Skip to content

Commit 7c8fa6e

Browse files
committed
correct norm computation for cosine
1 parent 079b1bb commit 7c8fa6e

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

cpp/src/cluster/detail/minClusterDistanceCompute.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,20 @@ void minClusterDistanceCompute(raft::resources const& handle,
191191
auto centroidsNorm =
192192
raft::make_device_vector_view<DataT, IndexT>(L2NormBuf_OR_DistBuf.data(), n_clusters);
193193

194-
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
195-
handle,
196-
raft::make_device_matrix_view<const DataT, IndexT>(
197-
centroids.data_handle(), centroids.extent(0), centroids.extent(1)),
198-
centroidsNorm);
194+
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
195+
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
196+
handle,
197+
raft::make_device_matrix_view<const DataT, IndexT>(
198+
centroids.data_handle(), centroids.extent(0), centroids.extent(1)),
199+
centroidsNorm,
200+
raft::sqrt_op{});
201+
} else {
202+
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
203+
handle,
204+
raft::make_device_matrix_view<const DataT, IndexT>(
205+
centroids.data_handle(), centroids.extent(0), centroids.extent(1)),
206+
centroidsNorm);
207+
}
199208

200209
workspace.resize((sizeof(IndexT)) * n_samples, stream);
201210

0 commit comments

Comments
 (0)