Skip to content

Commit 45ce73c

Browse files
committed
[Feat] HKV: optimize for small dim and format
1 parent 8c96d9c commit 45ce73c

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

include/merlin/core_kernels/lookup.cuh

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,53 @@ struct LaunchPipelineLookupV2 {
734734
}
735735
};
736736

737+
// Small-dimension TLP kernel (no Table dependency): one thread per key
738+
// Optimized for tiny value size (e.g., dim <= 2 for 4-byte V), small-to-mid
739+
// batch sizes
740+
template <typename K = uint64_t, typename V = float, typename S = uint32_t>
741+
__global__ void lookup_kernel_tlp_small_dim(
742+
Bucket<K, V, S>* buckets, const size_t buckets_num, const uint32_t dim,
743+
const K* __restrict keys, V* __restrict values, S* __restrict scores,
744+
bool* __restrict founds, const size_t n) {
745+
constexpr int BUCKET_SIZE = 128;
746+
constexpr int TILE_SIZE = 4;
747+
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
748+
const int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / g.size();
749+
// if (group_id >= n) return;
750+
const int rank = g.thread_rank();
751+
752+
const K find_key = keys[group_id];
753+
// const bool inactive = IS_RESERVED_KEY<K>(find_key);
754+
// unsigned int active_mask = g.ballot(!inactive);
755+
// if (active_mask == 0u) return;
756+
757+
size_t bkt_idx = 0;
758+
size_t start_idx = 0;
759+
Bucket<K, V, S>* bucket = get_key_position<K>(
760+
buckets, find_key, bkt_idx, start_idx, buckets_num, BUCKET_SIZE);
761+
762+
int key_pos = -1;
763+
int src_lane = -1;
764+
// Align start index to TILE_SIZE boundary to improve coalesced digest probing
765+
size_t start_idx_aligned = (start_idx & ~(static_cast<size_t>(TILE_SIZE) - 1));
766+
OccupyResult occupy_result = find_without_lock<K, V, S, TILE_SIZE>(
767+
g, bucket, find_key, static_cast<int>(start_idx_aligned), key_pos, src_lane,
768+
BUCKET_SIZE);
769+
770+
const V v_r = __ldg(reinterpret_cast<const V*>(bucket->vectors) + static_cast<size_t>(key_pos));
771+
const S s_r = *(bucket->scores(key_pos));
772+
773+
const bool found = (occupy_result == OccupyResult::DUPLICATE);
774+
if (rank == 0) {
775+
// dim == 1: copy single element directly, assume values/scores/founds are non-null
776+
founds[group_id] = found;
777+
if (found) {
778+
values[group_id] = v_r;
779+
scores[group_id] = s_r;
780+
}
781+
}
782+
}
783+
737784
template <typename ArchTag>
738785
struct LookupValueBufConfig;
739786

@@ -755,19 +802,46 @@ template <typename K, typename V, typename S = uint64_t,
755802
struct SelectPipelineLookupKernelWithIO {
756803
using ValueBufConfig = LookupValueBufConfig<ArchTag>;
757804

805+
// Helper overloads: try small-dimension TLP launch only for LookupKernelParams (FoundFunctorV1)
806+
template <typename ParamsT>
807+
static inline bool small_dim_tlp_try_launch(ParamsT&, uint32_t, cudaStream_t&) {
808+
return false; // default: not applicable
809+
}
810+
811+
template <typename KK, typename VV, typename SS>
812+
static inline bool small_dim_tlp_try_launch(
813+
LookupKernelParams<KK, VV, SS>& params,
814+
const uint32_t total_value_size,
815+
cudaStream_t& stream) {
816+
if (total_value_size <= 8) {
817+
constexpr int TILE_SIZE = 4;
818+
constexpr int BLOCK = 256; // must be multiple of TILE_SIZE
819+
constexpr int GROUPS_PER_BLOCK = BLOCK / TILE_SIZE;
820+
const int grid = static_cast<int>((params.n + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK);
821+
lookup_kernel_tlp_small_dim<KK, VV, SS><<<grid, BLOCK, 0, stream>>>(
822+
params.buckets, static_cast<size_t>(params.buckets_num),
823+
static_cast<uint32_t>(params.dim), params.keys, params.values,
824+
params.scores, params.found_functor.founds,
825+
static_cast<size_t>(params.n));
826+
return true;
827+
}
828+
return false;
829+
}
830+
758831
static inline uint32_t max_value_size() {
759832
return ValueBufConfig::size_pipeline_v1;
760833
}
761834

762835
template <template <typename, typename, typename> typename LookupKernelParams>
763836
static void select_kernel(LookupKernelParams<K, V, S>& params,
764837
cudaStream_t& stream) {
838+
// Small-dimension direct TLP path: dim*sizeof(V) <= 8
839+
const uint32_t total_value_size = static_cast<uint32_t>(params.dim * sizeof(V));
840+
if (small_dim_tlp_try_launch(params, total_value_size, stream)) return;
765841
constexpr int BUCKET_SIZE = 128;
766842
constexpr uint32_t buf_size_v1 = ValueBufConfig::size_pipeline_v1;
767843
constexpr uint32_t buf_size_v2 = ValueBufConfig::size_pipeline_v2;
768844

769-
uint32_t total_value_size = static_cast<uint32_t>(params.dim * sizeof(V));
770-
771845
if (params.scores == nullptr) {
772846
using CopyScore = CopyScoreEmpty<S, K, BUCKET_SIZE>;
773847
if (total_value_size <= buf_size_v1) {

0 commit comments

Comments
 (0)