1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
| cimport cython from libc.stdint cimport intptr_t
cdef extern from "acl/acl.h": ctypedef void *aclrtStream cdef extern from "hccl/hccl.h": ctypedef void *HcclComm ctypedef unsigned int __uint32_t ctypedef __uint32_t uint32_t ctypedef unsigned long int __uint64_t ctypedef __uint64_t uint64_t ctypedef enum HcclDataType: pass ctypedef enum HcclResult: HCCL_SUCCESS cdef enum: HCCL_ROOT_INFO_BYTES = 4108 ctypedef struct HcclRootInfo: char internal[HCCL_ROOT_INFO_BYTES]
HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) nogil HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) nogil HcclResult HcclSend(void* sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, HcclComm comm, aclrtStream stream) nogil HcclResult HcclRecv(void* recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, HcclComm comm, aclrtStream stream) nogil HcclResult HcclCommDestroy(HcclComm comm) nogil const char *HcclGetErrorString(HcclResult code) nogil
cdef dict HCCL_ERR_STR = { 0 : 'HCCL_SUCCESS', 1 : 'HCCL_E_PARA', 2 : 'HCCL_E_PTR', 3 : 'HCCL_E_MEMORY', 4 : 'HCCL_E_INTERNAL', 5 : 'HCCL_E_NOT_SUPPORT', 6 : 'HCCL_E_NOT_FOUND', 7 : 'HCCL_E_UNAVAIL', 8 : 'HCCL_E_SYSCALL', 9 : 'HCCL_E_TIMEOUT', 10 : 'HCCL_E_OPEN_FILE_FAILURE', 11 : 'HCCL_E_TCP_CONNECT', 12 : 'HCCL_E_ROCE_CONNECT', 13 : 'HCCL_E_TCP_TRANSFER', 14 : 'HCCL_E_ROCE_TRANSFER', 15 : 'HCCL_E_RUNTIME', 16 : 'HCCL_E_DRV', 17 : 'HCCL_E_PROFILING', 18 : 'HCCL_E_CCE', 19 : 'HCCL_E_NETWORK', 20 : 'HCCL_E_AGAIN', 21 : 'HCCL_E_REMOTE', 22 : 'HCCL_E_SUSPENDING', 23 : 'HCCL_E_RESERVED' }
class HcclError(RuntimeError): def __init__(self, int status): self.status = status cdef const char* msg with nogil: msg = HcclGetErrorString(<HcclResult>status) super(HcclError, self).__init__( '%s: %s' % (HCCL_ERR_STR[status], msg.decode()))
def __reduce__(self): return (type(self), (self.status,))
@cython.profile(False) cpdef inline check_hccl_status(HcclResult status): if status != HCCL_SUCCESS: raise HcclError(status)
def get_unique_id(): cdef HcclRootInfo root_info with nogil: status = HcclGetRootInfo(&root_info) check_hccl_status(status) ret = tuple([root_info.internal[i] for i in range(HCCL_ROOT_INFO_BYTES)]) return ret
cdef class HCCLCommunicator: cdef: HcclComm _comm
@property def comm(self): return <intptr_t>self._comm
def __cinit__(self): self._comm = <HcclComm>0 def __dealloc__(self): if self._comm: with nogil: status = HcclCommDestroy(self._comm) check_hccl_status(status) self._comm = <HcclComm>0
def __init__(self, int ndev, tuple commId, int rank): cdef HcclRootInfo _root_info assert len(commId) == HCCL_ROOT_INFO_BYTES for i in range(HCCL_ROOT_INFO_BYTES): _root_info.internal[i] = commId[i] with nogil: status = HcclCommInitRootInfo(ndev, &_root_info, rank, &self._comm) check_hccl_status(status)
def send(self, intptr_t sendbuf, size_t count, int datatype, int peer, intptr_t stream): with nogil: status = HcclSend(<void*>sendbuf, count, <HcclDataType>datatype, peer, self._comm, <aclrtStream>stream) check_hccl_status(status)
def recv(self, intptr_t recvbuf, size_t count, int datatype, int peer, intptr_t stream): with nogil: status = HcclRecv(<void*>recvbuf, count, <HcclDataType>datatype, peer, self._comm, <aclrtStream>stream) check_hccl_status(status)
def create_hccl_communicator(world_size, hccl_root_info, rank): return HCCLCommunicator(world_size, hccl_root_info, rank)
|