Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
cce66e6
Add pq_len=8 instances
enp1s0 Aug 3, 2025
7d14a39
Merge branch 'branch-25.10' into cagra-q-pq_len-8
enp1s0 Aug 14, 2025
43e6145
Update CAGRA-Q test
enp1s0 Aug 14, 2025
16321bc
Update CAGRA-Q distance kernel
enp1s0 Sep 3, 2025
a4e5050
Merge branch 'branch-25.10' into cagra-q-pq_len-8
enp1s0 Sep 3, 2025
bfdc2d4
Add DatasetBlockDim check
enp1s0 Sep 3, 2025
23c02e1
Update VPQ compute distance kernel
enp1s0 Sep 3, 2025
5b3a832
Merge branch 'branch-25.12' into cagra-q-pq_len-8
enp1s0 Oct 21, 2025
e0f629c
Add fp_8bit4
enp1s0 Oct 21, 2025
0da1aa2
Fix compilation error
enp1s0 Oct 21, 2025
62ba0ad
Add as_u32
enp1s0 Oct 21, 2025
7f9f614
Update VPQ
enp1s0 Oct 21, 2025
058abbb
Fix fp_8bit4 constructor
enp1s0 Oct 21, 2025
e64f8c6
Add sts for u32
enp1s0 Oct 21, 2025
77492dd
Add f8
enp1s0 Oct 21, 2025
4638eb2
Fix a bug
enp1s0 Oct 22, 2025
a64f264
Add native f8 support
enp1s0 Oct 22, 2025
77dbe73
Merge branch 'branch-25.12' into cagra-q-pq_len-8
enp1s0 Oct 22, 2025
60ba5e9
Fix VPQ init
enp1s0 Oct 22, 2025
f37a131
Update clock measure
enp1s0 Aug 26, 2025
3b3c20b
Add fp8x8
enp1s0 Oct 23, 2025
bc572e0
Fix a bug
enp1s0 Oct 24, 2025
ec275d4
Update 2, 4, 8 configs
enp1s0 Oct 24, 2025
a85d8a3
Fix a bug
enp1s0 Oct 26, 2025
d1f628c
Add F8 query support
enp1s0 Oct 27, 2025
e05dfb0
Fix query vec id calc
enp1s0 Oct 30, 2025
5f2b78f
Improve performance
enp1s0 Oct 30, 2025
484f9e6
Improve performance
enp1s0 Oct 31, 2025
88871c9
Merge branch 'cagra-q-pq_len-8' into cagra-q-pq_len-8-query-f8
enp1s0 Oct 31, 2025
76d820f
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Nov 2, 2025
e710c62
Merge branch 'cagra-q-pq_len-8' into cagra-q-pq_len-8-query-f8
enp1s0 Nov 2, 2025
6bcb0e6
Fix template switch
enp1s0 Nov 2, 2025
581eba1
Fix pq_val_config
enp1s0 Nov 2, 2025
f8a7a74
Merge branch 'cagra-q-pq_len-8' into cagra-q-pq_len-8-query-f8
enp1s0 Nov 2, 2025
a088cfd
Improve smem index calculation
enp1s0 Nov 3, 2025
e244ac1
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Nov 11, 2025
8f00f22
Merge branch 'cagra-q-pq_len-8' into cagra-q-pq_len-8-query-f8
enp1s0 Nov 11, 2025
12809f3
Update fp8 pack dtype
enp1s0 Nov 11, 2025
f91d041
Refactoring
enp1s0 Nov 11, 2025
7c8ecd4
Fix a bug
enp1s0 Nov 11, 2025
7b46115
Add EnableFP8 flag
enp1s0 Nov 11, 2025
d49959c
Fix a bug
enp1s0 Nov 11, 2025
ed906f7
Fix a bug in compute_distance_00_generate.py
enp1s0 Nov 12, 2025
c0e9ddd
Update VPQ instances
enp1s0 Nov 12, 2025
35259e3
Add `smem_dtype` option
enp1s0 Nov 12, 2025
77fbcf6
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Nov 12, 2025
7639d02
Remove unnecessary include
enp1s0 Nov 12, 2025
710232a
Remove unnecessary files
enp1s0 Nov 12, 2025
daf6f89
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Nov 12, 2025
d183be6
Remove unnecessary file
enp1s0 Nov 12, 2025
e7d3d42
Revert "Remove unnecessary file"
enp1s0 Nov 12, 2025
24089bc
Fix Copyright
enp1s0 Nov 12, 2025
fd4530e
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 108 additions & 36 deletions cpp/CMakeLists.txt

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ enum class search_algo {

enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 };

