@@ -734,6 +734,53 @@ struct LaunchPipelineLookupV2 {
734734 }
735735};
736736
737+ // Small-dimension TLP kernel (no Table dependency): one thread per key
738+ // Optimized for tiny value size (e.g., dim <= 2 for 4-byte V), small-to-mid
739+ // batch sizes
740+ template <typename K = uint64_t , typename V = float , typename S = uint32_t >
741+ __global__ void lookup_kernel_tlp_small_dim (
742+ Bucket<K, V, S>* buckets, const size_t buckets_num, const uint32_t dim,
743+ const K* __restrict keys, V* __restrict values, S* __restrict scores,
744+ bool * __restrict founds, const size_t n) {
745+ constexpr int BUCKET_SIZE = 128 ;
746+ constexpr int TILE_SIZE = 4 ;
747+ auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block ());
748+ const int group_id = (blockIdx .x * blockDim .x + threadIdx .x ) / g.size ();
749+ // if (group_id >= n) return;
750+ const int rank = g.thread_rank ();
751+
752+ const K find_key = keys[group_id];
753+ // const bool inactive = IS_RESERVED_KEY<K>(find_key);
754+ // unsigned int active_mask = g.ballot(!inactive);
755+ // if (active_mask == 0u) return;
756+
757+ size_t bkt_idx = 0 ;
758+ size_t start_idx = 0 ;
759+ Bucket<K, V, S>* bucket = get_key_position<K>(
760+ buckets, find_key, bkt_idx, start_idx, buckets_num, BUCKET_SIZE);
761+
762+ int key_pos = -1 ;
763+ int src_lane = -1 ;
764+ // Align start index to TILE_SIZE boundary to improve coalesced digest probing
765+ size_t start_idx_aligned = (start_idx & ~(static_cast <size_t >(TILE_SIZE) - 1 ));
766+ OccupyResult occupy_result = find_without_lock<K, V, S, TILE_SIZE>(
767+ g, bucket, find_key, static_cast <int >(start_idx_aligned), key_pos, src_lane,
768+ BUCKET_SIZE);
769+
770+ const V v_r = __ldg (reinterpret_cast <const V*>(bucket->vectors ) + static_cast <size_t >(key_pos));
771+ const S s_r = *(bucket->scores (key_pos));
772+
773+ const bool found = (occupy_result == OccupyResult::DUPLICATE);
774+ if (rank == 0 ) {
775+ // dim == 1: copy single element directly, assume values/scores/founds are non-null
776+ founds[group_id] = found;
777+ if (found) {
778+ values[group_id] = v_r;
779+ scores[group_id] = s_r;
780+ }
781+ }
782+ }
783+
737784template <typename ArchTag>
738785struct LookupValueBufConfig ;
739786
@@ -755,19 +802,46 @@ template <typename K, typename V, typename S = uint64_t,
755802struct SelectPipelineLookupKernelWithIO {
756803 using ValueBufConfig = LookupValueBufConfig<ArchTag>;
757804
805+ // Helper overloads: try small-dimension TLP launch only for LookupKernelParams (FoundFunctorV1)
806+ template <typename ParamsT>
807+ static inline bool small_dim_tlp_try_launch (ParamsT&, uint32_t , cudaStream_t&) {
808+ return false ; // default: not applicable
809+ }
810+
811+ template <typename KK, typename VV, typename SS>
812+ static inline bool small_dim_tlp_try_launch (
813+ LookupKernelParams<KK, VV, SS>& params,
814+ const uint32_t total_value_size,
815+ cudaStream_t& stream) {
816+ if (total_value_size <= 8 ) {
817+ constexpr int TILE_SIZE = 4 ;
818+ constexpr int BLOCK = 256 ; // must be multiple of TILE_SIZE
819+ constexpr int GROUPS_PER_BLOCK = BLOCK / TILE_SIZE;
820+ const int grid = static_cast <int >((params.n + GROUPS_PER_BLOCK - 1 ) / GROUPS_PER_BLOCK);
821+ lookup_kernel_tlp_small_dim<KK, VV, SS><<<grid, BLOCK, 0 , stream>>> (
822+ params.buckets , static_cast <size_t >(params.buckets_num ),
823+ static_cast <uint32_t >(params.dim ), params.keys , params.values ,
824+ params.scores , params.found_functor .founds ,
825+ static_cast <size_t >(params.n ));
826+ return true ;
827+ }
828+ return false ;
829+ }
830+
758831 static inline uint32_t max_value_size () {
759832 return ValueBufConfig::size_pipeline_v1;
760833 }
761834
762835 template <template <typename , typename , typename > typename LookupKernelParams>
763836 static void select_kernel (LookupKernelParams<K, V, S>& params,
764837 cudaStream_t& stream) {
838+ // Small-dimension direct TLP path: dim*sizeof(V) <= 8
839+ const uint32_t total_value_size = static_cast <uint32_t >(params.dim * sizeof (V));
840+ if (small_dim_tlp_try_launch (params, total_value_size, stream)) return ;
765841 constexpr int BUCKET_SIZE = 128 ;
766842 constexpr uint32_t buf_size_v1 = ValueBufConfig::size_pipeline_v1;
767843 constexpr uint32_t buf_size_v2 = ValueBufConfig::size_pipeline_v2;
768844
769- uint32_t total_value_size = static_cast <uint32_t >(params.dim * sizeof (V));
770-
771845 if (params.scores == nullptr ) {
772846 using CopyScore = CopyScoreEmpty<S, K, BUCKET_SIZE>;
773847 if (total_value_size <= buf_size_v1) {
0 commit comments