Skip to content

Commit 895581b

Browse files
committed
add draft token tree runtime
Signed-off-by: Yue Weng <[email protected]>
1 parent 15ceba8 commit 895581b

File tree

25 files changed

+1907
-626
lines changed

25 files changed

+1907
-626
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
3+
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include <cuda_runtime_api.h>
19+
20+
#include <cuda_bf16.h>
21+
#include <cuda_fp16.h>
22+
23+
#ifdef ENABLE_FP8
24+
#include <cuda_fp8.h>
25+
#endif
26+
27+
#include "draftTokenTreeKernels.h"
28+
#include "tensorrt_llm/common/assert.h"
29+
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
30+
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
31+
#include "tensorrt_llm/common/cudaUtils.h"
32+
33+
using namespace tensorrt_llm::common;
34+
35+
namespace tensorrt_llm
36+
{
37+
namespace kernels
38+
{
39+
40+
__global__ void extractRealDraftTokensKernel(int const curDraftIdx, int const batchSize, int const maxDraftLen,
41+
int const maxTotalDraftTokens, int const maxTopK, int const numTokensExpandThisLayer,
42+
int* tokensGatherIdxForDrafterModel, int* topKList, int* draftTokensIndicesCumsum, int64_t* newDraftTokens,
43+
int64_t* draftTokensBuffer)
44+
{
45+
// curDraftIdx: int
46+
// batchSize: int
47+
// maxTotalDraftTokens: int
48+
// maxTopK: int
49+
// tokensGatherIdxForDrafterModel: int32_t*, indices of the draft tokens that need to be expand this layer
50+
// shape: [numTokensExpandThisLayer]
51+
// topKList: int32_t*, top k value for each expandable token
52+
// shape: [numTokensExpandThisLayer]
53+
// draftTokensIndicesCumsum: int32_t*, the cumulative sum of the write back indices for each draft layer
54+
// shape: [maxDraftLen + 1]
55+
// newDraftTokens: int64_t*, the new draft tokens. We only need to extract this layer's tokens and write back to
56+
// the draftTokensBuffer shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK]
57+
// draftTokensBuffer: int64_t*, the buffer to store the real draft tokens
58+
// shape: [maxBatchSize, maxTotalDraftTokens + 1]
59+
60+
// Each thread handles one request
61+
auto const tix = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x);
62+
auto const isValid{tix < batchSize};
63+
64+
if (isValid)
65+
{
66+
int newDraftTokensOffset = curDraftIdx == 0 ? 1 : maxTotalDraftTokens + 1;
67+
auto newDraftTokensStartPtr = newDraftTokens + tix * newDraftTokensOffset * maxTopK;
68+
auto draftTokensBufferPtr
69+
= draftTokensBuffer + tix * (maxTotalDraftTokens + 1) + draftTokensIndicesCumsum[curDraftIdx];
70+
71+
int cnt = 0;
72+
for (int i = 0; i < numTokensExpandThisLayer; i++)
73+
{
74+
int tokenGatherIdx = tokensGatherIdxForDrafterModel[i];
75+
auto newDraftTokenPtr = newDraftTokensStartPtr + tokenGatherIdx * maxTopK;
76+
77+
int topKValue = topKList[i];
78+
for (int j = 0; j < topKValue; j++)
79+
{
80+
int64_t newGenDraftToken = newDraftTokenPtr[j];
81+
draftTokensBufferPtr[cnt] = newGenDraftToken;
82+
cnt++;
83+
}
84+
}
85+
}
86+
}
87+
88+
void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream)
89+
{
90+
int constexpr BLOCK_SIZE = 64;
91+
int NUM_BLOCKS = divUp(params.batchSize, BLOCK_SIZE);
92+
93+
extractRealDraftTokensKernel<<<NUM_BLOCKS, BLOCK_SIZE, 0, stream>>>(params.curDraftIdx, params.batchSize,
94+
params.maxDraftLen, params.maxTotalDraftTokens, params.maxTopK, params.numTokensExpandThisLayer,
95+
params.tokensGatherIdxForDrafterModel, params.topKList, params.draftTokensIndicesCumsum, params.newDraftTokens,
96+
params.draftTokensBuffer);
97+
98+
sync_check_cuda_error(stream);
99+
}
100+
101+
} // namespace kernels
102+
} // namespace tensorrt_llm
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
3+
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include <cuda_bf16.h>
21+
#include <cuda_fp16.h>
22+
23+
#include "tensorrt_llm/common/assert.h"
24+
#include "tensorrt_llm/common/cudaUtils.h"
25+
#include "tensorrt_llm/runtime/common.h"
26+
27+
namespace tensorrt_llm
28+
{
29+
// namespace tensorrt_llm::kernels
30+
namespace kernels
31+
{
32+
33+
////////////////////////////////////////////////////////////////////////////////////////////////////////////
34+
// Relaxed acceptance
35+
struct ExtractRealDraftTokensParam
36+
{
37+
int curDraftIdx;
38+
int batchSize;
39+
int maxDraftLen;
40+
int maxTotalDraftTokens;
41+
int maxTopK;
42+
int numTokensExpandThisLayer;
43+
int* tokensGatherIdxForDrafterModel;
44+
int* topKList;
45+
int* draftTokensIndicesCumsum;
46+
int64_t* newDraftTokens;
47+
int64_t* draftTokensBuffer;
48+
};
49+
50+
void invokeExtractRealDraftTokens(ExtractRealDraftTokensParam& params, cudaStream_t const stream);
51+
52+
} // namespace kernels
53+
54+
} // namespace tensorrt_llm

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ add_library(
100100
virtualMemoryAllocator.cpp
101101
weightOnlyQuantGemm.cpp
102102
weightOnlyQuantOp.cpp
103-
mtpOp.cpp
103+
specDecOp.cpp
104104
loraOp.cpp
105105
finegrained_mixed_dtype_gemm_thop.cpp
106106
tinygemm2.cpp

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ class Runner : public RunnerBase
546546
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
547547
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
548548
enqueue_params.spec_decoding_is_generation_length_variable = true;
549-
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
549+
TLLM_CHECK(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1]
550+
enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1];
550551
}
551552

