Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1b763de
spectral clustering
aamijar Oct 14, 2025
bf6fe74
Merge branch 'branch-25.12' into spectral-clustering
aamijar Oct 14, 2025
58a2d04
add gtests
aamijar Oct 15, 2025
34e81dd
update
aamijar Oct 18, 2025
38f12ee
Merge branch 'branch-25.12' into spectral-clustering
aamijar Oct 18, 2025
c3069f0
update hyperparams
aamijar Oct 20, 2025
67a21c5
Merge branch 'main' into spectral-clustering
aamijar Oct 20, 2025
f1ab56a
rename to spectral.hpp
aamijar Oct 21, 2025
376da2e
Merge branch 'main' into spectral-clustering
aamijar Oct 21, 2025
5007df9
Merge branch 'main' into spectral-clustering
aamijar Oct 21, 2025
472eb55
increase precision
aamijar Oct 21, 2025
2d90bb5
Merge branch 'spectral-clustering' of https://github.com/aamijar/cuvs…
aamijar Oct 21, 2025
c42e74f
update gtest
aamijar Oct 21, 2025
913d3aa
rerun CI
aamijar Oct 21, 2025
e4db202
rerun CI
aamijar Oct 22, 2025
bb3846b
rerun CI
aamijar Oct 22, 2025
d90614c
Merge branch 'main' into spectral-clustering
aamijar Oct 22, 2025
3ecd180
rerun CI
aamijar Oct 22, 2025
f377524
Merge branch 'spectral-clustering' of https://github.com/aamijar/cuvs…
aamijar Oct 22, 2025
7d3fe90
Merge branch 'main' into spectral-clustering
aamijar Oct 23, 2025
4708e69
rerun CI
aamijar Oct 23, 2025
6cc47d6
Merge branch 'main' into spectral-clustering
aamijar Oct 24, 2025
d9d4bc8
rng_state
aamijar Oct 27, 2025
eae5ef9
Merge branch 'main' into spectral-clustering
aamijar Oct 27, 2025
5c608f3
support double types
aamijar Oct 28, 2025
148bfd5
Merge branch 'main' into spectral-clustering
aamijar Oct 28, 2025
0477a64
update gtests
aamijar Oct 28, 2025
0544f46
Merge branch 'main' into spectral-clustering
aamijar Oct 29, 2025
b9c5876
move templates to src files
aamijar Oct 30, 2025
c8ce74f
remove whitespace
aamijar Oct 30, 2025
40ac5c2
link issue
aamijar Nov 1, 2025
118af8f
Merge branch 'main' into spectral-clustering
aamijar Nov 1, 2025
d9faea2
rerun CI
aamijar Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ if(NOT BUILD_CPU_ONLY)
src/cluster/kmeans_transform_double.cu
src/cluster/kmeans_transform_float.cu
src/cluster/single_linkage_float.cu
src/cluster/spectral.cu
src/core/bitset.cu
src/core/omp_wrapper.cpp
src/distance/detail/kernels/gram_matrix.cu
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/cuvs/cluster/spectral.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/random/rng_state.hpp>

namespace cuvs::cluster::spectral {

struct params {
int n_clusters;
int n_components;
int n_init;
int n_neighbors;
raft::random::RngState rng_state{0};
};

// TODO: int64_t nnz support (see https://github.com/rapidsai/cuvs/issues/1484)
void fit_predict(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<float, int, int, int> connectivity_graph,
raft::device_vector_view<int, int> labels);

void fit_predict(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<double, int, int, int> connectivity_graph,
raft::device_vector_view<int, int> labels);

} // namespace cuvs::cluster::spectral
14 changes: 14 additions & 0 deletions cpp/include/cuvs/preprocessing/spectral_embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,22 @@ void transform(raft::resources const& handle,
raft::device_coo_matrix_view<float, int, int, int> connectivity_graph,
raft::device_matrix_view<float, int, raft::col_major> embedding);

void transform(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<double, int, int, int> connectivity_graph,
raft::device_matrix_view<double, int, raft::col_major> embedding);

/**
* @}
*/

} // namespace cuvs::preprocessing::spectral_embedding

namespace cuvs::preprocessing::spectral_embedding::helpers {

void create_connectivity_graph(raft::resources const& handle,
Copy link
Member

@cjnolet cjnolet Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not part of the standard "fit()" "predict()" APIs that one would normally use day to day, so please add it to the cuvs::preprocessing::spectral_embedding::helpers namespace.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in f1ab56a

params spectral_embedding_config,
raft::device_matrix_view<float, int, raft::row_major> dataset,
raft::device_coo_matrix<float, int, int, int>& connectivity_graph);

} // namespace cuvs::preprocessing::spectral_embedding::helpers
63 changes: 63 additions & 0 deletions cpp/src/cluster/detail/spectral.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuvs/cluster/kmeans.hpp>
#include <cuvs/cluster/spectral.hpp>
#include <cuvs/preprocessing/spectral_embedding.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/random/rng_state.hpp>

namespace cuvs::cluster::spectral::detail {

template <typename DataT>
void fit_predict(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph,
raft::device_vector_view<int, int> labels)
{
int n_samples = connectivity_graph.structure_view().get_n_rows();
DataT inertia;
int n_iter;
auto embedding_col_major =
raft::make_device_matrix<DataT, int, raft::col_major>(handle, n_samples, config.n_components);
auto embedding_row_major =
raft::make_device_matrix<DataT, int, raft::row_major>(handle, n_samples, config.n_components);
cuvs::preprocessing::spectral_embedding::params spectral_embedding_config;
spectral_embedding_config.n_components = config.n_components;
spectral_embedding_config.n_neighbors = config.n_neighbors;
spectral_embedding_config.norm_laplacian = true;
spectral_embedding_config.drop_first = false;
spectral_embedding_config.seed = config.rng_state.seed;

cuvs::cluster::kmeans::params kmeans_config;
kmeans_config.n_clusters = config.n_clusters;
kmeans_config.rng_state = config.rng_state;
kmeans_config.n_init = config.n_init;
kmeans_config.oversampling_factor = 0.0;

cuvs::preprocessing::spectral_embedding::transform(
handle, spectral_embedding_config, connectivity_graph, embedding_col_major.view());

raft::linalg::transpose(handle,
embedding_col_major.data_handle(),
embedding_row_major.data_handle(),
n_samples,
config.n_components,
raft::resource::get_cuda_stream(handle));

cuvs::cluster::kmeans::fit_predict(handle,
kmeans_config,
embedding_row_major.view(),
std::nullopt,
std::nullopt,
labels,
raft::make_host_scalar_view(&inertia),
raft::make_host_scalar_view(&n_iter));
}

} // namespace cuvs::cluster::spectral::detail
26 changes: 26 additions & 0 deletions cpp/src/cluster/spectral.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "./detail/spectral.cuh"

#include <cuvs/cluster/spectral.hpp>

namespace cuvs::cluster::spectral {

#define CUVS_INST_SPECTRAL(DataT) \
void fit_predict(raft::resources const& handle, \
params config, \
raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, \
raft::device_vector_view<int, int> labels) \
{ \
detail::fit_predict<DataT>(handle, config, connectivity_graph, labels); \
}

CUVS_INST_SPECTRAL(float);
CUVS_INST_SPECTRAL(double);

#undef CUVS_INST_SPECTRAL

} // namespace cuvs::cluster::spectral
Loading