Skip to content

Multi-GPU Batched KMeans#2017

Open
viclafargue wants to merge 4 commits intorapidsai:mainfrom
viclafargue:mg-batched-kmeans
Open

Multi-GPU Batched KMeans#2017
viclafargue wants to merge 4 commits intorapidsai:mainfrom
viclafargue:mg-batched-kmeans

Conversation

@viclafargue
Copy link
Copy Markdown
Contributor

Closes #1989.

Adds multi-GPU support to KMeans fit for host-resident data, with two modes:

  • OpenMP (cuVS SNMG): A single process drives all local GPUs via OMP threads and raw NCCL. Activated automatically when the handle is a device_resources_snmg.
  • RAFT comms (Ray / Dask / MPI): Each rank is a separate process that calls fit with its own data shard and an initialized RAFT communicator. Coordination uses the RAFT comms.

Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present.

@viclafargue viclafargue self-assigned this Apr 13, 2026
@viclafargue viclafargue requested review from a team as code owners April 13, 2026 14:34
@viclafargue viclafargue added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Apr 13, 2026
@viclafargue
Copy link
Copy Markdown
Contributor Author

Here are some instructions to test the Multi-GPU Batched KMeans API with RAFT comms (to be used with Ray/Dask) :

RAFT comms (Ray/Dask) demo code
#include <cuvs/cluster/kmeans.hpp>

#include <raft/comms/std_comms.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/comms.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#include <cuda_runtime.h>
#include <mpi.h>
#include <nccl.h>

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <random>
#include <vector>

