Skip to content

Commit ec2b953

Browse files
authored
refactor: Enhanced handling of decoder requests and logits within the batch manager (#6055)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 77acb4f commit ec2b953

File tree

16 files changed

+168
-155
lines changed

16 files changed

+168
-155
lines changed

cpp/include/tensorrt_llm/batch_manager/decoderBuffers.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include "tensorrt_llm/batch_manager/common.h"
1920
#include "tensorrt_llm/runtime/bufferManager.h"
2021
#include "tensorrt_llm/runtime/iTensor.h"
2122
#include "tensorrt_llm/runtime/modelConfig.h"
@@ -38,8 +39,8 @@ class DecoderInputBuffers
3839
using SizeType32 = runtime::SizeType32;
3940
using TensorPtr = runtime::ITensor::SharedPtr;
4041

41-
explicit DecoderInputBuffers(SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps,
42-
runtime::BufferManager const& manager);
42+
explicit DecoderInputBuffers(
43+
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, runtime::BufferManager const& manager);
4344

4445
void setupMedusaLogits(SizeType32 maxNumSequences, runtime::ModelConfig const& modelConfig);
4546

@@ -56,11 +57,13 @@ class DecoderInputBuffers
5657

5758
//! Buffers for decoder forward
5859

60+
//! Requests for considered in decoder forward
61+
RequestVector decoderRequests;
62+
5963
//! Batch slots for all decoder steps, [maxDecoderSteps][maxBatchSize]
6064
std::vector<TensorPtr> forwardBatchSlots;
6165

62-
//! Logits for all batch slots, [maxNumSequences]
63-
//! The vector is sparse, only slots in forwardBatchSlots are used.
66+
//! Logits of decoder requests
6467
std::vector<TensorPtr> logits;
6568

6669
//! Logits for speculative decoding (Medusa)

cpp/include/tensorrt_llm/batch_manager/guidedDecoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GrammarCompiler;
2929

3030
namespace tensorrt_llm::batch_manager
3131
{
32+
class DecoderInputBuffers;
3233

3334
class GuidedDecoder
3435
{
@@ -40,8 +41,7 @@ class GuidedDecoder
4041
GuidedDecoder(executor::GuidedDecodingConfig const& guidedDecodingConfig, SizeType32 maxNumSequences,
4142
SizeType32 vocabSizePadded, nvinfer1::DataType logitsDtype, runtime::BufferManager const& runtimeBufferManager);
4243
void build(ScheduledRequests const& scheduledRequests);
43-
void execute(ScheduledRequests const& scheduledRequests, runtime::BufferManager const& runtimeBufferManager,
44-
std::vector<TensorPtr> const& decoderBuffersLogits);
44+
void execute(DecoderInputBuffers const& decoderInputBuffers, runtime::BufferManager const& runtimeBufferManager);
4545

4646
private:
4747
executor::GuidedDecodingConfig::GuidedDecodingBackend mGuidedDecodingBackend;

cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,29 @@
2424

2525
namespace tensorrt_llm::runtime
2626
{
27-
class TllmRuntime;
27+
class CudaStream;
2828
}
2929

3030
namespace tensorrt_llm::batch_manager
3131
{
32+
class DecoderInputBuffers;
3233

3334
class LogitsPostProcessor : Algorithm
3435
{
3536
public:
37+
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
38+
3639
using LogitsPostProcessorBatched = std::function<void(std::vector<batch_manager::LlmRequest::RequestIdType> const&,
3740
std::vector<batch_manager::LlmRequest::TensorPtr>&,
38-
std::vector<std::reference_wrapper<batch_manager::LlmRequest::BeamTokens const>> const&,
39-
runtime::BufferManager::CudaStreamPtr const&,
41+
std::vector<std::reference_wrapper<batch_manager::LlmRequest::BeamTokens const>> const&, CudaStreamPtr const&,
4042
std::vector<std::optional<batch_manager::LlmRequest::RequestIdType>> const&)>;
4143

4244
constexpr static auto name{"LogitsPostProcessor"};
4345

4446
LogitsPostProcessor() = default;
4547

46-
bool operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
47-
bool replicateLogitsPostProcessor, std::vector<batch_manager::LlmRequest::TensorPtr>& seqSlotLogits,
48-
runtime::WorldConfig const& worldConfig, runtime::TllmRuntime& runtime,
48+
bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
49+
runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
4950
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
5051
};
5152

cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class MakeDecodingBatchInputOutput : Algorithm
4646

4747
MakeDecodingBatchInputOutput() = default;
4848

49-
std::unique_ptr<runtime::decoder_batch::Input> operator()(RequestVector const& contextRequests,
50-
RequestVector const& generationRequests, DecoderInputBuffers const& inputBuffers,
49+
std::unique_ptr<runtime::decoder_batch::Input> operator()(DecoderInputBuffers& inputBuffers,
5150
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig,
5251
SizeType32 maxNumSequences, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
5352

cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace tensorrt_llm::batch_manager
3131
{
3232

3333
DecoderInputBuffers::DecoderInputBuffers(
34-
SizeType32 maxNumSequences, SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager)
34+
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager)
3535
{
3636
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
3737
auto const nvSizeType = TRTDataType<SizeType32>::value;
@@ -49,8 +49,6 @@ DecoderInputBuffers::DecoderInputBuffers(
4949
{
5050
forwardBatchSlots.emplace_back(BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize}), nvSizeType));
5151
}
52-
53-
logits.resize(maxNumSequences);
5452
}
5553

5654
void DecoderInputBuffers::setupMedusaLogits(SizeType32 maxNumSequences, ModelConfig const& modelConfig)

cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
#include "tensorrt_llm/batch_manager/guidedDecoder.h"
19+
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
1920
#include "tensorrt_llm/batch_manager/llmRequest.h"
2021
#include "tensorrt_llm/kernels/logitsBitmask.h"
2122

@@ -136,8 +137,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
136137
}
137138
}
138139

139-
void GuidedDecoder::execute(ScheduledRequests const& scheduledRequests, BufferManager const& runtimeBufferManager,
140-
std::vector<TensorPtr> const& decoderBuffersLogits)
140+
void GuidedDecoder::execute(DecoderInputBuffers const& decoderInputBuffers, BufferManager const& runtimeBufferManager)
141141
{
142142
auto const& stream = runtimeBufferManager.getStream();
143143

@@ -150,32 +150,28 @@ void GuidedDecoder::execute(ScheduledRequests const& scheduledRequests, BufferMa
150150
mCopyBufferManager.getStream().record(event);
151151
stream.wait(event);
152152

153-
SizeType32 batchIdx{0};
154-
if (mGuidedDecodingBackend == executor::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
153+
if (mGuidedDecodingBackend == executor::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR
154+
&& !decoderInputBuffers.decoderRequests.empty())
155155
{
156-
for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests})
156+
SizeType32 batchIdx{0};
157+
for (size_t requestIdx = 0; requestIdx < decoderInputBuffers.decoderRequests.size(); ++requestIdx)
157158
{
158-
for (auto const& llmReq : requests)
159+
auto const& llmReq = decoderInputBuffers.decoderRequests.at(requestIdx);
160+
161+
auto const& guidedDecodingParams = llmReq->getGuidedDecodingParams();
162+
if (guidedDecodingParams.has_value())
159163
{
160-
if (llmReq->isContextInitState() && !llmReq->isLastContextChunk())
161-
{
162-
continue;
163-
}
164-
auto const& guidedDecodingParams = llmReq->getGuidedDecodingParams();
165-
if (guidedDecodingParams.has_value())
166-
{
167-
auto const seqSlot = llmReq->mSeqSlot.value();
164+
auto const seqSlot = llmReq->mSeqSlot.value();
168165

169-
auto const& logits = decoderBuffersLogits.at(seqSlot);
170-
auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot});
166+
auto const& logits = decoderInputBuffers.logits.at(requestIdx);
167+
auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot});
171168

172-
// Use void* to unify the code for different mLogitsDtype
173-
*reinterpret_cast<void**>(ITensor::at(mLogitsPtrVecHost, {batchIdx})->data()) = logits->data();
174-
*reinterpret_cast<void**>(ITensor::at(mLogitsBitmaskPtrVecHost, {batchIdx})->data())
175-
= logitsBitmask->data();
169+
// Use void* to unify the code for different mLogitsDtype
170+
*reinterpret_cast<void**>(ITensor::at(mLogitsPtrVecHost, {batchIdx})->data()) = logits->data();
171+
*reinterpret_cast<void**>(ITensor::at(mLogitsBitmaskPtrVecHost, {batchIdx})->data())
172+
= logitsBitmask->data();
176173

177-
++batchIdx;
178-
}
174+
++batchIdx;
179175
}
180176
}
181177
if (batchIdx > 0)

cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re
7676
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
7777
NVTX3_SCOPED_RANGE(HandleContextLogits);
7878

79+
auto& decoderRequests = inputBuffers.decoderRequests;
80+
decoderRequests.clear();
81+
decoderRequests.reserve(contextRequests.size());
82+
auto& allDecoderLogits = inputBuffers.logits;
83+
allDecoderLogits.clear();
84+
allDecoderLogits.reserve(contextRequests.size());
85+
7986
SizeType32 batchIndex{0};
8087
SizeType32 logitsIndex{0};
8188
// Copy logits into decoderBuffers.logits
@@ -115,7 +122,6 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re
115122
// Get the logits from the last context token and draft tokens
116123
auto const numDecoderLogits = 1 + draftLength;
117124
auto const seqSlot = llmReq->mSeqSlot.value();
118-
auto& decoderLogits = inputBuffers.logits.at(seqSlot);
119125
TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits);
120126

121127
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
@@ -136,22 +142,28 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re
136142

137143
TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid<float>(*logitsView, manager, "logits") == false,
138144
"Found invalid number (NaN or Inf) in logits");
139-
// Scatter the output logits to the decoderLogits
140-
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
141-
if (reqBeamWidth > 1)
142-
{
143-
// Tile logits of context requests
144-
auto const logitsShape = logitsView->getShape();
145-
auto const logitsType = logitsView->getDataType();
146-
decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType);
147-
tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, manager.getStream());
148-
decoderLogits->unsqueeze(0);
149-
}
150-
else
145+
146+
if (llmReq->isLastContextChunk())
151147
{
152-
auto const logitsViewShape = logitsView->getShape();
153-
decoderLogits
154-
= ITensor::view(logitsView, ITensor::makeShape({logitsViewShape.d[0], 1, logitsViewShape.d[1]}));
148+
TensorPtr decoderLogits;
149+
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
150+
if (reqBeamWidth > 1)
151+
{
152+
// Tile logits of context requests
153+
auto const& logitsShape = logitsView->getShape();
154+
auto const logitsType = logitsView->getDataType();
155+
decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType);
156+
tensorrt_llm::runtime::kernels::tileTensor(
157+
*decoderLogits, *logitsView, reqBeamWidth, manager.getStream());
158+
decoderLogits->unsqueeze(0);
159+
}
160+
else
161+
{
162+
decoderLogits = logitsView;
163+
decoderLogits->unsqueeze(1);
164+
}
165+
decoderRequests.push_back(llmReq);
166+
allDecoderLogits.emplace_back(std::move(decoderLogits));
155167
}
156168

