Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions comms/ctran/interfaces/IBootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include <stdexcept>
#include <vector>

#include <folly/futures/Future.h>
Expand Down Expand Up @@ -60,6 +61,43 @@ class IBootstrap {
int localNranks,
std::vector<int> localRankToCommRank) = 0;

/**
* AllGather within an NVLink domain, which may span multiple hosts (MNNVL).
*
* `buf` refers to a continuous memory segment of size `nvlNranks * len`.
* `nvlLocalRank` is this rank's index within the NVL domain [0, nvlNranks).
* `nvlRankToCommRank` maps NVL-local indices to global communicator ranks.
*
* Unlike allGatherIntraNode (which uses a host-scoped communicator),
* this creates a dynamic subcommunicator from the specified global ranks,
* supporting cross-host NVLink domains like GB200 NVL72.
*
* Subclasses must override this if NVL domain operations are needed.
*/
virtual folly::SemiFuture<int> allGatherNvlDomain(
void* buf,
int len,
int nvlLocalRank,
int nvlNranks,
std::vector<int> nvlRankToCommRank) {
throw std::runtime_error("allGatherNvlDomain not implemented");
}

/**
* Barrier within an NVLink domain, which may span multiple hosts (MNNVL).
*
* `nvlLocalRank` is this rank's index within the NVL domain [0, nvlNranks).
* `nvlRankToCommRank` maps NVL-local indices to global communicator ranks.
*
* Subclasses must override this if NVL domain operations are needed.
*/
virtual folly::SemiFuture<int> barrierNvlDomain(
int nvlLocalRank,
int nvlNranks,
std::vector<int> nvlRankToCommRank) {
throw std::runtime_error("barrierNvlDomain not implemented");
}

/*
* `buf` refers to a continuous memory segment that is of size `len`
* `peer` must be a valid value between 0 and `nranks - 1`
Expand Down
10 changes: 10 additions & 0 deletions comms/ctran/tests/bootstrap/MockBootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ void MockBootstrap::expectSuccessfulCtranInitCalls() {
.WillRepeatedly([](int localRank,
int localNRanks,
std::vector<int> localRankToCommRank) { return 0; });
EXPECT_CALL(*this, allGatherNvlDomain(_, _, _, _, _))
.WillRepeatedly([](void* buf,
int len,
int nvlLocalRank,
int nvlNranks,
std::vector<int> nvlRankToCommRank) { return 0; });
EXPECT_CALL(*this, barrierNvlDomain(_, _, _))
.WillRepeatedly([](int nvlLocalRank,
int nvlNranks,
std::vector<int> nvlRankToCommRank) { return 0; });
}

} // namespace ctran::testing
14 changes: 14 additions & 0 deletions comms/ctran/tests/bootstrap/MockBootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class MockBootstrap : public ctran::bootstrap::IBootstrap {
barrierIntraNode,
(int localRank, int localNranks, std::vector<int> localRankToCommRank),
(override));
MOCK_METHOD(
folly::SemiFuture<int>,
allGatherNvlDomain,
(void* buf,
int len,
int nvlLocalRank,
int nvlNranks,
std::vector<int> nvlRankToCommRank),
(override));
MOCK_METHOD(
folly::SemiFuture<int>,
barrierNvlDomain,
(int nvlLocalRank, int nvlNranks, std::vector<int> nvlRankToCommRank),
(override));
MOCK_METHOD(
folly::SemiFuture<int>,
send,
Expand Down
93 changes: 93 additions & 0 deletions comms/pipes/MultiPeerDeviceHandle.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#pragma once

#include <cstdint>

#include "comms/pipes/DeviceSpan.cuh"

// In CUDA compilation, include full Transport definition for device accessors.
// In host-only compilation, a forward declaration suffices because DeviceSpan
// only stores a pointer (T*) — it doesn't need sizeof(T).
#ifdef __CUDACC__
#include "comms/pipes/P2pIbgdaTransportDevice.cuh"
#include "comms/pipes/Transport.cuh"
#else
namespace comms::pipes {
struct Transport;
enum class TransportType : uint8_t;
class P2pNvlTransportDevice;
class P2pIbgdaTransportDevice;
} // namespace comms::pipes
#endif

namespace comms::pipes {

/**
* MultiPeerDeviceHandle - Unified device-side handle for mixed-transport
* communication.
*
* Lightweight struct passed to CUDA kernels. Contains a single DeviceSpan
* of Transport objects (one per rank) plus peer counts. The Transport union
* already carries the type discriminant, so no separate type array is needed.
*
* Layout: transports[0..nRanks-1] where transports[myRank].type == SELF,
* NVL peers sorted first, followed by IBGDA-only peers.
*
* USAGE:
* __global__ void kernel(MultiPeerDeviceHandle handle, ...) {
* for (int rank = 0; rank < handle.nRanks; ++rank) {
* switch (handle.get_type(rank)) {
* case TransportType::SELF: ... break;
* case TransportType::P2P_NVL: handle.get_nvl(rank).send(...); break;
* case TransportType::P2P_IBGDA: handle.get_ibgda(rank).put(...);
* break;
* }
* }
* }
*/
struct MultiPeerDeviceHandle {
int myRank{-1};
int nRanks{0};

// Unified transport array indexed by global rank.
// transports[rank].type gives the transport type for that rank.
DeviceSpan<Transport> transports;

// Number of NVL peers (excluding self)
int numNvlPeers{0};

// Number of IBGDA peers (= nRanks - 1, all non-self)
int numIbPeers{0};

#ifdef __CUDACC__
/** @return Transport type for the given global rank. */
__device__ __forceinline__ TransportType get_type(int rank) const {
return transports[rank].type;
}

/** @return Mutable reference to the NVL transport for the given rank. */
__device__ __forceinline__ P2pNvlTransportDevice& get_nvl(int rank) {
return transports[rank].p2p_nvl;
}

/** @return Const reference to the NVL transport for the given rank. */
__device__ __forceinline__ const P2pNvlTransportDevice& get_nvl(
int rank) const {
return transports[rank].p2p_nvl;
}

/** @return Mutable reference to the IBGDA transport for the given rank. */
__device__ __forceinline__ P2pIbgdaTransportDevice& get_ibgda(int rank) {
return *transports[rank].p2p_ibgda;
}

/** @return Const reference to the IBGDA transport for the given rank. */
__device__ __forceinline__ const P2pIbgdaTransportDevice& get_ibgda(
int rank) const {
return *transports[rank].p2p_ibgda;
}
#endif
};

} // namespace comms::pipes
Loading