-
Notifications
You must be signed in to change notification settings - Fork 146
Spectral Clustering #1425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Spectral Clustering #1425
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
1b763de
spectral clustering
aamijar bf6fe74
Merge branch 'branch-25.12' into spectral-clustering
aamijar 58a2d04
add gtests
aamijar 34e81dd
update
aamijar 38f12ee
Merge branch 'branch-25.12' into spectral-clustering
aamijar c3069f0
update hyperparams
aamijar 67a21c5
Merge branch 'main' into spectral-clustering
aamijar f1ab56a
rename to spectral.hpp
aamijar 376da2e
Merge branch 'main' into spectral-clustering
aamijar 5007df9
Merge branch 'main' into spectral-clustering
aamijar 472eb55
increase precision
aamijar 2d90bb5
Merge branch 'spectral-clustering' of https://github.com/aamijar/cuvs…
aamijar c42e74f
update gtest
aamijar 913d3aa
rerun CI
aamijar e4db202
rerun CI
aamijar bb3846b
rerun CI
aamijar d90614c
Merge branch 'main' into spectral-clustering
aamijar 3ecd180
rerun CI
aamijar f377524
Merge branch 'spectral-clustering' of https://github.com/aamijar/cuvs…
aamijar 7d3fe90
Merge branch 'main' into spectral-clustering
aamijar 4708e69
rerun CI
aamijar 6cc47d6
Merge branch 'main' into spectral-clustering
aamijar d9d4bc8
rng_state
aamijar eae5ef9
Merge branch 'main' into spectral-clustering
aamijar 5c608f3
support double types
aamijar 148bfd5
Merge branch 'main' into spectral-clustering
aamijar 0477a64
update gtests
aamijar 0544f46
Merge branch 'main' into spectral-clustering
aamijar b9c5876
move templates to src files
aamijar c8ce74f
remove whitespace
aamijar 40ac5c2
link issue
aamijar 118af8f
Merge branch 'main' into spectral-clustering
aamijar d9faea2
rerun CI
aamijar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <raft/core/device_coo_matrix.hpp> | ||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <raft/random/rng_state.hpp> | ||
|
|
||
| namespace cuvs::cluster::spectral { | ||
|
|
||
| struct params { | ||
| int n_clusters; | ||
| int n_components; | ||
| int n_init; | ||
| int n_neighbors; | ||
| raft::random::RngState rng_state{0}; | ||
| }; | ||
|
|
||
| // TODO: int64_t nnz support (see https://github.com/rapidsai/cuvs/issues/1484) | ||
| void fit_predict(raft::resources const& handle, | ||
| params config, | ||
| raft::device_coo_matrix_view<float, int, int, int> connectivity_graph, | ||
| raft::device_vector_view<int, int> labels); | ||
|
|
||
| void fit_predict(raft::resources const& handle, | ||
| params config, | ||
| raft::device_coo_matrix_view<double, int, int, int> connectivity_graph, | ||
| raft::device_vector_view<int, int> labels); | ||
|
|
||
| } // namespace cuvs::cluster::spectral |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuvs/cluster/kmeans.hpp> | ||
| #include <cuvs/cluster/spectral.hpp> | ||
| #include <cuvs/preprocessing/spectral_embedding.hpp> | ||
| #include <raft/core/device_mdarray.hpp> | ||
| #include <raft/linalg/transpose.cuh> | ||
| #include <raft/random/rng_state.hpp> | ||
|
|
||
| namespace cuvs::cluster::spectral::detail { | ||
|
|
||
| template <typename DataT> | ||
| void fit_predict(raft::resources const& handle, | ||
| params config, | ||
| raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, | ||
| raft::device_vector_view<int, int> labels) | ||
| { | ||
| int n_samples = connectivity_graph.structure_view().get_n_rows(); | ||
| DataT inertia; | ||
| int n_iter; | ||
| auto embedding_col_major = | ||
| raft::make_device_matrix<DataT, int, raft::col_major>(handle, n_samples, config.n_components); | ||
| auto embedding_row_major = | ||
| raft::make_device_matrix<DataT, int, raft::row_major>(handle, n_samples, config.n_components); | ||
| cuvs::preprocessing::spectral_embedding::params spectral_embedding_config; | ||
| spectral_embedding_config.n_components = config.n_components; | ||
| spectral_embedding_config.n_neighbors = config.n_neighbors; | ||
| spectral_embedding_config.norm_laplacian = true; | ||
| spectral_embedding_config.drop_first = false; | ||
| spectral_embedding_config.seed = config.rng_state.seed; | ||
|
|
||
| cuvs::cluster::kmeans::params kmeans_config; | ||
| kmeans_config.n_clusters = config.n_clusters; | ||
| kmeans_config.rng_state = config.rng_state; | ||
| kmeans_config.n_init = config.n_init; | ||
| kmeans_config.oversampling_factor = 0.0; | ||
|
|
||
| cuvs::preprocessing::spectral_embedding::transform( | ||
| handle, spectral_embedding_config, connectivity_graph, embedding_col_major.view()); | ||
|
|
||
| raft::linalg::transpose(handle, | ||
| embedding_col_major.data_handle(), | ||
| embedding_row_major.data_handle(), | ||
| n_samples, | ||
| config.n_components, | ||
| raft::resource::get_cuda_stream(handle)); | ||
|
|
||
| cuvs::cluster::kmeans::fit_predict(handle, | ||
| kmeans_config, | ||
| embedding_row_major.view(), | ||
| std::nullopt, | ||
| std::nullopt, | ||
| labels, | ||
| raft::make_host_scalar_view(&inertia), | ||
| raft::make_host_scalar_view(&n_iter)); | ||
| } | ||
|
|
||
| } // namespace cuvs::cluster::spectral::detail |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #include "./detail/spectral.cuh" | ||
|
|
||
| #include <cuvs/cluster/spectral.hpp> | ||
|
|
||
| namespace cuvs::cluster::spectral { | ||
|
|
||
| #define CUVS_INST_SPECTRAL(DataT) \ | ||
| void fit_predict(raft::resources const& handle, \ | ||
| params config, \ | ||
| raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, \ | ||
| raft::device_vector_view<int, int> labels) \ | ||
| { \ | ||
| detail::fit_predict<DataT>(handle, config, connectivity_graph, labels); \ | ||
| } | ||
|
|
||
| CUVS_INST_SPECTRAL(float); | ||
| CUVS_INST_SPECTRAL(double); | ||
|
|
||
| #undef CUVS_INST_SPECTRAL | ||
|
|
||
| } // namespace cuvs::cluster::spectral |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not part of the standard "fit()" "predict()" APIs that one would normally use day to day, so please add it to the
cuvs::preprocessing::spectral_embedding::helpersnamespace.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in f1ab56a