#define CHECK_CUDA(call)                                                 \
  do {                                                                   \
    cudaError_t e = (call);                                              \
    if (e != cudaSuccess) {                                              \
      std::fprintf(stderr, "CUDA error %s @ %s:%d\n",                   \
                   cudaGetErrorString(e), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

#define CHECK_NCCL(call)                                                 \
  do {                                                                   \
    ncclResult_t r = (call);                                             \
    if (r != ncclSuccess) {                                              \
      std::fprintf(stderr, "NCCL error %s @ %s:%d\n",                   \
                   ncclGetErrorString(r), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

int main(int argc, char** argv)
{
  MPI_Init(&argc, &argv);

  int rank, num_ranks;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &num_ranks);

  CHECK_CUDA(cudaSetDevice(rank));

  ncclUniqueId nccl_id;
  if (rank == 0) CHECK_NCCL(ncclGetUniqueId(&nccl_id));
  MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD);

  ncclComm_t nccl_comm;
  CHECK_NCCL(ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));

  raft::resources handle;
  raft::comms::build_comms_nccl_only(&handle, nccl_comm, num_ranks, rank);

  // --- Demo parameters ---
  constexpr int64_t n_samples       = 100'000;
  constexpr int64_t n_features      = 32;
  constexpr int     n_clusters      = 10;
  constexpr int64_t streaming_batch = 10'000;
  constexpr float   cluster_spread  = 1.0f;
  constexpr float   center_range    = 30.0f;

  if (rank == 0) {
    std::printf("=== Multi-GPU KMeans Demo (%d ranks) ===\n", num_ranks);
    std::printf("Samples: %ld | Features: %ld | k: %d | batch: %ld\n\n",
                long(n_samples), long(n_features), n_clusters, long(streaming_batch));
  }

  // Generate synthetic blobs with well-separated cluster centers
  std::vector<float> h_data(n_samples * n_features);
  std::vector<int>   h_true_labels(n_samples);
  std::vector<float> cluster_centers(n_clusters * n_features);
  {
    std::mt19937 gen(12345);
    std::uniform_real_distribution<float> center_dist(-center_range, center_range);
    std::normal_distribution<float> noise(0.0f, cluster_spread);

    for (int c = 0; c < n_clusters; ++c)
      for (int d = 0; d < n_features; ++d)
        cluster_centers[c * n_features + d] = center_dist(gen);

    for (int64_t i = 0; i < n_samples; ++i) {
      int label = static_cast<int>(i % n_clusters);
      h_true_labels[i] = label;
      for (int d = 0; d < n_features; ++d)
        h_data[i * n_features + d] = cluster_centers[label * n_features + d] + noise(gen);
    }

    // Shuffle so labels aren't just sequential runs
    std::vector<int64_t> perm(n_samples);
    std::iota(perm.begin(), perm.end(), 0);
    std::shuffle(perm.begin(), perm.end(), gen);

    std::vector<float> tmp_data(h_data);
    std::vector<int>   tmp_labels(h_true_labels);
    for (int64_t i = 0; i < n_samples; ++i) {
      std::memcpy(h_data.data() + i * n_features,
                  tmp_data.data() + perm[i] * n_features,
                  n_features * sizeof(float));
      h_true_labels[i] = tmp_labels[perm[i]];
    }
  }

  int64_t base    = n_samples / num_ranks;
  int64_t rem     = n_samples % num_ranks;
  int64_t offset  = rank * base + std::min<int64_t>(rank, rem);
  int64_t n_local = base + (rank < rem ? 1 : 0);

  std::printf("[rank %d / GPU %d]  rows [%ld .. %ld)  (%ld samples)\n",
              rank, rank, long(offset), long(offset + n_local), long(n_local));

  auto X_local = raft::make_host_matrix_view<const float, int64_t>(
    h_data.data() + offset * n_features, n_local, n_features);

  auto d_centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);

  cuvs::cluster::kmeans::params params;
  params.n_clusters           = n_clusters;
  params.max_iter             = 50;
  params.tol                  = 1e-4;
  params.init                 = cuvs::cluster::kmeans::params::KMeansPlusPlus;
  params.rng_state.seed       = 42;
  params.inertia_check        = true;
  params.streaming_batch_size = streaming_batch;

  float   inertia = 0.0f;
  int64_t n_iter  = 0;

  cuvs::cluster::kmeans::fit(handle,
                             params,
                             X_local,
                             std::nullopt,
                             d_centroids.view(),
                             raft::make_host_scalar_view(&inertia),
                             raft::make_host_scalar_view(&n_iter));

  auto stream = raft::resource::get_cuda_stream(handle);
  CHECK_CUDA(cudaStreamSynchronize(stream));

  if (rank == 0) {
    // --- Predict labels on the full dataset (on rank 0) ---
    auto d_X = raft::make_device_matrix<float, int64_t>(handle, n_samples, n_features);
    CHECK_CUDA(cudaMemcpy(d_X.data_handle(), h_data.data(),
                          sizeof(float) * n_samples * n_features, cudaMemcpyHostToDevice));

    auto d_labels = raft::make_device_vector<int64_t, int64_t>(handle, n_samples);
    float predict_inertia = 0.0f;

    cuvs::cluster::kmeans::predict(
      handle, params,
      raft::make_device_matrix_view<const float, int64_t>(d_X.data_handle(), n_samples, n_features),
      std::nullopt,
      raft::make_device_matrix_view<const float, int64_t>(
        d_centroids.data_handle(), n_clusters, n_features),
      d_labels.view(),
      false,
      raft::make_host_scalar_view(&predict_inertia));
    CHECK_CUDA(cudaStreamSynchronize(stream));

    std::vector<int64_t> h_labels(n_samples);
    CHECK_CUDA(cudaMemcpy(h_labels.data(), d_labels.data_handle(),
                          sizeof(int64_t) * n_samples, cudaMemcpyDeviceToHost));

    // --- Quality: permutation-invariant accuracy via majority voting ---
    // For each predicted cluster, find which true label appears most often.
    std::vector<std::vector<int64_t>> confusion(n_clusters, std::vector<int64_t>(n_clusters, 0));
    for (int64_t i = 0; i < n_samples; ++i)
      confusion[h_labels[i]][h_true_labels[i]]++;

    // Greedy matching: assign each predicted cluster to its dominant true label
    std::vector<int> pred_to_true(n_clusters, -1);
    std::vector<bool> true_taken(n_clusters, false);
    for (int round = 0; round < n_clusters; ++round) {
      int64_t best_count = -1;
      int best_pred = -1, best_true = -1;
      for (int p = 0; p < n_clusters; ++p) {
        if (pred_to_true[p] >= 0) continue;
        for (int t = 0; t < n_clusters; ++t) {
          if (true_taken[t]) continue;
          if (confusion[p][t] > best_count) {
            best_count = confusion[p][t];
            best_pred = p;
            best_true = t;
          }
        }
      }
      pred_to_true[best_pred] = best_true;
      true_taken[best_true] = true;
    }

    int64_t correct = 0;
    std::vector<int64_t> cluster_sizes(n_clusters, 0);
    std::vector<int64_t> cluster_correct(n_clusters, 0);
    for (int64_t i = 0; i < n_samples; ++i) {
      int p = static_cast<int>(h_labels[i]);
      cluster_sizes[p]++;
      if (h_true_labels[i] == pred_to_true[p]) {
        ++correct;
        ++cluster_correct[p];
      }
    }
    double accuracy = 100.0 * correct / n_samples;

    // --- Compute centroid-to-true-center distances ---
    std::vector<float> h_centroids(n_clusters * n_features);
    CHECK_CUDA(cudaMemcpy(h_centroids.data(), d_centroids.data_handle(),
                          sizeof(float) * n_clusters * n_features, cudaMemcpyDeviceToHost));

    std::printf("\n============ Multi-GPU KMeans Results ============\n");
    std::printf("  Ranks             : %d\n", num_ranks);
    std::printf("  Total samples     : %ld\n", long(n_samples));
    std::printf("  Features          : %ld\n", long(n_features));
    std::printf("  Clusters (k)      : %d\n", n_clusters);
    std::printf("  Streaming batch   : %ld\n", long(streaming_batch));
    std::printf("  Lloyd iterations  : %ld\n", long(n_iter));
    std::printf("  Final inertia     : %.6f\n", double(inertia));
    std::printf("  Predict inertia   : %.6f\n", double(predict_inertia));
    std::printf("\n  --- Clustering Quality ---\n");
    std::printf("  Overall accuracy  : %.2f%% (%ld / %ld)\n",
                accuracy, long(correct), long(n_samples));

    std::printf("\n  Per-cluster breakdown:\n");
    std::printf("  %6s  %10s  %10s  %8s  %12s\n",
                "Pred", "TrueLabel", "Size", "Acc%", "CentroidErr");
    for (int p = 0; p < n_clusters; ++p) {
      int t = pred_to_true[p];
      double pct = cluster_sizes[p] > 0
                     ? 100.0 * cluster_correct[p] / cluster_sizes[p]
                     : 0.0;

      // L2 distance between learned centroid and ground truth center
      double dist2 = 0.0;
      for (int d = 0; d < n_features; ++d) {
        double diff = h_centroids[p * n_features + d] - cluster_centers[t * n_features + d];
        dist2 += diff * diff;
      }
      std::printf("  %6d  %10d  %10ld  %7.2f%%  %12.4f\n",
                  p, t, long(cluster_sizes[p]), pct, std::sqrt(dist2));
    }

    std::printf("\n  Expected accuracy for well-separated blobs: >99%%\n");
    if (accuracy >= 99.0)
      std::printf("  PASS: Clustering quality is high.\n");
    else if (accuracy >= 90.0)
      std::printf("  WARN: Clustering quality is acceptable but not ideal.\n");
    else
      std::printf("  FAIL: Clustering quality is poor!\n");

    std::printf("==================================================\n");
  }

  CHECK_NCCL(ncclCommDestroy(nccl_comm));
  MPI_Finalize();
  return 0;
}
Compilation command
nvcc -std=c++17 -x cu --extended-lambda -arch=native       \
 -I$CONDA_PREFIX/include/rapids                            \
 -I$CONDA_PREFIX/include                                   \
 demo_mg_kmeans_raft_comms.cu                              \
 -L$CONDA_PREFIX/lib -lcuvs -lnccl -lrmm -lmpi             \
 -lucxx -lucp -lucs                                       \
 -Xlinker=-rpath,$CONDA_PREFIX/lib                         \
 -o demo_mg_kmeans
Launch command

mpirun -np 2 ./demo_mg_kmeans

@viclafargue viclafargue requested a review from tarang-jain April 13, 2026 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[FEA] Multi-node Multi-GPU Kmeans (C++) to support new out-of-core batching

1 participant