Skip to content

Commit 588a2e8

Browse files
committed
update trtllm-gen nvfp4 kernels with better performance
Signed-off-by: Perkz Zheng <[email protected]> update nvfp4 kv cache trtllm-gen kernels && fix several bugs Signed-off-by: Perkz Zheng <[email protected]>
1 parent 51bf716 commit 588a2e8

File tree

1,950 files changed

+6139
-5587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,950 files changed

+6139
-5587
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,13 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
114114
auto srcPtr = computeBlockPointer(src, pools, poolIdx);
115115
auto dstPtr = computeBlockPointer(dst, pools, poolIdx);
116116

117+
// Does it contain block scales?
118+
auto containsBlockScales = pools[poolIdx].containsBlockScales;
119+
117120
// If no partial tokens or if the dataType is not supported for partial copy, copy entire block.
121+
// Note that nvfp4 kv cache SFs use an interleaved layout, so we need to copy the entire block.
118122
if (numTokensToCopy <= 0 || srcPtr->getDataType() == nvinfer1::DataType::kINT4
119-
|| srcPtr->getDataType() == nvinfer1::DataType::kFP4)
123+
|| srcPtr->getDataType() == nvinfer1::DataType::kFP4 || containsBlockScales)
120124
{
121125
// For partial copy not implemented with these data types,
122126
// just do a full copy.

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,11 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
189189
tllmRunnerParams.attentionSinksPtr = runnerParams.attentionSinksPtr;
190190
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
191191
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
192+
// Attention scales device pointers (only fp8 kernels need to load scales from the device memory).
192193
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(runnerParams.scaleBmm2Ptr);
193-
// TRTLLM-GEN kernels always use the Log2 scale
194-
tllmRunnerParams.scaleSoftmaxLog2Ptr
195-
= reinterpret_cast<float const*>(runnerParams.scaleBmm1Ptr + kIdxScaleSoftmaxLog2Ptr);
194+
tllmRunnerParams.scaleSoftmaxLog2Ptr = runnerParams.scaleBmm1Ptr
195+
? reinterpret_cast<float const*>(runnerParams.scaleBmm1Ptr + kIdxScaleSoftmaxLog2Ptr)
196+
: nullptr;
196197
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(kvPageIdxPtr);
197198
tllmRunnerParams.oSfScalePtr = runnerParams.oSfScalePtr;
198199
tllmRunnerParams.oPtr = runnerParams.outputPtr;
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a74c90bed8cdfc61d4d30985f0a037b948a845387af20641313e17b1892c830b
3-
size 612398
2+
oid sha256:331aaf5e84f39f9ce4940fce18d646701f80caf6681d8ba1244934171baf9d03
3+
size 616196
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:47b29936b5167f32d44959bdbbdb8943a3b4152c0f42860fbcafbd27fbe930d4
3-
size 547072
2+
oid sha256:e93cb23f1ee61233c61091dc880258c59fa006abb5950cc6c8e1a99da2537845
3+
size 551858
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:64a5d40ff29adb68f36625bd0f0fba00b347fc0e20c7acf901eea8b97919a1bf
3-
size 601346
2+
oid sha256:748f8edf49b35d4c0502d3a292f11a53673d224539f8a94e2f9724bf17b8502b
3+
size 605146
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:fab341ce70c852b81be9d942b033a133cd803f2f095f3fad59dfb65e03292929
3-
size 536022
2+
oid sha256:e4ba7f26e6cb3e11b76321d6539bf2e3d194908058ae296a5b4c3ecb36fdfdf3
3+
size 540806
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:9a96b284f3fd4745012dffb728a2659b4de6e79d8c5c47efa375d7b3859d00bc
3-
size 594912
2+
oid sha256:799f27d110ed5c5b76d30c39a945d1bf8a28078be2a9dbf35db18a27a8f608dd
3+
size 466054
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:0cdbd895f8ce46c6432bf8eb0d6682c247ba3a79ac84f5d3f37bc00ae626ca74
3-
size 554678
2+
oid sha256:a2e8a5d62a02d3ba248b1a19bfbf9d0cd11674705283022033f5243b98ee72cd
3+
size 432382
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:7759c5d27b868e46921e41870c6ea9cd2f2bb96b46f0cc5308c6890689296dcd
3-
size 584176
2+
oid sha256:a21b45df44b576b61dff0cd0e89caeb324971601c3b458cfc72f32640a76eee4
3+
size 456106
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:cde6a4e81280191ad7ccbcaeb0ffcc3c0aab9e82061b6c5aa71c95e9cce70a27
3-
size 550258
2+
oid sha256:281d59d446597519bb0fdd6fd5b46cc9a69fe5ede08fb6b46426cef8bf5b1327
3+
size 427984

0 commit comments

Comments
 (0)