Skip to content

Commit 45e220d

Browse files
authored
Improved memory efficiency in UMAP given precomputed knn graphs (#7481)
Closes #7143 This PR improves memory usage in UMAP when given a precomputed knn graph. Previously, a user-given knn graph will occupy GPU memory throughout the full UMAP pipeline even though it is not needed in later steps of UMAP. In this PR, if the user-given knn graph is on host memory, we keep it on host memory and copy to device at the cpp level to allow better memory management. ### This PR with precomputed knn graph on CPU <img width="808" height="313" alt="Screenshot 2025-11-12 at 7 00 33 PM" src="https://github.com/user-attachments/assets/6c752f62-a1b2-4fb1-a44d-d86ed468915b" /> ### Before with precomputed knn graph on CPU <img width="828" height="316" alt="Screenshot 2025-11-12 at 7 01 12 PM" src="https://github.com/user-attachments/assets/8237fdd4-e0bb-48f5-bc46-71878ce14b33" /> Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Philip Hyunsu Cho (https://github.com/hcho3) - Simon Adorf (https://github.com/csadorf) - Tarang Jain (https://github.com/tarang-jain) URL: #7481
1 parent f69ab90 commit 45e220d

File tree

4 files changed

+40
-9
lines changed

4 files changed

+40
-9
lines changed

cpp/include/cuml/manifold/common.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#pragma once
77

8+
#include <raft/spatial/knn/detail/ann_utils.cuh>
9+
810
#include <stdint.h>
911

1012
namespace ML {
@@ -104,7 +106,15 @@ struct manifold_precomputed_knn_inputs_t : public manifold_inputs_t<value_t> {
104106

105107
knn_graph<value_idx, value_t> knn_graph;
106108

107-
bool alloc_knn_graph() const { return false; }
109+
bool alloc_knn_graph() const
110+
{
111+
// Return true if data is on CPU (need to allocate device memory)
112+
// Return false if data is already on device (no allocation needed)
113+
auto pointer_residency = raft::spatial::knn::detail::utils::check_pointer_residency(
114+
knn_graph.knn_indices, knn_graph.knn_dists);
115+
return pointer_residency == raft::spatial::knn::detail::utils::pointer_residency::host_only ||
116+
pointer_residency == raft::spatial::knn::detail::utils::pointer_residency::mixed;
117+
}
108118
};
109119

110120
}; // end namespace ML

cpp/src/umap/knn_graph/algo.cuh

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,14 @@ inline void launcher(const raft::handle_t& handle,
197197
const ML::UMAPParams* params,
198198
cudaStream_t stream)
199199
{
200-
out.knn_indices = inputsA.knn_graph.knn_indices;
201-
out.knn_dists = inputsA.knn_graph.knn_dists;
200+
if (inputsA.alloc_knn_graph()) {
201+
// if new space for the knn graph is allocated, copy the data from the precomputed knn graph
202+
raft::copy(out.knn_indices, inputsA.knn_graph.knn_indices, inputsA.n * n_neighbors, stream);
203+
raft::copy(out.knn_dists, inputsA.knn_graph.knn_dists, inputsA.n * n_neighbors, stream);
204+
} else {
205+
out.knn_indices = inputsA.knn_graph.knn_indices;
206+
out.knn_dists = inputsA.knn_graph.knn_dists;
207+
}
202208
}
203209

204210
// Instantiation for precomputed inputs, int indices
@@ -211,8 +217,14 @@ inline void launcher(const raft::handle_t& handle,
211217
const ML::UMAPParams* params,
212218
cudaStream_t stream)
213219
{
214-
out.knn_indices = inputsA.knn_graph.knn_indices;
215-
out.knn_dists = inputsA.knn_graph.knn_dists;
220+
if (inputsA.alloc_knn_graph()) {
221+
// if new space for the knn graph is allocated, copy the data from the precomputed knn graph
222+
raft::copy(out.knn_indices, inputsA.knn_graph.knn_indices, inputsA.n * n_neighbors, stream);
223+
raft::copy(out.knn_dists, inputsA.knn_graph.knn_dists, inputsA.n * n_neighbors, stream);
224+
} else {
225+
out.knn_indices = inputsA.knn_graph.knn_indices;
226+
out.knn_dists = inputsA.knn_graph.knn_dists;
227+
}
216228
}
217229

218230
} // namespace Algo

