Skip to content

Commit 4a015e3

Browse files
jhaotingcEmmaQiaoChchzblychlowsfer
authored
* [None][infra] Pin the version for triton to 3.3.1 (NVIDIA#6508) Signed-off-by: qqiao <[email protected]> * [None][infra] Pin the version for triton to 3.3.1 (NVIDIA#6508) (NVIDIA#6519) (NVIDIA#6549) Signed-off-by: Yanchao Lu <[email protected]> * [fix]: use safeInitRowMax instead of fp32_lowest to avoid NaN (NVIDIA#7087) Signed-off-by: Yao Yao <[email protected]> * [None][fix] Fix a numerical stability issue for XQA with spec dec Signed-off-by: Yao Yao <[email protected]> * fix typo Signed-off-by: Jhao-Ting Chen <[email protected]> * fix precompiled multi_query_token kernel not having is_fp8_out hash key (NVIDIA#6279) Signed-off-by: Jhao-Ting Chen <[email protected]> * [fix] Fix missing fields in xqa kernel cache key (NVIDIA#6282) Signed-off-by: Yao Yao <[email protected]> --------- Signed-off-by: qqiao <[email protected]> Signed-off-by: Yanchao Lu <[email protected]> Signed-off-by: Yao Yao <[email protected]> Signed-off-by: Jhao-Ting Chen <[email protected]> Co-authored-by: Emma Qiao <[email protected]> Co-authored-by: Yanchao Lu <[email protected]> Co-authored-by: Yao Yao <[email protected]>
1 parent 9270041 commit 4a015e3

File tree

7 files changed

+47
-14
lines changed

7 files changed

+47
-14
lines changed

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,8 @@ CUBIN_EXPORT __global__
630630
#ifdef NDEBUG
631631
#if !OPTIMIZE_FOR_LATENCY
632632
__launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2)
633+
#else
634+
__launch_bounds__(128 * 3)
633635
#endif
634636
#else
635637
__launch_bounds__(128 * 3, 1)
@@ -999,7 +1001,7 @@ CUBIN_EXPORT __global__
9991001
if (threadIdx.x < smem.gemm1AccColMax.size)
10001002
{
10011003
auto const idx = threadIdx.x;
1002-
smem.gemm1AccColMax[idx] = mha::numeric_limits<float>::lowest();
1004+
smem.gemm1AccColMax[idx] = safeInitRowMax;
10031005
smem.gemm1AccColSum[idx] = 0;
10041006
}
10051007
smem.gemm1WarpGrpBar.arrive_and_wait();
@@ -1075,6 +1077,23 @@ CUBIN_EXPORT __global__
10751077
}
10761078
}
10771079
smem.gemm1WarpGrpBar.arrive_and_wait();
1080+
#else
1081+
if (blockIdx.y == 1 && threadIdx.x == 0)
1082+
{
1083+
printf("rowMax:\n");
1084+
for (int i = 0; i < ctaNbQHeads; i++)
1085+
{
1086+
printf("%f, ", smem.xRowMax[idxXBuf][i]);
1087+
}
1088+
printf("\n");
1089+
printf("rowSum:\n");
1090+
for (int i = 0; i < ctaNbQHeads; i++)
1091+
{
1092+
printf("%f, ", smem.xRowSum[idxXBuf][i]);
1093+
}
1094+
printf("\n");
1095+
}
1096+
smem.gemm1WarpGrpBar.arrive_and_wait();
10781097
#endif
10791098
#endif
10801099

@@ -1887,15 +1906,15 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
18871906
uint32_t const globalRow = tileStartRow + row;
18881907
if (globalRow >= cacheSeqLen)
18891908
{
1890-
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
1909+
acc(m, n)(i, j) = safeInitRowMax;
18911910
continue;
18921911
}
18931912
if (globalRow >= maskStartRow)
18941913
{
18951914
uint32_t const maskRow = globalRow - maskStartRow;
18961915
if ((bit_mask >> maskRow) == 0)
18971916
{
1898-
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
1917+
acc(m, n)(i, j) = safeInitRowMax;
18991918
}
19001919
}
19011920
}
@@ -2009,7 +2028,7 @@ __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32
20092028
#pragma unroll
20102029
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
20112030
{
2012-
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
2031+
acc(m, n)(i, j) = safeInitRowMax;
20132032
}
20142033
}
20152034
}
@@ -2302,9 +2321,9 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
23022321
{
23032322
uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j;
23042323
assert((col < nbValidCols) == bool(endMask & (1ULL << col)));
2305-
if (((mask >> col) & 1) == 0)
2324+
if ((mask & (1ULL << col)) == 0)
23062325
{
2307-
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
2326+
acc(m, n)(i, j) = safeInitRowMax;
23082327
}
23092328
}
23102329
}
@@ -2332,7 +2351,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uin
23322351
#pragma unroll
23332352
for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++)
23342353
{
2335-
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
2354+
acc(m, n)(i, j) = safeInitRowMax;
23362355
}
23372356
}
23382357
}