157169
++batchIndex;

cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
2323
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
2424
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
25+
#include "tensorrt_llm/common/assert.h"
2526
#include "tensorrt_llm/common/nvtxUtils.h"
2627
#include "tensorrt_llm/runtime/iTensor.h"
2728
#include "tensorrt_llm/runtime/utils/debugUtils.h"
@@ -82,6 +83,11 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque
8283
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
8384
NVTX3_SCOPED_RANGE(HandleGenerationLogits);
8485

86+
auto& decoderRequests = inputBuffers.decoderRequests;
87+
decoderRequests.reserve(decoderRequests.size() + generationRequests.size());
88+
auto& allDecoderLogits = inputBuffers.logits;
89+
allDecoderLogits.reserve(allDecoderLogits.size() + generationRequests.size());
90+
8591
for (auto const& llmReq : generationRequests)
8692
{
8793
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
@@ -101,18 +107,21 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque
101107
TensorPtr logitsView = ITensor::slice(logits, logitsIndex, numLogits);
102108
TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid<float>(*logitsView, manager, "logits") == false,
103109
"Found invalid number (NaN or Inf) in logits");
104-
auto& decoderLogits = inputBuffers.logits.at(seqSlot);
105-
auto const logitsViewShape = logitsView->getShape();
110+
111+
TLLM_CHECK(llmReq->isGenerationInProgressState());
112+
TensorPtr decoderLogits;
106113
if (reqBeamWidth > 1)
107114
{
108115
decoderLogits = logitsView;
109116
decoderLogits->unsqueeze(0);
110117
}
111118
else
112119
{
113-
decoderLogits
114-
= ITensor::view(logitsView, ITensor::makeShape({logitsViewShape.d[0], 1, logitsViewShape.d[1]}));
120+
decoderLogits = logitsView;
121+
decoderLogits->unsqueeze(1);
115122
}
123+
decoderRequests.push_back(llmReq);
124+
allDecoderLogits.emplace_back(std::move(decoderLogits));
116125

