Skip to content

feat: add euclidean one2many#188

Open
richyreachy wants to merge 7 commits intomainfrom
feat/euclidean_one2many
Open

feat: add euclidean one2many#188
richyreachy wants to merge 7 commits intomainfrom
feat/euclidean_one2many

Conversation

@richyreachy
Copy link
Collaborator

add euclidean one2many

@richyreachy richyreachy requested a review from iaojnh March 1, 2026 14:47
@greptile-apps
Copy link

greptile-apps bot commented Mar 1, 2026

Greptile Summary

This PR adds Euclidean one-to-many distance computation with SIMD optimizations (AVX2, AVX512F, AVX512FP16) for float32, float16, and int8 data types. However, multiple critical bugs exist in the new implementation files that will cause incorrect distance calculations:

  • euclidean_distance_batch_impl.h: Wrong FMA operation in AVX512F masked tail (line 81) computes dot product instead of squared difference; AVX2 switch statement (lines 128-148) missing dimension offsets
  • euclidean_distance_batch_impl_fp16.h: AVX512F implementation (line 139) computes dot product instead of squared distance
  • euclidean_distance_batch_impl_int8.h: Incorrect pointer arithmetic (line 74) and completely broken switch statement (lines 93-136) with wrong macro usage and array indices

These bugs were already identified in previous review threads and remain unfixed. Additionally:

  • Removed incorrect Hamming distance cases from SquaredEuclideanMetric::batch_distance()
  • Added batch_distance() method to EuclideanMetric
  • Minor formatting improvements across several files

Confidence Score: 0/5

  • This PR is NOT safe to merge - contains multiple critical bugs that will produce incorrect results
  • Score of 0 reflects multiple critical algorithmic errors in the core distance computation logic that will cause incorrect calculations across all three implementation files (fp32, fp16, int8). These bugs were previously identified but remain unfixed, and will result in wrong distance values being returned to callers.
  • All three implementation files require immediate attention: euclidean_distance_batch_impl.h, euclidean_distance_batch_impl_fp16.h, and euclidean_distance_batch_impl_int8.h

Important Files Changed

Filename Overview
src/ailego/math_batch/euclidean_distance_batch_impl.h New file with critical bugs in AVX512F masked tail handling (wrong FMA operation) and AVX2 switch statement (missing dim offsets)
src/ailego/math_batch/euclidean_distance_batch_impl_fp16.h New file with critical bug in AVX512F implementation (dot product instead of squared distance at line 139)
src/ailego/math_batch/euclidean_distance_batch_impl_int8.h New file with critical bugs in pointer arithmetic (line 74) and switch statement (lines 91-136: wrong macro usage with pointers and incorrect array indices)
src/ailego/math_batch/euclidean_distance_batch.h New file defining batch distance computation structure, looks correct but depends on buggy implementation files
src/ailego/math_batch/distance_batch.h Added Euclidean and SquaredEuclidean distance batch handlers using constexpr type checks
src/core/metric/euclidean_metric.cc Added batch_distance method to EuclideanMetric and removed incorrect Hamming distance cases from SquaredEuclideanMetric

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[BaseDistance::ComputeBatch] --> B{Distance Type?}
    B -->|Euclidean| C[EuclideanDistanceBatch::ComputeBatch]
    B -->|SquaredEuclidean| D[SquaredEuclideanDistanceBatch::ComputeBatch]
    B -->|Other| E[Default _ComputeBatch]
    
    D --> F{Data Type?}
    F -->|float| G[SquaredEuclideanDistanceBatchImpl float]
    F -->|int8_t| H[SquaredEuclideanDistanceBatchImpl int8_t]
    F -->|Float16| I[SquaredEuclideanDistanceBatchImpl Float16]
    F -->|Other| J[Fallback Implementation]
    
    G --> K{CPU Features?}
    K -->|AVX512F| L[compute_one_to_many_squared_euclidean_avx512f_fp32]
    K -->|AVX2| M[compute_one_to_many_squared_euclidean_avx2_fp32]
    K -->|None| J
    
    H --> N{CPU Features?}
    N -->|AVX2| O[compute_one_to_many_squared_euclidean_avx2_int8]
    N -->|None| J
    
    I --> P{CPU Features?}
    P -->|AVX512FP16| Q[compute_one_to_many_squared_euclidean_avx512fp16_fp16]
    P -->|AVX512F| R[compute_one_to_many_squared_euclidean_avx512f_fp16]
    P -->|AVX2| S[compute_one_to_many_squared_euclidean_avx2_fp16]
    P -->|None| J
    
    C --> T[Call SquaredEuclideanDistanceBatch]
    T --> U[Apply sqrt to results]
    
    style L fill:#f99,stroke:#333,stroke-width:2px
    style M fill:#f99,stroke:#333,stroke-width:2px
    style O fill:#f99,stroke:#333,stroke-width:2px
    style Q fill:#f99,stroke:#333,stroke-width:2px
    style R fill:#f99,stroke:#333,stroke-width:2px
    style S fill:#f99,stroke:#333,stroke-width:2px