enum class internal_dtype { F16 = 0, E5M2 = 1, AUTO = 100 };

struct search_params : cuvs::neighbors::search_params {
/** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/
size_t max_queries = 0;
Expand Down Expand Up @@ -277,6 +279,10 @@ struct search_params : cuvs::neighbors::search_params {
* negative, in which case the filtering rate is automatically calculated.
*/
float filtering_rate = -1.0;

/** Data type of the query vector and codebook table on shared memory. Currently, only VPQ
* supports FP8. **/
internal_dtype smem_dtype = internal_dtype::AUTO;
};

/**
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ void search_main(raft::resources const& res,
// Dispatch search parameters based on the dataset kind.
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data());
strided_dset != nullptr) {
if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::AUTO ||
params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) {
RAFT_LOG_WARN("In this search mode, only AUTO or F16 are supported as the smem_dtype.");
}
// Search using a plain (strided) row-major dataset
RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded ||
index.dataset_norms().has_value(),
Expand Down
1,880 changes: 1,600 additions & 280 deletions cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh

Large diffs are not rendered by default.

336 changes: 312 additions & 24 deletions cpp/src/neighbors/detail/cagra/compute_distance.cu

Large diffs are not rendered by default.

46 changes: 28 additions & 18 deletions cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
"""

mxdim_team = [(128, 8), (256, 16), (512, 32)]
vpq_2_4_mxdim_team = [(64, 4), (128, 8), (256, 16), (512, 32)]
vpq_8_mxdim_team = [(128, 4), (256, 8), (512, 16), (1024, 32)]
vrq_mxdim_team = [(64, 4), (128, 8), (256, 16), (512, 32)]
# mxdim_team = [(64, 8), (128, 16), (256, 32)]
# mxdim_team = [(32, 8), (64, 16), (128, 32)]

pq_bits = [8]
pq_lens = [2, 4]
pq_lens = [2, 4, 8]

# rblock = [(256, 4), (512, 2), (1024, 1)]
# rcandidates = [32]
Expand Down Expand Up @@ -76,26 +79,33 @@
f.write(template.format(includes=includes, content=content))
cmake_list.append(f" src/neighbors/detail/cagra/{path}")

# CAGRA-Q
for code_book_t in code_book_types:
for pq_len in pq_lens:
for pq_len in pq_lens:
vpq_mxdim_team = (
vpq_8_mxdim_team if pq_len == 8 else vpq_2_4_mxdim_team
)
for mxdim, team in vpq_mxdim_team:
# CAGRA-Q
for code_book_t in code_book_types:
for pq_bit in pq_bits:
for metric in ["L2Expanded"]:
path = f"compute_distance_vpq_{metric}_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{pq_len}subd_{code_book_t}.cu"
includes = '#include "compute_distance_vpq-impl.cuh"'
params = f"{metric_prefix}{metric}, {team}, {mxdim}, {pq_bit}, {pq_len}, {code_book_t}, {data_t}, {idx_t}, {distance_t}"
spec = f"vpq_descriptor_spec<{params}>"
content = f"""template struct {spec};"""
specs.append(spec)
with open(path, "w") as f:
f.write(
template.format(
includes=includes, content=content
)
)
cmake_list.append(
f" src/neighbors/detail/cagra/{path}"
for enable_fp8 in ["true", "false"]:
path = f"compute_distance_vpq_{metric}_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{pq_len}subd_{code_book_t}_fp8{enable_fp8}.cu"
includes = (
'#include "compute_distance_vpq-impl.cuh"'
)
params = f"{metric_prefix}{metric}, {team}, {mxdim}, {pq_bit}, {pq_len}, {code_book_t}, {data_t}, {idx_t}, {distance_t}, {enable_fp8}"
spec = f"vpq_descriptor_spec<{params}>"
content = f"""template struct {spec};"""
specs.append(spec)
with open(path, "w") as f:
f.write(
template.format(
includes=includes, content=content
)
)
cmake_list.append(
f" src/neighbors/detail/cagra/{path}"
)

# CAGRA (Binary Hamming distance)
for mxdim, team in mxdim_team:
Expand Down
Loading
Loading