写点什么

NCCL 源码解析②:Bootstrap 网络连接的建立

作者:OneFlow
  • 2023-04-10
    重庆
  • 本文字数:6153 字

    阅读完需:约 20 分钟


作者|KIDGINBROOK

更新|潘丽晨


上次介绍到rank0的机器生成了ncclUniqueId,并完成了机器的 bootstrap 网络和通信网络的初始化,这节接着看下所有节点间 bootstrap 的连接是如何建立的。


rank0 节点执行 ncclGetUniqueId 生成 ncclUniqueId,通过 mpi 将 Id 广播到所有节点,然后所有节点都会执行 ncclCommInitRank,这里其他节点也会进行初始化 bootstrap 网络和通信网络的操作,然后会执行到 ncclCommInitRankSync。


ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int cudaDev) {  ncclResult_t res;   CUDACHECK(cudaSetDevice(cudaDev));  NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank), res, cleanup);  NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);  NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);   INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %x - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId);   return ncclSuccess;cleanup:  if ((*newcomm) && (*newcomm)->bootstrap) bootstrapAbort((*newcomm)->bootstrap);  *newcomm = NULL;  return res;}
复制代码


ncclComm_t 是指向 ncclComm 的指针,ncclComm 是一个大杂烩,包含了通信用到的所有上下文信息,里面的字段等用到的时候再介绍,然后通过 commAlloc 分配 newcom,并且完成初始化,比如当前是哪个卡,对应的 pcie busid 是什么,然后执行 initTransportsRank。


