Skip to content

Commit a3d7cd6

Browse files
committed
refactor: Remove maxNumSequences parameter from MakeDecodingBatchInputOutput
- Removed maxNumSequences parameter from createDecoderBatchInputs and related function calls, streamlining the interface. - Updated all relevant implementations and tests to reflect the changes in function signatures. Signed-off-by: Robin Kobus <[email protected]>
1 parent 4115bd5 commit a3d7cd6

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

cpp/include/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ class MakeDecodingBatchInputOutput : Algorithm
4848

4949
void operator()(DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
5050
RequestVector const& contextRequests, RequestVector const& generationRequests,
51-
std::vector<TensorPtr> const& logits, runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
51+
std::vector<TensorPtr> const& logits, runtime::ModelConfig const& modelConfig,
5252
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const;
5353

5454
static void createDecoderBatchInputs(DecoderInputBuffers& inputBuffers, std::vector<SizeType32> const& activeSlots,
55-
runtime::decoder::DecoderState const& decoderState, std::vector<TensorPtr> const& logits,
56-
SizeType32 maxNumSequences);
55+
runtime::decoder::DecoderState const& decoderState, std::vector<TensorPtr> const& logits);
5756
};
5857

5958
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using TensorPtr = MakeDecodingBatchInputOutput::TensorPtr;
3333

3434
void MakeDecodingBatchInputOutput::createDecoderBatchInputs(DecoderInputBuffers& inputBuffers,
3535
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
36-
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences)
36+
std::vector<TensorPtr> const& logits)
3737
{
3838
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
3939

@@ -46,7 +46,7 @@ void MakeDecodingBatchInputOutput::createDecoderBatchInputs(DecoderInputBuffers&
4646

4747
for (SizeType32 step = 0; step < maxDecoderSteps; ++step)
4848
{
49-
batchSlots.at(step)->resize(maxNumSequences);
49+
batchSlots.at(step)->resize(activeSlots.size());
5050
}
5151

5252
std::vector<SizeType32> batchIdx(maxDecoderSteps);
@@ -181,14 +181,13 @@ void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntim
181181
void MakeDecodingBatchInputOutput::operator()(DecoderInputBuffers& inputBuffers,
182182
runtime::decoder::DecoderState& decoderState, RequestVector const& contextRequests,
183183
RequestVector const& generationRequests, std::vector<TensorPtr> const& logits,
184-
runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
185-
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const
184+
runtime::ModelConfig const& modelConfig, OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const
186185
{
187186
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
188187

189188
auto [activeSlots, generationSteps] = getActiveSlots(contextRequests, generationRequests);
190189

191-
createDecoderBatchInputs(inputBuffers, activeSlots, decoderState, logits, maxNumSequences);
190+
createDecoderBatchInputs(inputBuffers, activeSlots, decoderState, logits);
192191

193192
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
194193
if (maxBeamWidth > 1)

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,7 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques
20082008
auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId);
20092009

20102010
(*mMakeDecodingBatchInputOutput)(decoderInputBuffers, *mDecoderState, scheduledRequests.contextRequests,
2011-
scheduledRequests.generationRequests, seqSlotLogits, mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers);
2011+
scheduledRequests.generationRequests, seqSlotLogits, mModelConfig, *fusedRuntimeBuffers);
20122012

20132013
auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, decoderInputBuffers);
20142014

cpp/tests/runtime/gptDecoderBatchedTest.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include "tensorrt_llm/runtime/iBuffer.h"
2828
#include "tensorrt_llm/runtime/iTensor.h"
2929
#include "tensorrt_llm/runtime/modelConfig.h"
30-
#include "tensorrt_llm/runtime/runtimeKernels.h"
3130
#include "tensorrt_llm/runtime/worldConfig.h"
3231

3332
#include <gmock/gmock-matchers.h>
@@ -354,7 +353,7 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig>& sa
354353
auto activeSlots = std::vector<SizeType32>(batchSize);
355354
std::iota(activeSlots.begin(), activeSlots.end(), 0);
356355
tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(
357-
inputBuffers, activeSlots, decoderState, decoderInputs.logits, batchSize);
356+
inputBuffers, activeSlots, decoderState, decoderInputs.logits);
358357
decoder.forward(decoderState, inputBuffers);
359358

360359
checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager);
@@ -484,7 +483,7 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingCo
484483
auto activeSlots = std::vector<SizeType32>(batchIdx + 1);
485484
std::iota(activeSlots.begin(), activeSlots.end(), 0);
486485
tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(
487-
inputBuffers, activeSlots, decoderState, decoderInputs.logits, batchSize);
486+
inputBuffers, activeSlots, decoderState, decoderInputs.logits);
488487
decoder.forward(decoderState, inputBuffers);
489488

490489
advanceSequenceLengths(
@@ -507,7 +506,7 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingCo
507506
while (!std::all_of(expectedFinished.begin(), expectedFinished.end(), [](bool finish) { return finish; }))
508507
{
509508
tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(
510-
inputBuffers, activeSlots, decoderState, decoderInputs.logits, batchSize);
509+
inputBuffers, activeSlots, decoderState, decoderInputs.logits);
511510
decoder.forward(decoderState, inputBuffers);
512511
finishedVec = getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager);
513512

@@ -643,7 +642,7 @@ void testDecoderDraft(nvinfer1::DataType const dtype, std::vector<SamplingConfig
643642
auto activeSlots = std::vector<SizeType32>(batchSize);
644643
std::iota(activeSlots.begin(), activeSlots.end(), 0);
645644
tb::MakeDecodingBatchInputOutput::createDecoderBatchInputs(
646-
inputBuffers, activeSlots, decoderState, decoderInputs.logits, batchSize);
645+
inputBuffers, activeSlots, decoderState, decoderInputs.logits);
647646
decoder.forward(decoderState, inputBuffers);
648647
checkSequenceLengths(*decoderState.getSequenceLengths(), expectedLengths, manager);
649648
EXPECT_THAT(getFinished(*decoderState.getFinishedSum(), samplingConfigs, manager), ::testing::Each(false));

0 commit comments

Comments
 (0)