Skip to content
168 changes: 117 additions & 51 deletions apps/nccl/src/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __global__ void __launch_bounds__(1024, 1)
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
const size_t peerIdx = blockIdx.x; // Stores the peerIdx.

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
if (threadIdx.x < nPeer) {
Expand All @@ -35,12 +36,15 @@ __global__ void __launch_bounds__(1024, 1)
__syncthreads();

const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));
const size_t rootsmaller = (root < rank) ? 1 : 0;

const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU;
size_t unitBytesPerThread;
if (bytes * nPeer >= nThread * 64) {
if (bytes >= nThread * 64) {
unitBytesPerThread = 64;
// unitBytesPerThread = 16;
// unitBytesPerThread = 32;
} else {
unitBytesPerThread = 16;
}
Expand All @@ -53,93 +57,155 @@ __global__ void __launch_bounds__(1024, 1)

size_t scratchSub = 0;

// printf("nLoop = %ld, bytes = %ld, unitBytes = %ld, bytes mod unitBytes = %ld \n", nLoop, bytes, unitBytes,
// bytes % unitBytes);

// First loop will always fit the scratch size.
if (nLoop > 0) {
// First loop unrolling
const size_t offset = blockIdx.x * unitBytesPerBlock;
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock;
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}

} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock;
if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, false>(dst + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
//__syncthreads();
{
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this bracket needed?

const size_t offset = blockIdx.x * unitBytesPerBlock;
smChans[peerIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
scratchSub = -i * unitBytes;
deviceSyncer.sync(gridDim.x);
if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
}
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.

smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock + i * unitBytes;
if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
{
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
smChans[peerIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}
}

// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
if (bytes % unitBytes > 0) { // remainder.
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
if (remainBytes > 0) {
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
// const size_t remainTotalBytes = bytes - nLoop * unitBytes;
// const size_t nblocks_to_use_base = remainTotalBytes / unitBytesPerBlock;
// const size_t nblocks_to_use =
// (remainTotalBytes % unitBytesPerBlock) ? nblocks_to_use_base + 1 : nblocks_to_use_base;

// printf("nLoop = %ld, bytes = %ld, nblocks_to_use = %ld\n", nLoop, bytes, nblocks_to_use);

// if (blockIdx.x < nblocks_to_use) {
if (rank == root) {
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
offset < bytes ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;
char* send_ = reinterpret_cast<char*>(sendbuff);
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.

smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} // remainBytes > 0.

} else { // rank != root.
const size_t offset = (rank - rootsmaller) * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;

if (blockIdx.x == (rank - rootsmaller) && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
__syncthreads();

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
if (peerIdx != peerRootIdx) {
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, scratch_ + offset + scratchSub, remainBytes,
threadIdx.x, blockDim.x);
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.
{
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;
smChans[peerIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
}
}
//} // remainBytes > 0.
}

deviceSyncer.sync(gridDim.x);
Expand Down
Loading