Skip to content

Commit c13c681

Browse files
committed
fix a bug with headDim 256 nvfp4-kv kernels
Signed-off-by: Perkz Zheng <[email protected]>
1 parent 588a2e8 commit c13c681

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,8 @@ struct KernelParams
685685
// The number of elements in 128B for Q.
686686
int32_t numEltsIn128BKv = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeKv);
687687
// The number of head elts (per token) in each block of shared memory (see above explanation).
688-
int32_t numEltsInClampedHeadDimKv = std::min(numEltsIn128BKv, maxHeadDimKv);
688+
// HeadDim will be split into multiple headDimStages (128) if maxHeadDimKv > 128.
689+
int32_t numEltsInClampedHeadDimKv = std::min({numEltsIn128BKv, maxHeadDimKv, 128});
689690

690691
// Do we have to transform K/V before MMA?
691692
bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ};

0 commit comments

Comments
 (0)