python/cuml/cuml/common/sparsefuncs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _determine_k_from_arrays(
263263
return total_elements // n_samples
264264

265265

266-
def extract_knn_graph(knn_info, n_neighbors):
266+
def extract_knn_graph(knn_info, n_neighbors, mem_type="device"):
267267
"""
268268
Extract the nearest neighbors distances and indices
269269
from the knn_info parameter.
@@ -367,6 +367,7 @@ def extract_knn_graph(knn_info, n_neighbors):
367367
deepcopy=deepcopy,
368368
check_dtype=np.int64,
369369
convert_to_dtype=np.int64,
370+
convert_to_mem_type=mem_type,
370371
)
371372

372373
knn_dists_m, _, _, _ = input_to_cuml_array(
@@ -375,6 +376,7 @@ def extract_knn_graph(knn_info, n_neighbors):
375376
deepcopy=deepcopy,
376377
check_dtype=np.float32,
377378
convert_to_dtype=np.float32,
379+
convert_to_mem_type=mem_type,
378380
)
379381

380382
return knn_indices_m, knn_dists_m

python/cuml/cuml/manifold/umap/umap.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,9 @@ class UMAP(Base, InteropMixin, CMajorInputTagMixin, SparseInputTagMixin):
522522
sparse array (preferably CSR/COO). This feature allows
523523
the precomputation of the KNN outside of UMAP
524524
and also allows the use of a custom distance function. This function
525-
should match the metric used to train the UMAP embeedings.
525+
should match the metric used to train the UMAP embeedings. For most efficient
526+
memory usage, the precomputed knn graph should be CPU-accessible arrays
527+
such as numpy arrays.
526528
random_state : int, RandomState instance or None, optional (default=None)
527529
random_state is the seed used by the random number generator during
528530
embedding initialization and during sampling used by the optimizer.
@@ -900,7 +902,9 @@ class UMAP(Base, InteropMixin, CMajorInputTagMixin, SparseInputTagMixin):
900902
the precomputation of the KNN outside of UMAP
901903
and also allows the use of a custom distance function. This function
902904
should match the metric used to train the UMAP embeedings.
903-
Takes precedence over the precomputed_knn parameter.
905+
Takes precedence over the precomputed_knn parameter. For most efficient
906+
memory usage, the precomputed knn graph should be CPU-accessible arrays
907+
such as numpy arrays.
904908
"""
905909
if len(X.shape) != 2:
906910
raise ValueError("Reshape your data: data should be two dimensional")
@@ -968,6 +972,7 @@ class UMAP(Base, InteropMixin, CMajorInputTagMixin, SparseInputTagMixin):
968972
knn_indices, knn_dists = extract_knn_graph(
969973
(knn_graph if knn_graph is not None else self.precomputed_knn),
970974
self._n_neighbors,
975+
mem_type=False, # mirrors the input graph mem type
971976
)
972977
if X_is_sparse:
973978
knn_indices = input_to_cuml_array(
@@ -1072,7 +1077,9 @@ class UMAP(Base, InteropMixin, CMajorInputTagMixin, SparseInputTagMixin):
10721077
the precomputation of the KNN outside of UMAP
10731078
and also allows the use of a custom distance function. This function
10741079
should match the metric used to train the UMAP embeedings.
1075-
Takes precedence over the precomputed_knn parameter.
1080+
Takes precedence over the precomputed_knn parameter. For most efficient
1081+
memory usage, the precomputed knn graph should be CPU-accessible arrays
1082+
such as numpy arrays.
10761083
"""
10771084
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph)
10781085
return self.embedding_

0 commit comments

Comments
 (0)