117126
if (llmReq->getReturnGenerationLogits())
118127
{

cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,24 @@
1717

1818
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
1919

20+
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
2021
#include "tensorrt_llm/batch_manager/llmRequest.h"
2122
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
2223
#include "tensorrt_llm/common/nvtxUtils.h"
2324
#include "tensorrt_llm/runtime/iTensor.h"
24-
#include "tensorrt_llm/runtime/tllmRuntime.h"
2525

2626
namespace tr = tensorrt_llm::runtime;
2727

2828
namespace tensorrt_llm::batch_manager
2929
{
3030

31-
using BufferManager = tensorrt_llm::runtime::BufferManager;
3231
using TensorPtr = runtime::ITensor::SharedPtr;
3332
using ITensor = runtime::ITensor;
3433
using SizeType32 = tensorrt_llm::runtime::SizeType32;
3534

36-
bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
37-
bool replicateLogitsPostProcessor, std::vector<TensorPtr>& seqSlotLogits, tr::WorldConfig const& worldConfig,
38-
tr::TllmRuntime& runtime, std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
35+
bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
36+
tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
37+
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
3938
{
4039
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
4140
NVTX3_SCOPED_RANGE(LogitsPostProcessor);
@@ -47,35 +46,28 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque
4746
std::vector<std::optional<LlmRequest::RequestIdType>> clientIdsVec;
4847

4948
bool logitsPostProcessorIsApplied = false;
50-
for (auto const& requests : {contextRequests, generationRequests})
49+
for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx)
5150
{
52-
for (auto const& llmReq : requests)
51+
auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx);
52+
auto& logits = inputBuffers.logits.at(batchIdx);
53+
54+
// Invoke non-batched processor or collect arguments for batched processor
55+
if (llmReq->mLogitsPostProcessor)
5356
{
54-
if (llmReq->isContextInitState() ? llmReq->isLastContextChunk() : llmReq->isGenerationInProgressState())
57+
logitsPostProcessorIsApplied = true;
58+
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
5559
{
56-
// Invoke non-batched processor or collect arguments for batched processor
57-
if (llmReq->mLogitsPostProcessor)
58-
{
59-
logitsPostProcessorIsApplied = true;
60-
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
61-
{
62-
auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value());
63-
(*llmReq->mLogitsPostProcessor)(
64-
llmReq->mRequestId, logits, llmReq->getTokens(), runtime.getStreamPtr(), llmReq->mClientId);
65-
}
66-
}
67-
else if (llmReq->mApplyLogitsPostProcessorBatched)
68-
{
69-
reqIdsVec.push_back(llmReq->mRequestId);
70-
71-
auto& logits = seqSlotLogits.at(llmReq->mSeqSlot.value());
72-
logitsVec.push_back(logits);
73-
74-
beamTokensVec.emplace_back(llmReq->getTokens());
75-
clientIdsVec.push_back(llmReq->mClientId);
76-
}
60+
(*llmReq->mLogitsPostProcessor)(
61+
llmReq->mRequestId, logits, llmReq->getTokens(), stream, llmReq->mClientId);
7762
}
7863
}
64+
else if (llmReq->mApplyLogitsPostProcessorBatched)
65+
{
66+
reqIdsVec.push_back(llmReq->mRequestId);
67+
logitsVec.push_back(logits);
68+
beamTokensVec.emplace_back(llmReq->getTokens());
69+
clientIdsVec.push_back(llmReq->mClientId);
70+
}
7971
}
8072

8173
// Invoke batched processor
@@ -84,7 +76,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque
8476
logitsPostProcessorIsApplied = true;
8577
if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank())
8678
{
87-
(*logitsPostProcessorBatched)(reqIdsVec, logitsVec, beamTokensVec, runtime.getStreamPtr(), clientIdsVec);
79+
(*logitsPostProcessorBatched)(reqIdsVec, logitsVec, beamTokensVec, stream, clientIdsVec);
8880
}
8981
}
9082

0 commit comments

Comments
 (0)