552553
// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration

cpp/tensorrt_llm/thop/mtpOp.cpp renamed to cpp/tensorrt_llm/thop/specDecOp.cpp

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
* limitations under the License.
1616
*/
1717

18+
#include "tensorrt_llm/common/cudaUtils.h"
1819
#include "tensorrt_llm/common/opUtils.h"
19-
#include "tensorrt_llm/runtime/torchUtils.h"
20-
20+
#include "tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.h"
2121
#include "tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h"
22+
#include "tensorrt_llm/runtime/torchUtils.h"
2223

2324
namespace th = torch;
2425
namespace tl = tensorrt_llm;
@@ -261,6 +262,78 @@ std::tuple<th::Tensor, th::Tensor> mtp_relaxed_acceptance_op(th::Tensor& reqSlot
261262
return std::make_tuple(acceptedTokens, numAcceptedTokens);
262263
}
263264

265+
////////////////////////////////////////////////////////////////////////////////////////////////////////////
266+
void extract_real_draft_tokens_op(th::Tensor newDraftTokens, th::Tensor draftTokensBuffer,
267+
th::Tensor tokensGatherIdxForDrafterModel, th::Tensor topKList, th::Tensor draftTokensIndicesCumsum,
268+
int64_t curDraftIdx, int64_t batchSize, int64_t maxDraftLen, int64_t maxTotalDraftTokens, int64_t maxTopK)
269+
{
270+
// Args:
271+
// curDraftIdx: int
272+
// batchSize: int
273+
// maxTotalDraftTokens: int
274+
// maxTopK: int
275+
// tokensGatherIdxForDrafterModel: Tensor, int32, indices of the draft tokens that need to be expand this layer
276+
// shape: [numTokensExpandThisLayer]
277+
// topKList: Tensor, int32, top k value for each expandable token
278+
// shape: [numTokensExpandThisLayer]
279+
// draftTokensIndicesCumsum: Tensor, int32, the cumulative sum of the write back indices for each draft layer
280+
// shape: [maxDraftLen + 1]
281+
// newDraftTokens: Tensor, int64, the new draft tokens. We only need to extract this layer's tokens and write back
282+
// to the draftTokensBuffer
283+
// shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK]
284+
// draftTokensBuffer: Tensor, int64, the buffer to store the real draft tokens
285+
// shape: [maxBatchSize, maxTotalDraftTokens + 1]
286+
287+
// Check the data types
288+
TLLM_CHECK(tokensGatherIdxForDrafterModel.scalar_type() == torch::kInt32);
289+
TLLM_CHECK(topKList.scalar_type() == torch::kInt32);
290+
TLLM_CHECK(draftTokensIndicesCumsum.scalar_type() == torch::kInt32);
291+
TLLM_CHECK(newDraftTokens.scalar_type() == torch::kInt64);
292+
TLLM_CHECK(draftTokensBuffer.scalar_type() == torch::kInt64);
293+
294+
// Check the shape of 'tokensGatherIdxForDrafterModel' and 'topKList'
295+
auto numTokensExpandThisLayer = tokensGatherIdxForDrafterModel.size(0);
296+
TLLM_CHECK(numTokensExpandThisLayer > 0);
297+
TLLM_CHECK(topKList.size(0) == numTokensExpandThisLayer);
298+
299+
// Check the shape of 'draftTokensIndicesCumsum'
300+
TLLM_CHECK(draftTokensIndicesCumsum.size(0) == maxDraftLen + 1);
301+
302+
// Check the shape of 'newDraftTokens'
303+
TLLM_CHECK(newDraftTokens.size(0) == batchSize);
304+
if (curDraftIdx == 0)
305+
{
306+
TLLM_CHECK(newDraftTokens.size(1) == 1);
307+
TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
308+
}
309+
else
310+
{
311+
TLLM_CHECK(newDraftTokens.size(1) == maxTotalDraftTokens + 1);
312+
TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
313+
}
314+
315+
// Check the shape of 'draftTokensBuffer'
316+
TLLM_CHECK(draftTokensBuffer.size(1) == maxTotalDraftTokens + 1);
317+
318+
auto stream = at::cuda::getCurrentCUDAStream(newDraftTokens.get_device());
319+
320+
// Fill params
321+
tk::ExtractRealDraftTokensParam params;
322+
params.curDraftIdx = curDraftIdx;
323+
params.batchSize = batchSize;
324+
params.maxDraftLen = maxDraftLen;
325+
params.maxTotalDraftTokens = maxTotalDraftTokens;
326+
params.maxTopK = maxTopK;
327+
params.numTokensExpandThisLayer = numTokensExpandThisLayer;
328+
params.tokensGatherIdxForDrafterModel = reinterpret_cast<int32_t*>(tokensGatherIdxForDrafterModel.data_ptr());
329+
params.topKList = reinterpret_cast<int32_t*>(topKList.data_ptr());
330+
params.draftTokensIndicesCumsum = reinterpret_cast<int32_t*>(draftTokensIndicesCumsum.data_ptr());
331+
params.newDraftTokens = reinterpret_cast<int64_t*>(newDraftTokens.data_ptr());
332+
params.draftTokensBuffer = reinterpret_cast<int64_t*>(draftTokensBuffer.data_ptr());
333+
334+
tk::invokeExtractRealDraftTokens(params, stream);
335+
}
336+
264337
} // end namespace torch_ext
265338

