Conversation
Greptile SummaryThis PR adds Euclidean one-to-many distance computation with SIMD optimizations (AVX2, AVX512F, AVX512FP16) for float32, float16, and int8 data types. However, multiple critical bugs exist in the new implementation files that will cause incorrect distance calculations:
These bugs were already identified in previous review threads and remain unfixed. Additionally:
Confidence Score: 0/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[BaseDistance::ComputeBatch] --> B{Distance Type?}
B -->|Euclidean| C[EuclideanDistanceBatch::ComputeBatch]
B -->|SquaredEuclidean| D[SquaredEuclideanDistanceBatch::ComputeBatch]
B -->|Other| E[Default _ComputeBatch]
D --> F{Data Type?}
F -->|float| G[SquaredEuclideanDistanceBatchImpl float]
F -->|int8_t| H[SquaredEuclideanDistanceBatchImpl int8_t]
F -->|Float16| I[SquaredEuclideanDistanceBatchImpl Float16]
F -->|Other| J[Fallback Implementation]
G --> K{CPU Features?}
K -->|AVX512F| L[compute_one_to_many_squared_euclidean_avx512f_fp32]
K -->|AVX2| M[compute_one_to_many_squared_euclidean_avx2_fp32]
K -->|None| J
H --> N{CPU Features?}
N -->|AVX2| O[compute_one_to_many_squared_euclidean_avx2_int8]
N -->|None| J
I --> P{CPU Features?}
P -->|AVX512FP16| Q[compute_one_to_many_squared_euclidean_avx512fp16_fp16]
P -->|AVX512F| R[compute_one_to_many_squared_euclidean_avx512f_fp16]
P -->|AVX2| S[compute_one_to_many_squared_euclidean_avx2_fp16]
P -->|None| J
C --> T[Call SquaredEuclideanDistanceBatch]
T --> U[Apply sqrt to results]
style L fill:#f99,stroke:#333,stroke-width:2px
style M fill:#f99,stroke:#333,stroke-width:2px
style O fill:#f99,stroke:#333,stroke-width:2px
style Q fill:#f99,stroke:#333,stroke-width:2px
style R fill:#f99,stroke:#333,stroke-width:2px
style S fill:#f99,stroke:#333,stroke-width:2px
Last reviewed commit: b766fb1 |
| accs[i] = _mm512_mask3_fmadd_ps( | ||
| _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim), | ||
| _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i], | ||
| mask); |
There was a problem hiding this comment.
Computing query * ptrs[i] + accs[i] instead of (query - ptrs[i])^2 + accs[i]
| accs[i] = _mm512_mask3_fmadd_ps( | |
| _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim), | |
| _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim), accs[i], | |
| mask); | |
| __m512 q_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, query + dim); | |
| __m512 d_vals = _mm512_mask_loadu_ps(zmm_undefined, mask, ptrs[i] + dim); | |
| __m512 diff = _mm512_sub_ps(q_vals, d_vals); | |
| accs[i] = _mm512_mask_fmadd_ps(diff, mask, diff, accs[i]); |
| switch (dimensionality - dim) { | ||
| case 7: | ||
| SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]); | ||
| /* FALLTHRU */ | ||
| case 6: | ||
| SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]); | ||
| /* FALLTHRU */ | ||
| case 5: | ||
| SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]); | ||
| /* FALLTHRU */ | ||
| case 4: | ||
| SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]); | ||
| /* FALLTHRU */ | ||
| case 3: | ||
| SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]); | ||
| /* FALLTHRU */ | ||
| case 2: | ||
| SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]); | ||
| /* FALLTHRU */ | ||
| case 1: | ||
| SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]); |
There was a problem hiding this comment.
Array indices should be offset by dim to process remaining elements correctly
| switch (dimensionality - dim) { | |
| case 7: | |
| SSD_FP32_GENERAL(query[6], ptrs[i][6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_FP32_GENERAL(query[5], ptrs[i][5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_FP32_GENERAL(query[4], ptrs[i][4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_FP32_GENERAL(query[3], ptrs[i][3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_FP32_GENERAL(query[2], ptrs[i][2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_FP32_GENERAL(query[1], ptrs[i][1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_FP32_GENERAL(query[0], ptrs[i][0], results[i]); | |
| switch (dimensionality - dim) { | |
| case 7: | |
| SSD_FP32_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_FP32_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_FP32_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_FP32_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_FP32_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_FP32_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_FP32_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]); | |
| } |
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| data_regs[i] = _mm512_cvtph_ps( | ||
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | ||
| accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); | ||
| } |
There was a problem hiding this comment.
Computing q * data_regs[i] + accs[i] instead of (q - data_regs[i])^2 + accs[i]
| for (size_t i = 0; i < dp_batch; ++i) { | |
| data_regs[i] = _mm512_cvtph_ps( | |
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | |
| accs[i] = _mm512_fmadd_ps(q, data_regs[i], accs[i]); | |
| } | |
| for (size_t i = 0; i < dp_batch; ++i) { | |
| data_regs[i] = _mm512_cvtph_ps( | |
| _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptrs[i] + dim))); | |
| __m512 diff = _mm512_sub_ps(q, data_regs[i]); | |
| accs[i] = _mm512_fmadd_ps(diff, diff, accs[i]); | |
| } |
|
|
||
| if (dimensionality >= dim + 16) { | ||
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| __m128i q = _mm_loadu_si128((const __m128i *)query + dim); |
There was a problem hiding this comment.
Pointer arithmetic on __m128i* reads from query + dim * 16 instead of query + dim
| __m128i q = _mm_loadu_si128((const __m128i *)query + dim); | |
| __m128i q = _mm_loadu_si128((const __m128i *)(query + dim)); |
| for (size_t i = 0; i < dp_batch; ++i) { | ||
| switch (dimensionality - dim) { | ||
| case 15: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 14: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]); | ||
| /* FALLTHRU */ | ||
| case 13: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 12: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 11: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]); | ||
| /* FALLTHRU */ | ||
| case 10: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 9: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 8: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 7: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 6: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 5: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 4: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 3: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 2: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]); | ||
| /* FALLTHRU */ | ||
| case 1: | ||
| SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]); | ||
| } |
There was a problem hiding this comment.
Passing pointers instead of scalar values to SSD_INT8_GENERAL macro, and using wrong array indices (ptrs[14] instead of ptrs[i])
| for (size_t i = 0; i < dp_batch; ++i) { | |
| switch (dimensionality - dim) { | |
| case 15: | |
| SSD_INT8_GENERAL(query + dim, ptrs[14] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 14: | |
| SSD_INT8_GENERAL(query + dim, ptrs[13 + dim], results[i]); | |
| /* FALLTHRU */ | |
| case 13: | |
| SSD_INT8_GENERAL(query + dim, ptrs[12] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 12: | |
| SSD_INT8_GENERAL(query + dim, ptrs[11] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 11: | |
| SSD_INT8_GENERAL(query + dim, ptrs[10 + dim], results[i]); | |
| /* FALLTHRU */ | |
| case 10: | |
| SSD_INT8_GENERAL(query + dim, ptrs[9] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 9: | |
| SSD_INT8_GENERAL(query + dim, ptrs[8] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 8: | |
| SSD_INT8_GENERAL(query + dim, ptrs[7] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 7: | |
| SSD_INT8_GENERAL(query + dim, ptrs[6] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_INT8_GENERAL(query + dim, ptrs[5] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_INT8_GENERAL(query + dim, ptrs[4] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_INT8_GENERAL(query + dim, ptrs[3] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_INT8_GENERAL(query + dim, ptrs[2] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_INT8_GENERAL(query + dim, ptrs[1] + dim, results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_INT8_GENERAL(query + dim, ptrs[0] + dim, results[i]); | |
| } | |
| for (size_t i = 0; i < dp_batch; ++i) { | |
| switch (dimensionality - dim) { | |
| case 15: | |
| SSD_INT8_GENERAL(query[dim + 14], ptrs[i][dim + 14], results[i]); | |
| /* FALLTHRU */ | |
| case 14: | |
| SSD_INT8_GENERAL(query[dim + 13], ptrs[i][dim + 13], results[i]); | |
| /* FALLTHRU */ | |
| case 13: | |
| SSD_INT8_GENERAL(query[dim + 12], ptrs[i][dim + 12], results[i]); | |
| /* FALLTHRU */ | |
| case 12: | |
| SSD_INT8_GENERAL(query[dim + 11], ptrs[i][dim + 11], results[i]); | |
| /* FALLTHRU */ | |
| case 11: | |
| SSD_INT8_GENERAL(query[dim + 10], ptrs[i][dim + 10], results[i]); | |
| /* FALLTHRU */ | |
| case 10: | |
| SSD_INT8_GENERAL(query[dim + 9], ptrs[i][dim + 9], results[i]); | |
| /* FALLTHRU */ | |
| case 9: | |
| SSD_INT8_GENERAL(query[dim + 8], ptrs[i][dim + 8], results[i]); | |
| /* FALLTHRU */ | |
| case 8: | |
| SSD_INT8_GENERAL(query[dim + 7], ptrs[i][dim + 7], results[i]); | |
| /* FALLTHRU */ | |
| case 7: | |
| SSD_INT8_GENERAL(query[dim + 6], ptrs[i][dim + 6], results[i]); | |
| /* FALLTHRU */ | |
| case 6: | |
| SSD_INT8_GENERAL(query[dim + 5], ptrs[i][dim + 5], results[i]); | |
| /* FALLTHRU */ | |
| case 5: | |
| SSD_INT8_GENERAL(query[dim + 4], ptrs[i][dim + 4], results[i]); | |
| /* FALLTHRU */ | |
| case 4: | |
| SSD_INT8_GENERAL(query[dim + 3], ptrs[i][dim + 3], results[i]); | |
| /* FALLTHRU */ | |
| case 3: | |
| SSD_INT8_GENERAL(query[dim + 2], ptrs[i][dim + 2], results[i]); | |
| /* FALLTHRU */ | |
| case 2: | |
| SSD_INT8_GENERAL(query[dim + 1], ptrs[i][dim + 1], results[i]); | |
| /* FALLTHRU */ | |
| case 1: | |
| SSD_INT8_GENERAL(query[dim + 0], ptrs[i][dim + 0], results[i]); | |
| } | |
| } |
|
@greptile |
add euclidean one2many