static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) {  // We use 3 AllGathers  // 1. { peerInfo, comm }  // 2. ConnectTransport[nranks], ConnectValue[nranks]  // 3. { nThreads, nrings, compCap, prev[MAXCHANNELS], next[MAXCHANNELS] }   int rank = comm->rank;  int nranks = comm->nRanks;  uint64_t commHash = getHash(commId->internal, NCCL_UNIQUE_ID_BYTES);  TRACE(NCCL_INIT, "comm %p, commHash %lx, rank %d nranks %d - BEGIN", comm, commHash, rank, nranks);  NCCLCHECK(bootstrapInit(commId, rank, nranks, &comm->bootstrap));   // AllGather1 - begin  struct {    struct ncclPeerInfo peerInfo;    struct ncclComm* comm;  } *allGather1Data;   NCCLCHECK(ncclCalloc(&allGather1Data, nranks));  allGather1Data[rank].comm = comm;  struct ncclPeerInfo* myInfo = &allGather1Data[rank].peerInfo;  NCCLCHECK(fillInfo(comm, myInfo, commHash));  NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGather1Data, sizeof(*allGather1Data)));   NCCLCHECK(ncclCalloc(&comm->peerInfo, nranks+1)); // Extra rank to represent CollNet root  for (int i = 0; i < nranks; i++) {    memcpy(comm->peerInfo+i, &allGather1Data[i].peerInfo, sizeof(struct ncclPeerInfo));    if ((i != rank) && (comm->peerInfo[i].hostHash == myInfo->hostHash) && (comm->peerInfo[i].busId == myInfo->busId)) {      WARN("Duplicate GPU detected : rank %d and rank %d both on CUDA device %x", rank, i, myInfo->busId);      return ncclInvalidUsage;    }  }
复制代码


看下 bootstrapInit:


ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {  ncclNetHandle_t* netHandle = (ncclNetHandle_t*) id;  bool idFromEnv = getenv("NCCL_COMM_ID") != NULL;  struct extState* state;  NCCLCHECK(ncclCalloc(&state, 1));  state->rank = rank;  state->nranks = nranks;  *commState = state;   TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks);   struct extInfo info = { 0 };  info.rank = rank;  info.nranks = nranks;  void *tmpSendComm, *tmpRecvComm;  // Pass the remote address to listen via info  if (idFromEnv) {    memcpy(&info.extHandleListen, netHandle, sizeof(ncclNetHandle_t));    memcpy(&info.extHandleListenRoot, netHandle, sizeof(ncclNetHandle_t));  }  // listen will return the local address via info (specify interface type 'findSubnetIf')  state->dev = idFromEnv ? findSubnetIf : 0;  void* extBstrapListenCommRoot;  NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListen, &state->extBstrapListenComm));  NCCLCHECK(bootstrapNetListen(state->dev, &info.extHandleListenRoot, &extBstrapListenCommRoot));   // stagger connection times to avoid an overload of the root at very high rank counts  if (nranks > 128) {    long msec = rank;    struct timespec tv;    tv.tv_sec = msec / 1000;    tv.tv_nsec = 1000000 * (msec % 1000);    TRACE(NCCL_INIT, "rank %d delaying connection to root by %ld msec", rank, msec);    (void) nanosleep(&tv, NULL);  }   // send info on my listening socket to root  NCCLCHECK(bootstrapNetConnect(state->dev, netHandle, &tmpSendComm));  NCCLCHECK(bootstrapNetSend(tmpSendComm, &info, sizeof(info)));  NCCLCHECK(bootstrapNetCloseSend(tmpSendComm));   // get info on my "next" rank in the bootstrap ring from root}
复制代码


首先看下 commState,即 ncclComm 的 bootstrap,类型为 extState。


struct extState {  void* extBstrapListenComm;  void* extBstrapRingRecvComm;  void* extBstrapRingSendComm;  ncclNetHandle_t* peerBstrapHandles;  struct unexConn* unexpectedConnections;  int rank;  int nranks;  int dev;};
复制代码


其中 extBstrapRingSendComm 是当前节点连接 next 的 socket 连接,extBstrapRingRecvComm 是当前节点和 prev 节点的 socket 连接,extBstrapListenComm 是当前节点的监听 socket,peerBstrapHandles 是所有 rank 的 ip port(对应 extBstrapListenComm),dev 默认为 0,表示用第几个 ip 地址。


然后通过 bootstrapNetListen 创建 extHandleListen 和 extHandleListenRoot 两个 bootstrap comm,如前文所述,bootstrap comm 其实就是保存了 fd,这里创建两个 comm 的原因是 extHandleListen 是 rank 之间实际使用的 bootstrap 连接,extHandleListenRoot 是 rank0 节点和其他所有 rank 进行通信使用的连接。


static ncclResult_t bootstrapNetListen(int dev, ncclNetHandle_t* netHandle, void** listenComm)
复制代码


bootstrapNetListen 函数上节有介绍过,会获取到第 dev 个当前机器的 ip,然后 listen 获取监听 fd,将 ip port 写到 nethandle,获取到的 bootstrap comm 写到 listencomm。


然后将 rank,nrank,extHandleListen 和 extHandleListenRoot 写到 extInfo 里。


struct extInfo {  int rank;  int nranks;  ncclNetHandle_t extHandleListenRoot;  ncclNetHandle_t extHandleListen;};
复制代码


netHandle 为 ncclUniqueId,即 rank0 的 ip port,然后通过 bootstrapNetConnect 创建 bootstrap send comm,类比 bootstrapNetListen,bootstrapNetConnect 就是建立到 netHandle 的 socket 连接,将 socket 写到 sendComm 里,这里 dev 并没有用到。


static ncclResult_t bootstrapNetConnect(int dev, ncclNetHandle_t* netHandle, void** sendComm)
复制代码


然后通过 bootstrapNetSend 将 extInfo 发送出去,即发给 rank0:


static ncclResult_t bootstrapNetSend(void* sendComm, void* data, int size) {  struct bootstrapNetComm* comm = (struct bootstrapNetComm*)sendComm;  NCCLCHECK(socketSend(comm->fd, &size, sizeof(int)));  NCCLCHECK(socketSend(comm->fd, data, size));  return ncclSuccess;}
复制代码


其中 socketSend 就是执行 send 接口发送数据。


然后通过 bootstrapNetCloseSend 关闭 fd。


rank0 收到数据后会做什么工作呢,回顾一下,rank0 的节执行 ncclGetUniqueId 生成 ncclUniqueId,其中在执行 bootstrapCreateRoot 的最后会启动一个线程执行 bootstrapRoot。


static void *bootstrapRoot(void* listenComm) {  struct extInfo info;  ncclNetHandle_t *rankHandles = NULL;  ncclNetHandle_t *rankHandlesRoot = NULL; // for initial rank <-> root information exchange  ncclNetHandle_t zero = { 0 }; // for sanity checking  void* tmpComm;  ncclResult_t res;  setFilesLimit();   TRACE(NCCL_INIT, "BEGIN");  /* Receive addresses from all ranks */  int nranks = 0, c = 0;  do {    NCCLCHECKGOTO(bootstrapNetAccept(listenComm, &tmpComm), res, out);    NCCLCHECKGOTO(bootstrapNetRecv(tmpComm, &info, sizeof(info)), res, out);    NCCLCHECKGOTO(bootstrapNetCloseRecv(tmpComm), res, out);     if (c == 0) {      nranks = info.nranks;      NCCLCHECKGOTO(ncclCalloc(&rankHandles, nranks), res, out);      NCCLCHECKGOTO(ncclCalloc(&rankHandlesRoot, nranks), res, out);    }     if (nranks != info.nranks) {      WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nranks, info.nranks);      goto out;    }     if (memcmp(&zero, &rankHandlesRoot[info.rank], sizeof(ncclNetHandle_t)) != 0) {      WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, nranks);      goto out;    }     // Save the connection handle for that rank    memcpy(rankHandlesRoot+info.rank, info.extHandleListenRoot, sizeof(ncclNetHandle_t));    memcpy(rankHandles+info.rank, info.extHandleListen, sizeof(ncclNetHandle_t));     ++c;    TRACE(NCCL_INIT, "Received connect from rank %d total %d/%d",  info.rank, c, nranks);  } while (c < nranks);  TRACE(NCCL_INIT, "COLLECTED ALL %d HANDLES", nranks);   // Send the connect handle for the next rank in the AllGather ring  for (int r=0; r<nranks; ++r) {    int next = (r+1) % nranks;    void *tmpSendComm;    NCCLCHECKGOTO(bootstrapNetConnect(0, rankHandlesRoot+r, &tmpSendComm), res, out);    NCCLCHECKGOTO(bootstrapNetSend(tmpSendComm, rankHandles+next, sizeof(ncclNetHandle_t)), res, out);    NCCLCHECKGOTO(bootstrapNetCloseSend(tmpSendComm), res, out);  }  TRACE(NCCL_INIT, "SENT OUT ALL %d HANDLES", nranks); out:  bootstrapNetCloseListen(listenComm);  if (rankHandles) free(rankHandles);  if (rankHandlesRoot) free(rankHandlesRoot);   TRACE(NCCL_INIT, "DONE");  return NULL;}
复制代码


listenComm 是上一个博文中 rank0 创建的监听 fd,bootstrapNetAccept 是从 listenComm 中获取一个新连接,使用新连接的 fd 创建 recvcomm。


static ncclResult_t bootstrapNetAccept(void* listenComm, void** recvComm)
复制代码


然后通过 bootstrapNetRecv 读取 tmpComm 的数据,即其他 rank 发送来的 extInfo,然后保存其他 rank 的 extHandleListen 和 extHandleListenRoot,这个时候 rank0 就获取到其他所有 rank 的 ip 和 port 了。


获取完所有 rank 的 info 之后开始建环,将节点(r+1) % nranks 的 extHandleListen 发送给节点 r,就是说将节点 r 的 next 节点的 nethandle 发送给节点 r。这里可以看出,每个节点创建了两个 listen comm,其中 rank0 使用 extHandleListenRoot 进行通信,其他节点之间通过 extHandleListen 进行通信。


然后再回去接着看 bootstrapInit。


ncclResult_t bootstrapInit(ncclUniqueId * id, int rank, int nranks, void** commState) {  // get info on my "next" rank in the bootstrap ring from root  ncclNetHandle_t extHandleNext;  NCCLCHECK(bootstrapNetAccept(extBstrapListenCommRoot, &tmpRecvComm));  NCCLCHECK(bootstrapNetRecv(tmpRecvComm, &extHandleNext, sizeof(extHandleNext)));  NCCLCHECK(bootstrapNetCloseRecv(tmpRecvComm));  NCCLCHECK(bootstrapNetCloseListen(extBstrapListenCommRoot));   NCCLCHECK(bootstrapNetConnect(state->dev, &extHandleNext, &state->extBstrapRingSendComm));  // Accept the connect request from the previous rank in the AllGather ring  NCCLCHECK(bootstrapNetAccept(state->extBstrapListenComm, &state->extBstrapRingRecvComm));   // AllGather all listen handlers  NCCLCHECK(ncclCalloc(&state->peerBstrapHandles, nranks));  memcpy(state->peerBstrapHandles+rank, info.extHandleListen, sizeof(ncclNetHandle_t));  NCCLCHECK(bootstrapAllGather(state, state->peerBstrapHandles, sizeof(ncclNetHandle_t)));   TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);   return ncclSuccess;}
复制代码


接着所有 rank 都会在 extHandleListenRoot 上接收新连接创建 tmpRecvComm,然后接收到当前 rank 的 next 的 ip,port;然后连接 next 创建 bscomm 到 state->extBstrapRingSendComm,接收 prev 的连接创建 bscomm 到 state->extBstrapRingRecvComm,到现在 bootstrap 网络连接就完全建立起来了,如下图:



最后 gather 所有 rank 的 ip port,首先将自己的 nethandle 放到 peerBstrapHandles 的对应位置,如下所示。



然后执行 bootstrapAllGather:


ncclResult_t bootstrapAllGather(void* commState, void* allData, int size) {  struct extState* state = (struct extState*)commState;  char* data = (char*)allData;  int rank = state->rank;  int nranks = state->nranks;   TRACE(NCCL_INIT, "rank %d nranks %d size %d", rank, nranks, size);   /* Simple ring based AllGather   * At each step i receive data from (rank-i-1) from left   * and send previous step's data from (rank-i) to right   */  for (int i=0; i<nranks-1; i++) {    size_t rslice = (rank - i - 1 + nranks) % nranks;    size_t sslice = (rank - i + nranks) % nranks;     // Send slice to the right    NCCLCHECK(bootstrapNetSend(state->extBstrapRingSendComm, data+sslice*size, size));    // Recv slice from the left    NCCLCHECK(bootstrapNetRecv(state->extBstrapRingRecvComm, data+rslice*size, size));  }   TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);  return ncclSuccess;}
复制代码


第一步:



第二步:


到这里每个 rank 就都有了全局所有 rank 的 ip port。


最后总结一下,本节主要创建了 bootstrap 环形网络连接,并保存到 ncclComm 里。


其他人都在看


欢迎 Star、试用 OneFlow 最新版本:https://github.com/Oneflow-Inc/oneflow/

发布于: 刚刚阅读数: 3
用户头像

OneFlow

关注

不至于成为世界上最快的深度学习框架。 2022-03-23 加入

★ OneFlow深度学习框架:github.com/Oneflow-Inc/oneflow ★ OF云平台:oneflow.cloud

评论

发布
暂无评论
NCCL源码解析②:Bootstrap网络连接的建立_人工智能_OneFlow_InfoQ写作社区