266339
TORCH_LIBRARY_FRAGMENT(trtllm, m)
@@ -323,3 +396,18 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
323396
{
324397
m.impl("mtp_relaxed_acceptance_op", &torch_ext::mtp_relaxed_acceptance_op);
325398
}
399+
400+
////////////////////////////////////////////////////////////////////////////////////////////////////////////
401+
402+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
403+
{
404+
m.def(
405+
"extract_real_draft_tokens_op(Tensor newDraftTokens, Tensor draftTokensBuffer, "
406+
"Tensor tokensGatherIdxForDrafterModel, Tensor topKList, Tensor draftTokensIndicesCumsum, "
407+
"int curDraftIdx, int batchSize, int maxDraftLen, int maxTotalDraftTokens, int maxTopK) -> ()");
408+
}
409+
410+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
411+
{
412+
m.impl("extract_real_draft_tokens_op", &torch_ext::extract_real_draft_tokens_op);
413+
}

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
if TYPE_CHECKING:
1313
from ..speculative.utils import SpecDecodingTensor
14+
from ..speculative.interface import SpecMetadata
15+
from ..speculative.spec_tree_manager import SpecTreeManager
1416

1517
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
1618
RotaryScalingType)
@@ -338,10 +340,15 @@ def restore_from_spec_dec(self) -> None:
338340

339341
def update_spec_dec_param(
340342
self,
343+
batch_size,
341344
is_spec_decoding_enabled,
342345
is_spec_dec_tree,
343346
is_spec_dec_dynamic_tree,
344-
max_draft_tokens,
347+
max_draft_len,
348+
max_total_draft_tokens,
349+
model_is_wrapped: bool = False,
350+
spec_metadata: Optional['SpecMetadata'] = None,
351+
spec_tree_manager: Optional['SpecTreeManager'] = None,
345352
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
346353
"""
347354
Hook to be called when using TRTLLM attention backend in spec-dec mode.

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,17 +477,24 @@ def create_expanded_buffers(self, capture_graph=False):
477477
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
478478
def update_spec_dec_param(
479479
self,
480+
batch_size,
480481
is_spec_decoding_enabled,
481482
is_spec_dec_tree,
482483
is_spec_dec_dynamic_tree,
483-
max_draft_tokens,
484+
max_draft_len,
485+
max_total_draft_tokens,
486+
model_is_wrapped: bool = False,
487+
spec_metadata: Optional['SpecMetadata'] = None,
488+
spec_tree_manager: Optional['SpecTreeManager'] = None,
484489
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
485490
):
486-
super().update_spec_dec_param(is_spec_decoding_enabled,
491+
super().update_spec_dec_param(batch_size, is_spec_decoding_enabled,
487492
is_spec_dec_tree,
488-
is_spec_dec_dynamic_tree,
489-
max_draft_tokens, spec_decoding_tensor)
490-
self.max_draft_tokens = max_draft_tokens
493+
is_spec_dec_dynamic_tree, max_draft_len,
494+
max_total_draft_tokens, model_is_wrapped,
495+
spec_metadata, spec_tree_manager,
496+
spec_decoding_tensor)
497+
self.max_draft_tokens = max_draft_len
491498
init_shape = self.kv_lens_expanded_host.shape[0]
492499
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
493500
capture_graph = torch.cuda.is_current_stream_capturing()

0 commit comments

Comments
 (0)