diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 8f55edb925..fbc6877a00 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -91,7 +91,7 @@ struct cuvsKMeansParams { */ int batch_centroids; - /** Check inertia during iterations for early convergence. */ + /** Deprecated, ignored. Kept for ABI compatibility. */ bool inertia_check; /** @@ -108,7 +108,14 @@ struct cuvsKMeansParams { * Number of samples to process per GPU batch for the batched (host-data) API. * When set to 0, defaults to n_samples (process all at once). */ - int64_t streaming_batch_size; + int64_t streaming_batch_size; + + /** + * Number of samples to draw for KMeansPlusPlus initialization. + * When set to 0, uses heuristic min(3 * n_clusters, n_samples) for host data, + * or n_samples for device data. + */ + int64_t init_size; }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index a84cd50259..495a83f8d5 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -28,7 +28,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; - kmeans_params.inertia_check = params.inertia_check; + kmeans_params.init_size = params.init_size; kmeans_params.streaming_batch_size = params.streaming_batch_size; return kmeans_params; } @@ -237,10 +237,11 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) .oversampling_factor = cpp_params.oversampling_factor, .batch_samples = cpp_params.batch_samples, .batch_centroids = cpp_params.batch_centroids, - .inertia_check = cpp_params.inertia_check, + .inertia_check = false, .hierarchical = false, .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), - .streaming_batch_size = cpp_params.streaming_batch_size}; + .streaming_batch_size = cpp_params.streaming_batch_size, + .init_size = cpp_params.init_size}; }); } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index d299d9f483..ff7d056f7d 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -113,9 +113,14 @@ struct params : base_params { int batch_centroids = 0; /** - * If true, check inertia during iterations for early convergence. + * Number of samples to randomly draw for the KMeansPlusPlus initialization + * step. A random subset of this size is used for centroid seeding. + * When set to 0 the default depends on the data location: + * - Device data: n_samples (use the full dataset). + * - Host data: min(3 * n_clusters, n_samples). + * Default: 0. */ - bool inertia_check = false; + int64_t init_size = 0; /** * Number of samples to process per GPU batch when fitting with host data. diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5a35f203b3..a35e557ba5 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -5,6 +5,7 @@ #pragma once #include "../../core/nvtx.hpp" +#include "../../neighbors/detail/ann_utils.cuh" #include "kmeans_common.cuh" #include @@ -31,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +46,7 @@ #include #include #include +#include #include #include @@ -303,150 +306,17 @@ void update_centroids(raft::resources const& handle, new_centroids); } -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - raft::common::nvtx::range fun_scope("kmeans_fit_main"); - raft::default_logger().set_level(params.verbosity); - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm(handle, X, L2NormX.view()); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - update_centroids( - handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - cuda::transform_iterator(minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}), - wtInCluster.view(), - newCentroids.view(), - workspace); - - // Compute how much centroids shifted - DataT sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(newCentroids.view())); - - raft::copy(handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), - raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - cuvs::cluster::kmeans::cluster_cost(handle, - X, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - inertia, - std::make_optional(weight)); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /* * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. @@ -651,17 +521,20 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, auto inertia = raft::make_host_scalar(0); auto n_iter = raft::make_host_scalar(0); - cuvs::cluster::kmeans::params default_params; - default_params.n_clusters = params.n_clusters; - - cuvs::cluster::kmeans::detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); + cuvs::cluster::kmeans::params recluster_params; + recluster_params.n_clusters = params.n_clusters; + recluster_params.init = cuvs::cluster::kmeans::params::InitMethod::Array; + recluster_params.n_init = 1; + + auto weight_opt = std::make_optional(raft::make_const_mdspan(weight.view())); + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, + recluster_params, + raft::make_const_mdspan(potentialCentroids), + weight_opt, + centroidsRawData, + inertia.view(), + n_iter.view()); } else if ((int)potentialCentroids.extent(0) < n_clusters) { // supplement with random @@ -697,90 +570,109 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, } /** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. + * @brief Unified k-means fit (works with host or device data). + * + * @tparam DataT Data / weight type + * @tparam IndexT Index type + * @tparam Accessor Accessor policy (host or device); deduced from X + * * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. + * @param[in] pams Parameters for the KMeans model. + * @param[in] X Training instances to cluster (host or device). + * Row-major, [n_samples x n_features]. * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. + * [n_samples]. When std::nullopt, uniform weights + * are used. + * @param[inout] centroids [in] When init is InitMethod::Array, used as + * the initial cluster centers. + * [out] The final centroids produced by the + * algorithm. [n_clusters x n_features]. * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. + * closest cluster center. + * @param[out] n_iter Number of iterations run for the best + * initialization. */ -template -void kmeans_fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& pams, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) +template +void kmeans_fit( + raft::resources const& handle, + const cuvs::cluster::kmeans::params& pams, + raft::mdspan, raft::row_major, Accessor> X, + std::optional< + raft::mdspan, raft::layout_right, Accessor>> + sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { raft::common::nvtx::range fun_scope("kmeans_fit"); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = pams.n_clusters; + auto metric = pams.metric; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - // Check that parameters are valid + if (sample_weight.has_value()) RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, "invalid parameter (sample_weight!=n_samples)"); RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); RAFT_EXPECTS(pams.tol > 0, "invalid parameter (tol<=0)"); RAFT_EXPECTS(pams.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == pams.n_clusters, + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "invalid parameter (centroids.extent(0) != n_clusters)"); RAFT_EXPECTS(centroids.extent(1) == n_features, "invalid parameter (centroids.extent(1) != n_features)"); - // Display a message if the batch size is smaller than n_samples but will be ignored - if (pams.batch_samples < (int)n_samples && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_samples, - (int)n_samples); - } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (pams.batch_centroids != 0 && pams.batch_centroids != pams.n_clusters && - (pams.metric == cuvs::distance::DistanceType::L2Expanded || - pams.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - pams.batch_centroids, - pams.n_clusters); + raft::default_logger().set_level(pams.verbosity); + + IndexT streaming_batch_size = static_cast(pams.streaming_batch_size); + if (streaming_batch_size <= 0 || streaming_batch_size > static_cast(n_samples)) { + streaming_batch_size = static_cast(n_samples); } - raft::default_logger().set_level(pams.verbosity); + const DataT* weight_ptr = + sample_weight.has_value() ? sample_weight.value().data_handle() : nullptr; + DataT weight_scale = compute_weight_scale(handle, weight_ptr, n_samples); - // Allocate memory rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(handle, weight.view(), sample_weight.value()); - else - raft::matrix::fill(handle, weight.view(), DataT(1)); - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); + constexpr bool data_on_device = !raft::is_host_mdspan_v; + + auto init_centroids = [&](const cuvs::cluster::kmeans::params& iter_params, + raft::device_matrix_view centroidsRawData) { + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), + raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); + return; + } + + raft::random::RngState random_state(iter_params.rng_state.seed); + + if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + raft::matrix::sample_rows(handle, random_state, X, centroidsRawData); + } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IndexT default_init_size = + data_on_device ? n_samples : std::min(static_cast(3 * n_clusters), n_samples); + IndexT init_sample_size = iter_params.init_size > 0 + ? std::min(static_cast(iter_params.init_size), n_samples) + : default_init_size; + + auto init_sample = + raft::make_device_matrix(handle, init_sample_size, n_features); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); + auto init_sample_const = raft::make_const_mdspan(init_sample.view()); + if (iter_params.oversampling_factor == 0) + kmeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); + else + initScalableKMeansPlusPlus( + handle, iter_params, init_sample_const, centroidsRawData, workspace); + } else { + THROW("unknown initialization method to select initial centers"); + } + }; auto n_init = pams.n_init; if (pams.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { @@ -791,70 +683,253 @@ void kmeans_fit(raft::resources const& handle, n_init = 1; } + IndexT centroid_buf_size = n_clusters * n_features; + rmm::device_uvector centroid_buf_A(centroid_buf_size, stream); + rmm::device_uvector centroid_buf_B(centroid_buf_size, stream); + DataT* cur_centroids_ptr = centroid_buf_A.data(); + DataT* new_centroids_ptr = centroid_buf_B.data(); + + auto minClusterAndDistance = raft::make_device_vector, IndexT>( + handle, streaming_batch_size); + auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); + auto batch_weights_buf = raft::make_device_vector(handle, streaming_batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); + auto centroid_norms_buf = raft::make_device_vector(handle, n_clusters); + auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, streaming_batch_size, stream); + cuvs::spatial::knn::detail::utils::batch_load_iterator weight_batches( + weight_ptr, n_samples, 1, streaming_batch_size, stream); + + if (weight_ptr == nullptr) { raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1}); } + + auto prepare_batch_weights = [&](const auto& wt_batch, IndexT cur_batch_size) { + if (weight_ptr != nullptr) { + raft::copy(batch_weights_buf.data_handle(), wt_batch.data(), cur_batch_size, stream); + if (weight_scale != DataT{1}) { + auto bw = raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + raft::linalg::map( + handle, bw, raft::mul_const_op{weight_scale}, raft::make_const_mdspan(bw)); + } + } + return raft::make_device_vector_view(batch_weights_buf.data_handle(), + cur_batch_size); + }; + + RAFT_LOG_DEBUG( + "KMeans.fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(streaming_batch_size)); + + bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + bool use_norm_cache = need_compute_norms && !data_on_device; + std::vector h_norm_cache; + if (use_norm_cache) { h_norm_cache.resize(n_samples); } + bool norms_cached = false; + + auto compute_batch_norms = [&](const DataT* batch_ptr, IndexT batch_size) { + auto batch_view = + raft::make_device_matrix_view(batch_ptr, batch_size, n_features); + auto norm_view = + raft::make_device_vector_view(L2NormBatch.data_handle(), batch_size); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, batch_view, norm_view, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, batch_view, norm_view); + } + }; + + if (need_compute_norms && data_on_device) { + compute_batch_norms(X.data_handle(), n_samples); + norms_cached = true; + } + std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = pams; iter_params.rng_state.seed = gen(); - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); - raft::copy( - handle, - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); - } else { - THROW("unknown initialization method to select initial centers"); - } + RAFT_LOG_DEBUG("KMeans.fit: n_init iteration %d/%d (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, + cur_centroids_ptr = centroid_buf_A.data(); + new_centroids_ptr = centroid_buf_B.data(); + init_centroids( iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); + raft::make_device_matrix_view(cur_centroids_ptr, n_clusters, n_features)); + + DataT iter_inertia = std::numeric_limits::max(); + IndexT n_current_iter = 0; + DataT priorClusteringCost = 0; + + for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) { + RAFT_LOG_DEBUG("KMeans.fit: Iteration-%d", n_current_iter); + + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, weight_per_cluster.view(), DataT{0}); + raft::linalg::map(handle, + raft::make_device_scalar_view(clustering_cost.data_handle()), + raft::const_op{DataT{0}}); + + auto centroids_const = raft::make_device_matrix_view( + cur_centroids_ptr, n_clusters, n_features); + auto new_centroids_view = + raft::make_device_matrix_view(new_centroids_ptr, n_clusters, n_features); + + std::optional> centroid_norms_opt = + std::nullopt; + if (need_compute_norms) { + raft::linalg::norm( + handle, centroids_const, centroid_norms_buf.view()); + centroid_norms_opt = raft::make_device_vector_view( + centroid_norms_buf.data_handle(), n_clusters); + } + + data_batches.reset(); + weight_batches.reset(); + auto wt_it = weight_batches.begin(); + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const auto& wt_batch = *wt_it; + ++wt_it; + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + auto batch_weights_view = prepare_batch_weights(wt_batch, cur_batch_size); + + auto minCAD_view = raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), cur_batch_size); + + if (need_compute_norms && !norms_cached) { + compute_batch_norms(data_batch.data(), cur_batch_size); + if (use_norm_cache) { + raft::copy(h_norm_cache.data() + data_batch.offset(), + L2NormBatch.data_handle(), + cur_batch_size, + stream); + } + } else if (use_norm_cache) { + raft::copy(L2NormBatch.data_handle(), + h_norm_cache.data() + data_batch.offset(), + cur_batch_size, + stream); + } + + auto l2_const_view = raft::make_device_vector_view( + L2NormBatch.data_handle(), cur_batch_size); + + process_batch(handle, + batch_data_view, + batch_weights_view, + centroids_const, + metric, + iter_params.batch_samples, + iter_params.batch_centroids, + minCAD_view, + l2_const_view, + L2NormBuf_OR_DistBuf, + workspace, + centroid_sums.view(), + weight_per_cluster.view(), + batch_sums.view(), + batch_counts.view(), + clustering_cost.view(), + centroid_norms_opt); + } + if (!norms_cached && use_norm_cache) { + raft::resource::sync_stream(handle, stream); + norms_cached = true; + } + + finalize_centroids(handle, + raft::make_const_mdspan(centroid_sums.view()), + raft::make_const_mdspan(weight_per_cluster.view()), + centroids_const, + new_centroids_view); + + DataT sqrdNormError = + compute_centroid_shift(handle, + raft::make_const_mdspan(centroids_const), + raft::make_const_mdspan(new_centroids_view)); + + std::swap(cur_centroids_ptr, new_centroids_ptr); + + bool done = false; + + DataT curClusteringCost = DataT{0}; + raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + + if (curClusteringCost == DataT{0}) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_current_iter > 1) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - iter_params.tol) done = true; + } + priorClusteringCost = curClusteringCost; + + if (sqrdNormError < iter_params.tol) done = true; + + if (done) { + RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", + n_current_iter); + break; + } + } + + { + auto centroids_const = raft::make_device_matrix_view( + cur_centroids_ptr, n_clusters, n_features); + + iter_inertia = DataT{0}; + data_batches.reset(); + weight_batches.reset(); + auto wt_it = weight_batches.begin(); + for (const auto& data_batch : data_batches) { + IndexT cur_batch_size = static_cast(data_batch.size()); + const auto& wt_batch = *wt_it; + ++wt_it; + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), cur_batch_size, n_features); + + std::optional> batch_sw = std::nullopt; + if (weight_ptr != nullptr) { batch_sw = prepare_batch_weights(wt_batch, cur_batch_size); } + + DataT batch_cost = DataT{0}; + cuvs::cluster::kmeans::cluster_cost(handle, + batch_data_view, + centroids_const, + raft::make_host_scalar_view(&batch_cost), + batch_sw); + + iter_inertia += batch_cost; + } + } + if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; - raft::copy( - handle, - raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), - raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); + raft::copy(centroids.data_handle(), cur_centroids_ptr, centroid_buf_size, stream); } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", + RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d", seed_iter + 1, n_init, inertia[0], @@ -877,15 +952,26 @@ void kmeans_fit(raft::resources const& handle, auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = raft::make_device_matrix_view(centroids, pams.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; + std::optional> sample_weightView = std::nullopt; if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); auto n_iterView = raft::make_host_scalar_view(&n_iter); - cuvs::cluster::kmeans::detail::kmeans_fit( - handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); + kmeans_fit(handle, pams, XView, sample_weightView, centroidsView, inertiaView, n_iterView); +} + +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } template diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh deleted file mode 100644 index e2fc8d334f..0000000000 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ /dev/null @@ -1,510 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include "kmeans.cuh" -#include "kmeans_common.cuh" - -#include "../../neighbors/detail/ann_utils.cuh" -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include -#include - -namespace cuvs::cluster::kmeans::detail { - -/** - * @brief Initialize centroids from host data - * - * @tparam T Input data type - * @tparam IdxT Index type - */ -template -void init_centroids_from_host_sample(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - IdxT streaming_batch_size, - raft::host_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - - if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - // this is a heuristic to speed up the initialization - IdxT init_sample_size = 3 * streaming_batch_size; - if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } - init_sample_size = std::min(init_sample_size, n_samples); - - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); - - if (params.oversampling_factor == 0) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } else { - cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( - handle, params, raft::make_const_mdspan(init_sample.view()), centroids, workspace); - } - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { - raft::random::RngState random_state(params.rng_state.seed); - raft::matrix::sample_rows(handle, random_state, X, centroids); - } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { - // already provided - } else { - RAFT_FAIL("Unknown initialization method"); - } -} - -/** - * @brief Compute the weight normalization scale factor for host sample weights. Weights are - * normalized to sum to n_samples. Returns the scale factor to apply to each weight. - * - * @param[in] sample_weight Optional host vector of sample weights - * @param[in] n_samples Number of samples - * @return Scale factor (1.0 if no weights or already normalized) - */ -template -T compute_host_weight_scale( - const std::optional>& sample_weight, IdxT n_samples) -{ - if (!sample_weight.has_value()) { return T{1}; } - - T wt_sum = T{0}; - const T* sw_ptr = sample_weight->data_handle(); - for (IdxT i = 0; i < n_samples; ++i) { - wt_sum += sw_ptr[i]; - } - if (wt_sum == static_cast(n_samples)) { return T{1}; } - - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %zu samples (scale=%f)", - static_cast(n_samples), - static_cast(static_cast(n_samples) / wt_sum)); - return static_cast(n_samples) / wt_sum; -} - -/** - * @brief Copy host sample weights to device and apply normalization scale. - * - * When sample_weight is provided, copies the batch slice to the device buffer - * and applies the normalization scale factor. When not provided, the device - * buffer is assumed to already be filled with 1.0. - * - * @param[in] handle RAFT resources handle - * @param[in] sample_weight Optional host weights - * @param[in] batch_offset Offset into the host weights for this batch - * @param[in] batch_size Number of elements in this batch - * @param[in] weight_scale Scale factor from compute_host_weight_scale - * @param[inout] batch_weights Device buffer to write normalized weights into - */ -template -void copy_and_scale_batch_weights( - raft::resources const& handle, - const std::optional>& sample_weight, - size_t batch_offset, - IdxT batch_size, - T weight_scale, - raft::device_vector& batch_weights) -{ - if (!sample_weight.has_value()) { return; } - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::copy( - batch_weights.data_handle(), sample_weight->data_handle() + batch_offset, batch_size, stream); - - if (weight_scale != T{1}) { - auto batch_weights_view = - raft::make_device_vector_view(batch_weights.data_handle(), batch_size); - raft::linalg::map(handle, - batch_weights_view, - raft::mul_const_op{weight_scale}, - raft::make_const_mdspan(batch_weights_view)); - } -} - -/** - * @brief Accumulate partial centroid sums and counts from a batch - * - * This function adds the partial sums from a batch to the running accumulators. - * It does NOT divide - that happens once at the end of all batches. - */ -template -void accumulate_batch_centroids( - raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IdxT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view batch_sums, - raft::device_vector_view batch_counts) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto workspace = rmm::device_uvector( - batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); - - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> - labels_itr(minClusterAndDistance.data_handle(), conversion_op); - - cuvs::cluster::kmeans::detail::compute_centroid_adjustments( - handle, - batch_data, - sample_weights, - labels_itr, - static_cast(centroid_sums.extent(0)), - batch_sums, - batch_counts, - workspace); - - raft::linalg::add(centroid_sums.data_handle(), - centroid_sums.data_handle(), - batch_sums.data_handle(), - centroid_sums.size(), - stream); - - raft::linalg::add(cluster_counts.data_handle(), - cluster_counts.data_handle(), - batch_counts.data_handle(), - cluster_counts.size(), - stream); -} - -/** - * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). - * - * Processes host data in GPU-sized batches per iteration, accumulating partial centroid - * sums and counts, then finalizes centroids at the end of each iteration. - * - * @tparam T Input data type (float, double) - * @tparam IdxT Index type (int, int64_t) - * - * @param[in] handle RAFT resources handle - * @param[in] params K-means parameters - * @param[in] X Input data on HOST [n_samples x n_features] - * @param[in] sample_weight Optional weights per sample (on host) - * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] - * @param[out] inertia Sum of squared distances to nearest centroid - * @param[out] n_iter Number of iterations run - */ -template -void fit(raft::resources const& handle, - const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - IdxT streaming_batch_size = static_cast(params.streaming_batch_size); - - if (params.streaming_batch_size == 0) { - streaming_batch_size = static_cast(n_samples); - } else if (params.streaming_batch_size < 0 || params.streaming_batch_size > n_samples) { - RAFT_LOG_WARN("streaming_batch_size must be >= 0 and <= n_samples, using n_samples=%zu", - static_cast(n_samples)); - streaming_batch_size = static_cast(n_samples); - } - - RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, - "centroids.extent(0) must equal n_clusters"); - RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); - - RAFT_LOG_DEBUG( - "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, streaming_batch_size=%zu", - static_cast(n_samples), - static_cast(n_features), - n_clusters, - static_cast(streaming_batch_size)); - - rmm::device_uvector workspace(0, stream); - - auto n_init = params.n_init; - if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - auto best_centroids = n_init > 1 - ? raft::make_device_matrix(handle, n_clusters, n_features) - : raft::make_device_matrix(handle, 0, 0); - T best_inertia = std::numeric_limits::max(); - IdxT best_n_iter = 0; - - std::mt19937 gen(params.rng_state.seed); - - // ----- Allocate reusable work buffers (shared across n_init iterations) ----- - auto batch_weights = raft::make_device_vector(handle, streaming_batch_size); - auto minClusterAndDistance = - raft::make_device_vector, IdxT>(handle, streaming_batch_size); - auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); - auto clustering_cost = raft::make_device_vector(handle, 1); - auto batch_clustering_cost = raft::make_device_vector(handle, 1); - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); - - // Compute weight normalization (matches checkWeight in regular kmeans) - T weight_scale = compute_host_weight_scale(sample_weight, n_samples); - - // ---- Main n_init loop ---- - for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { - cuvs::cluster::kmeans::params iter_params = params; - iter_params.rng_state.seed = gen(); - - RAFT_LOG_DEBUG("KMeans batched fit: n_init iteration %d/%d (seed=%llu)", - seed_iter + 1, - n_init, - (unsigned long long)iter_params.rng_state.seed); - - if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { - init_centroids_from_host_sample( - handle, iter_params, streaming_batch_size, X, centroids, workspace); - } - - if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } - - // Reset per-iteration state - T prior_cluster_cost = 0; - - cuvs::spatial::knn::detail::utils::batch_load_iterator data_batches( - X.data_handle(), n_samples, n_features, streaming_batch_size, stream); - - for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - - raft::matrix::fill(handle, centroid_sums.view(), T{0}); - raft::matrix::fill(handle, weight_per_cluster.view(), T{0}); - - raft::matrix::fill(handle, clustering_cost.view(), T{0}); - - auto centroids_const = raft::make_const_mdspan(centroids); - - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - raft::matrix::fill(handle, batch_clustering_cost.view(), T{0}); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - - auto batch_weights_view = raft::make_device_vector_view( - batch_weights.data_handle(), current_batch_size); - - auto L2NormBatch_view = - raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features), - L2NormBatch_view); - } - - auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); - - auto minClusterAndDistance_view = - raft::make_device_vector_view, IdxT>( - minClusterAndDistance.data_handle(), current_batch_size); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - batch_data_view, - centroids_const, - minClusterAndDistance_view, - L2NormBatch_const, - L2NormBuf_OR_DistBuf, - metric, - params.batch_samples, - params.batch_centroids, - workspace); - - auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); - - accumulate_batch_centroids(handle, - batch_data_view, - minClusterAndDistance_const, - batch_weights_view, - centroid_sums.view(), - weight_per_cluster.view(), - batch_sums.view(), - batch_counts.view()); - - if (params.inertia_check) { - raft::linalg::map( - handle, - minClusterAndDistance_view, - [=] __device__(const raft::KeyValuePair kvp, T wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance_view), - batch_weights_view); - - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance_view, - workspace, - raft::make_device_scalar_view(batch_clustering_cost.data_handle()), - raft::value_op{}, - raft::add_op{}); - raft::linalg::add(handle, - raft::make_const_mdspan(clustering_cost.view()), - raft::make_const_mdspan(batch_clustering_cost.view()), - clustering_cost.view()); - } - } - - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto weight_per_cluster_const = - raft::make_device_vector_view(weight_per_cluster.data_handle(), n_clusters); - - finalize_centroids(handle, - centroid_sums_const, - weight_per_cluster_const, - centroids_const, - new_centroids.view()); - - T sqrdNormError = compute_centroid_shift( - handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(new_centroids.view())); - - raft::copy(handle, centroids, new_centroids.view()); - - bool done = false; - if (params.inertia_check) { - raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); - raft::resource::sync_stream(handle); - ASSERT(inertia[0] != (T)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - if (n_iter[0] > 1) { - T delta = inertia[0] / prior_cluster_cost; - if (delta > 1 - params.tol) done = true; - } - prior_cluster_cost = inertia[0]; - } - - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - // Recompute final weighted inertia with the converged centroids. - { - auto centroids_const_view = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - - inertia[0] = T{0}; - for (const auto& data_batch : data_batches) { - IdxT current_batch_size = static_cast(data_batch.size()); - - auto batch_data_view = raft::make_device_matrix_view( - data_batch.data(), current_batch_size, n_features); - - std::optional> batch_sw = std::nullopt; - if (sample_weight.has_value()) { - copy_and_scale_batch_weights(handle, - sample_weight, - data_batch.offset(), - current_batch_size, - weight_scale, - batch_weights); - batch_sw = raft::make_device_vector_view(batch_weights.data_handle(), - current_batch_size); - } - - T batch_cost = T{0}; - cuvs::cluster::kmeans::cluster_cost(handle, - batch_data_view, - centroids_const_view, - raft::make_host_scalar_view(&batch_cost), - batch_sw); - - inertia[0] += batch_cost; - } - } - - RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", - seed_iter + 1, - n_init, - static_cast(inertia[0])); - - if (n_init > 1 && inertia[0] < best_inertia) { - best_inertia = inertia[0]; - best_n_iter = n_iter[0]; - raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); - } - } - if (n_init > 1) { - inertia[0] = best_inertia; - n_iter[0] = best_n_iter; - raft::copy(handle, centroids, best_centroids.view()); - RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", - n_init, - static_cast(best_inertia), - best_n_iter); - } -} - -} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 250563dd12..bc2de15726 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,9 @@ #include #include #include +#include + +#include #include #include @@ -154,7 +158,7 @@ void checkWeight(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&wt_sum), raft::make_device_scalar_view(wt_aggr.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); if (wt_sum != n_samples) { RAFT_LOG_DEBUG( @@ -262,7 +266,7 @@ void sampleCentroids(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&nPtsSampledInRank), raft::make_device_scalar_view(nSelected.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); thrust::for_each_n(raft::resource::get_thrust_policy(handle), @@ -367,7 +371,9 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms = + std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ extern template void minClusterAndDistanceCompute( \ @@ -380,7 +386,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int) @@ -399,7 +406,9 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> + precomputed_centroid_norms = std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \ extern template void minClusterDistanceCompute( \ @@ -412,7 +421,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t) @@ -594,8 +604,174 @@ DataT compute_centroid_shift(raft::resources const& handle, new_centroids.data_handle()); DataT result = 0; raft::copy(&result, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); return result; } +/** + * @brief Compute the weight normalization scale factor for sample weights that may + * reside on host memory. Weights are normalized to sum to n_samples. + * + * Works on any contiguous pointer (host or device) by copying to host for the sum. + * + * @tparam DataT Weight type + * @tparam IndexT Index type + * + * @param[in] handle RAFT resources handle + * @param[in] weight_ptr Pointer to sample weights (host or device), may be nullptr + * @param[in] n_samples Number of samples + * @return Scale factor (1.0 if weights already sum to n_samples or nullptr) + */ +template +DataT compute_weight_scale(raft::resources const& handle, const DataT* weight_ptr, IndexT n_samples) +{ + if (weight_ptr == nullptr) { return DataT{1}; } + + bool is_device_accessible = + raft::is_device_accessible(raft::memory_type_from_pointer(weight_ptr)); + + DataT wt_sum = DataT{0}; + if (!is_device_accessible) { + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += weight_ptr[i]; + } + } else { + std::vector h_weights(n_samples); + auto d_view = raft::make_device_vector_view(weight_ptr, n_samples); + auto h_view = raft::make_host_vector_view(h_weights.data(), n_samples); + raft::copy(handle, h_view, d_view); + raft::resource::sync_stream(handle); + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += h_weights[i]; + } + } + + if (wt_sum == static_cast(n_samples)) { return DataT{1}; } + + RAFT_LOG_DEBUG( + "[Warning!] KMeans: normalizing the user provided sample weight to " + "sum up to %zu samples", + static_cast(n_samples)); + return static_cast(n_samples) / wt_sum; +} + +/** + * @brief Process a single batch of data in the Lloyd iteration. + * + * Given one batch of data + precomputed norms + weights + current centroids it + * 1. finds the nearest centroid for every sample, + * 2. accumulates weighted centroid sums and counts into the running accumulators, + * 3. accumulates the weighted clustering cost (inertia). + * + * Data norms must be precomputed by the caller and passed in via L2NormBatch. + * + * @tparam DataT Data / weight type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] batch_data Device batch data [batch_size x n_features] + * @param[in] batch_weights Device batch weights [batch_size] + * @param[in] centroids Current centroids [n_clusters x n_features] + * @param[in] metric Distance metric + * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute + * @param[in] batch_centroids_param Batch-centroids param forwarded to + * minClusterAndDistanceCompute + * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[in] L2NormBatch Precomputed data norms [batch_size] + * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch + * @param[inout] workspace Resizable scratch + * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) + * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) + * @param[inout] batch_sums Scratch for this batch [n_clusters x n_features] + * @param[inout] batch_counts Scratch for this batch [n_clusters] + * @param[inout] clustering_cost Running cost scalar (device) (added into) + * @param[in] centroid_norms Optional precomputed centroid norms [n_clusters]. + * When provided, skips internal centroid norm computation. + */ +template +void process_batch( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_scalar_view clustering_cost, + std::optional> centroid_norms = std::nullopt) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + IndexT current_batch_sz = batch_data.extent(0); + + minClusterAndDistanceCompute(handle, + batch_data, + centroids, + minClusterAndDistance, + L2NormBatch, + L2NormBuf_OR_DistBuf, + metric, + batch_samples_param, + batch_centroids_param, + workspace, + centroid_norms); + + KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + auto batch_workspace = rmm::device_uvector( + current_batch_sz, stream, raft::resource::get_workspace_resource(handle)); + + compute_centroid_adjustments(handle, + batch_data, + batch_weights, + labels_itr, + static_cast(centroid_sums.extent(0)), + batch_sums, + batch_counts, + batch_workspace); + + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(weight_per_cluster.data_handle(), + weight_per_cluster.data_handle(), + batch_counts.data_handle(), + weight_per_cluster.size(), + stream); + + raft::linalg::map( + handle, + minClusterAndDistance, + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance), + batch_weights); + + auto batch_cost = raft::make_device_scalar(handle, DataT{0}); + computeClusterCost( + handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + raft::linalg::add(clustering_cost.data_handle(), + clustering_cost.data_handle(), + batch_cost.data_handle(), + 1, + stream); +} + } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index fdec2bdd73..cbc75c822c 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -701,49 +701,45 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; - if (params.inertia_check) { - rmm::device_scalar> clusterCostD(stream); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::cluster_cost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); - - // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); - - DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers\n"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; + rmm::device_scalar> clusterCostD(stream); + + // calculate cluster cost phi_x(C) + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + cuda::proclaim_return_type>( + [] __device__(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) { + raft::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + })); + + // Cluster cost phi_x(C) from all ranks + comm.allreduce(&(clusterCostD.data()->value), + &(clusterCostD.data()->value), + 1, + raft::comms::op_t::SUM, + stream); + + DataT curClusteringCost = 0; + raft::copy(handle, + raft::make_host_scalar_view(&curClusteringCost), + raft::make_device_scalar_view(&(clusterCostD.data()->value))); + + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + if (curClusteringCost == (DataT)0.0) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_iter[0] > 1) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; } + priorClusteringCost = curClusteringCost; raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 8370ff922f..bcfc381753 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -7,6 +7,8 @@ #include +#include + namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an @@ -23,36 +25,34 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -87,7 +87,7 @@ void minClusterAndDistanceCompute( datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -154,7 +154,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int64_t) @@ -164,16 +165,18 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void minClusterDistanceCompute( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view minClusterDistance, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -186,21 +189,22 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -232,7 +236,7 @@ void minClusterDistanceCompute(raft::resources const& handle, datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -290,7 +294,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_DISTANCE(double, int64_t) diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index d7e4748e33..51cd21cb51 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index f86fabcfbd..000774b9c6 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -72,7 +71,7 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( + cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 437aa16c76..f521edd07f 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -18,8 +18,12 @@ void fit_main(raft::resources const& handle, raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { - cuvs::cluster::kmeans::detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); + cuvs::cluster::kmeans::params p = params; + p.init = kmeans::params::InitMethod::Array; + p.n_init = 1; + auto sw = std::make_optional( + raft::make_device_vector_view(sample_weights.data_handle(), X.extent(0))); + cuvs::cluster::kmeans::detail::kmeans_fit(handle, p, X, sw, centroids, inertia, n_iter); } template @@ -31,7 +35,6 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise if (raft::resource::comms_initialized(handle)) { cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } else { @@ -54,4 +57,17 @@ void predict(raft::resources const& handle, handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } +template +void fit(raft::resources const& handle, + const kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 1ef8d07623..5d48ef099e 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -433,9 +433,8 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(d_centroids_ref.data(), params.n_clusters, n_features); - params.init = cuvs::cluster::kmeans::params::Array; - params.inertia_check = true; - params.max_iter = 20; + params.init = cuvs::cluster::kmeans::params::Array; + params.max_iter = 20; T ref_inertia = 0; int ref_n_iter = 0; @@ -448,7 +447,6 @@ class KmeansFitBatchedTest : public ::testing::TestWithParam(&ref_n_iter)); cuvs::cluster::kmeans::params batched_params = params; - batched_params.inertia_check = true; batched_params.streaming_batch_size = testparams.streaming_batch_size; std::optional> h_sw = std::nullopt; diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 6d0c878660..ccacb7042b 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -33,9 +33,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: int batch_samples, int batch_centroids, bool inertia_check, - int64_t streaming_batch_size, bool hierarchical, - int hierarchical_n_iters + int hierarchical_n_iters, + int64_t streaming_batch_size, + int64_t init_size ctypedef cuvsKMeansParams* cuvsKMeansParams_t diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b267c908c9..246ac4138c 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -76,8 +76,10 @@ cdef class KMeansParams: [batch_samples x n_clusters]. batch_centroids : int Number of centroids to process in each batch. If 0, uses n_clusters. - inertia_check : bool - If True, check inertia during iterations for early convergence. + init_size : int + Number of samples to draw for KMeansPlusPlus initialization with + host (out-of-core) data. When set to 0, uses the heuristic + min(3 * n_clusters, n_samples). Default: 0. streaming_batch_size : int Number of samples to process per GPU batch when fitting with host (numpy) data. When set to 0, defaults to n_samples (process all @@ -111,7 +113,7 @@ cdef class KMeansParams: oversampling_factor=None, batch_samples=None, batch_centroids=None, - inertia_check=None, + init_size=None, streaming_batch_size=None, hierarchical=None, hierarchical_n_iters=None): @@ -134,8 +136,8 @@ cdef class KMeansParams: self.params.batch_samples = batch_samples if batch_centroids is not None: self.params.batch_centroids = batch_centroids - if inertia_check is not None: - self.params.inertia_check = inertia_check + if init_size is not None: + self.params.init_size = init_size if streaming_batch_size is not None: self.params.streaming_batch_size = streaming_batch_size if hierarchical is not None: @@ -183,8 +185,8 @@ cdef class KMeansParams: return self.params.batch_centroids @property - def inertia_check(self): - return self.params.inertia_check + def init_size(self): + return self.params.init_size @property def streaming_batch_size(self):