cpp/kernels/xqa/utils.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
#include <cuda_fp8.h>
3131

3232
inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E)
33-
inline constexpr float safeInitRowMax = -1e+30F;
33+
// we used an optimization where exp(x-rowMax) is computed as:
34+
/* bias = rowMax * log2e // shared for the whole row
35+
exp(x-rowMax) = exp2f(x * log2e - bias)
36+
*/
37+
// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and x is too large. For
38+
// this reason, don't set safeInitRowMax with a huge absolute value.
39+
inline constexpr float safeInitRowMax = -1e+5F;
3440
inline constexpr int32_t kBAD_PAGE_INDEX = -1;
3541
__constant__ constexpr float kE4M3_MAX = 448.F;
3642

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
5252
unsigned int kernel_m_tilesize
5353
= getKernelMTileSize(num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit);
5454

55+
// precompiled XQA does not use is_fp8_output as hashing key
5556
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
5657
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
57-
xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
58+
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false,
59+
isXqaJit ? std::optional(xqaParams.position_embedding_type) : std::nullopt};
5860
}
5961

6062
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ struct XQAKernelRuntimeHashKey
6767
bool paged_kv_cache;
6868
bool multi_query_tokens;
6969
bool is_fp8_output;
70+
std::optional<PositionEmbeddingType> position_embedding_type;
7071

7172
bool operator==(XQAKernelRuntimeHashKey const& other) const
7273
{
7374
return kv_data_type == other.kv_data_type && head_size == other.head_size
7475
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
7576
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
7677
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache
77-
&& is_fp8_output == other.is_fp8_output;
78+
&& is_fp8_output == other.is_fp8_output && position_embedding_type == other.position_embedding_type;
7879
}
7980
};
8081

@@ -103,6 +104,8 @@ struct XQAKernelRuntimeHasher
103104
key ^= s.multi_query_tokens;
104105
key <<= 1; // 51
105106
key ^= s.is_fp8_output;
107+
key <<= 8;
108+
key ^= static_cast<int8_t>(s.position_embedding_type.value_or(static_cast<PositionEmbeddingType>(-1)));
106109
return key;
107110
}
108111
};

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ using ::tensorrt_llm::kernels::XQAKernelMetaInfo;
3737
XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const& kernelMeta)
3838
{
3939
return {kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth, kernelMeta.mNumQHeadsOverKV,
40-
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens,
41-
0 /* xqa jit param is_fp8_output */};
40+
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens, false,
41+
std::nullopt};
4242
}
4343

4444
} // anonymous namespace

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class XQAKernelList
9797
}
9898
XQAKernelRuntimeHashKey hash_key{kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth,
9999
kernelMeta.mNumQHeadsOverKV, kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache,
100-
kernelMeta.mMultiQueryTokens, 0 /* xqa jit param is_fp8_output */};
100+
kernelMeta.mMultiQueryTokens, false, std::nullopt};
101101

102102
mFunctions.insert(std::make_pair(hash_key, funcInfo));
103103
}
@@ -124,10 +124,12 @@ class XQAKernelList
124124
m_tilesize = num_q_heads_over_kv;
125125
}
126126

127+
// precompiled XQA does not support param is_fp8_output in hash key
127128
XQAKernelRuntimeHashKey hash_key
128129
= {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
129130
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
130-
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
131+
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0, /* xqa jit param is_fp8_output */
132+
std::nullopt};
131133
auto const findIter = mFunctions.find(hash_key);
132134
return findIter != mFunctions.end();
133135
}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ ninja
5959
etcd3
6060
blake3
6161
llguidance==0.7.29
62+
triton==3.3.1; platform_machine == "x86_64"

0 commit comments

Comments
 (0)