Loading

Last reviewed commit: b766fb1

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +81 to +84
accs[i] = _mm512_mask3_fmadd_ps(
_mm512_mask_loadu_ps(zmm_undefined, mask, query + dim),
_mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i],
mask);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing query * ptrs[i] + accs[i] instead of (query - ptrs[i])^2 + accs[i]

Suggested change
accs[i] = _mm512_mask3_fmadd_ps(
_mm512_mask_loadu_ps(zmm_undefined, mask, query + dim),
_mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i],
mask);
__m512 q_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim);
__m512 d_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim);
__m512 diff = _mm512_sub_ps(q_vals, d_vals);
accs[i] = _mm512_mask_fmadd_ps(diff, mask, diff, accs[i]);

Comment on lines +128 to +148
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array indices should be offset by dim to process remaining elements correctly

Suggested change
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]);
switch (dimensionality - dim) {
case 7:
SSD_FP32_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]);
/* FALLTHRU */
case 6:
SSD_FP32_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]);
/* FALLTHRU */
case 5:
SSD_FP32_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]);
/* FALLTHRU */
case 4:
SSD_FP32_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]);
/* FALLTHRU */
case 3:
SSD_FP32_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]);
/* FALLTHRU */
case 2:
SSD_FP32_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]);
/* FALLTHRU */
case 1:
SSD_FP32_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]);
}

Comment on lines +136 to +140
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing q * data_regs[i] + accs[i] instead of (q - data_regs[i])^2 + accs[i]

Suggested change
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]);
}
for (size_t i = 0; i < dp_batch; ++i) {
data_regs[i] = _mm512_cvtph_ps(
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim)));
__m512 diff = _mm512_sub_ps(q, data_regs[i]);
accs[i] = _mm512_fmadd_ps(diff, diff, accs[i]);
}


if (dimensionality >= dim + 16) {
for (size_t i = 0; i < dp_batch; ++i) {
__m128i q = _mm_loadu_si128((const __m128i *)query + dim);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointer arithmetic on __m128i* reads from query + dim * 16 instead of query + dim

Suggested change
__m128i q = _mm_loadu_si128((const __m128i *)query + dim);
__m128i q = _mm_loadu_si128((const __m128i *)(query + dim));

Comment on lines +91 to +137
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing pointers instead of scalar values to SSD_INT8_GENERAL macro, and using wrong array indices (ptrs[14] instead of ptrs[i])

Suggested change
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]);
}
for (size_t i = 0; i < dp_batch; ++i) {
switch (dimensionality - dim) {
case 15:
SSD_INT8_GENERAL(query[dim + 14], ptrs[i][dim + 14], results[i]);
/* FALLTHRU */
case 14:
SSD_INT8_GENERAL(query[dim + 13], ptrs[i][dim + 13], results[i]);
/* FALLTHRU */
case 13:
SSD_INT8_GENERAL(query[dim + 12], ptrs[i][dim + 12], results[i]);
/* FALLTHRU */
case 12:
SSD_INT8_GENERAL(query[dim + 11], ptrs[i][dim + 11], results[i]);
/* FALLTHRU */
case 11:
SSD_INT8_GENERAL(query[dim + 10], ptrs[i][dim + 10], results[i]);
/* FALLTHRU */
case 10:
SSD_INT8_GENERAL(query[dim + 9], ptrs[i][dim + 9], results[i]);
/* FALLTHRU */
case 9:
SSD_INT8_GENERAL(query[dim + 8], ptrs[i][dim + 8], results[i]);
/* FALLTHRU */
case 8:
SSD_INT8_GENERAL(query[dim + 7], ptrs[i][dim + 7], results[i]);
/* FALLTHRU */
case 7:
SSD_INT8_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]);
/* FALLTHRU */
case 6:
SSD_INT8_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]);
/* FALLTHRU */
case 5:
SSD_INT8_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]);
/* FALLTHRU */
case 4:
SSD_INT8_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]);
/* FALLTHRU */
case 3:
SSD_INT8_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]);
/* FALLTHRU */
case 2:
SSD_INT8_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]);
/* FALLTHRU */
case 1:
SSD_INT8_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]);
}
}

@richyreachy
Copy link
Collaborator Author

@greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant