Skip to content

Commit 37543a9

Browse files
authored
[None][refactor] Simplify decoder state initialization for speculative decoding (#6869)
Signed-off-by: Robin Kobus <[email protected]>
1 parent c232ba8 commit 37543a9

File tree

8 files changed

+223
-305
lines changed

8 files changed

+223
-305
lines changed

cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "tensorrt_llm/runtime/common.h"
2525
#include "tensorrt_llm/runtime/iTensor.h"
2626
#include "tensorrt_llm/runtime/modelConfig.h"
27-
#include "tensorrt_llm/runtime/request.h"
2827
#include "tensorrt_llm/runtime/worldConfig.h"
2928

3029
namespace tensorrt_llm::runtime
@@ -88,37 +87,6 @@ class CreateNewDecoderRequests : Algorithm
8887
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;
8988

9089
private:
91-
//! @brief Setups decoder internal tensors for new speculative decoding request
92-
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
93-
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
94-
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
95-
CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode,
96-
SizeType32 maxDecodingEngineTokens);
97-
98-
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
99-
static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
100-
SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream);
101-
102-
//! @brief Setups decoder internal tensors for new Medusa request
103-
static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
104-
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens);
105-
106-
//! @brief Setups decoder internal tensors for new Lookahead request
107-
static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
108-
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
109-
110-
//! @brief Setups decoder internal tensors for new Explicit draft tokens request
111-
static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
112-
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
113-
114-
//! @brief Setups decoder internal tensors for new Eagle request
115-
static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
116-
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
117-
118-
[[nodiscard]] std::shared_ptr<runtime::ITensor> retrieveDraftLogits(runtime::ModelConfig const& modelConfig,
119-
runtime::WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
120-
runtime::BufferManager const& bufferManager) const;
121-
12290
bool mSpeculativeDecodingFastLogits;
12391
bool mIsLeaderInOrchMode;
12492
bool mIsNormalizeLogProbs;

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ class GenericLlmRequest
11101110

11111111
[[nodiscard]] SizeType32 getNumDraftTokens() const
11121112
{
1113-
return mDraftTokens->size();
1113+
return hasDraftTokens() ? mDraftTokens->size() : 0;
11141114
}
11151115

11161116
void discardDraftTokens(SizeType32 numTokensToDiscard)

cpp/include/tensorrt_llm/runtime/decodingInput.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,13 @@ class DecodingInput
102102
{
103103
public:
104104
TensorPtr draftLogits;
105+
TensorPtr draftLogitsHost;
105106
TensorPtr draftProbs;
106107
TensorPtr targetProbs;
107108
TensorPtr numDraftTokens;
108109
TensorPtr numDraftTokensHost;
109110
TensorPtr draftTokenIds;
111+
TensorPtr draftTokenIdsHost;
110112
TensorPtr useDraftLogits;
111113
TensorPtr useDraftLogitsHost;
112114

cpp/include/tensorrt_llm/runtime/request.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)