diff --git a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h index cb77545578c..b429821746e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h @@ -47,7 +47,7 @@ class HandleContextLogits : Algorithm runtime::SizeType32 operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, runtime::ITensor::SharedPtr const& logits, std::vector const& numContextLogitsVec, runtime::ModelConfig const& modelConfig, runtime::BufferManager const& manager, - OptionalRef medusaBuffers) const; + OptionalRef medusaBuffers, runtime::SizeType32 vocabId = 0) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h index f9fd58800a6..343dbd13f2b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h @@ -47,7 +47,8 @@ class HandleGenerationLogits : Algorithm void operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests, runtime::ITensor::SharedPtr const& logits, runtime::SizeType32 logitsIndex, runtime::ModelConfig const& modelConfig, runtime::BufferManager const& manager, - OptionalRef genRuntimeBuffers, OptionalRef medusaBuffers) const; + OptionalRef genRuntimeBuffers, OptionalRef medusaBuffers, + runtime::SizeType32 vocabId = 0) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d97b87086f5..41509d882d3 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -553,7 +553,7 @@ class WindowBlockManager SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager, + std::shared_ptr kvCacheConnectorManager, SizeType32 numVocabs = 1, std::shared_ptr loopbackAgent = nullptr); ~WindowBlockManager(); @@ -888,6 +888,9 @@ class WindowBlockManager bool mEnablePartialReuse; // Whether partially matched blocks that are already in use should be copied and reused. bool mCopyOnPartialReuse; + + SizeType32 mNumVocabs; + // The kv cache connector manager std::shared_ptr mKvCacheConnectorManager; }; @@ -907,7 +910,7 @@ class BlockManager SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnPartialReuse = true, + bool copyOnPartialReuse = true, SizeType32 numVocabs = 1, std::shared_ptr kvCacheConnectorManager = nullptr, std::optional agentConfig = std::nullopt); @@ -1212,6 +1215,7 @@ class BlockManager std::vector mLayerToWindowSize; std::vector mAbsolutePoolToWindowSize; std::vector mAbsolutePoolToRelativePoolIndex; + SizeType32 mNumVocabs; }; struct OffsetTableDimensions @@ -1433,7 +1437,7 @@ class KVCacheManager : public BaseKVCacheManager bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true, + bool copyOnpartialReuse = true, SizeType32 numVocabs = 1, std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, @@ -1444,7 +1448,7 @@ class KVCacheManager : public BaseKVCacheManager bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true, + bool copyOnpartialReuse = true, SizeType32 numVocabs = 1, std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, @@ -1455,7 +1459,7 @@ class KVCacheManager : public BaseKVCacheManager bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, - bool copyOnpartialReuse = true, + bool copyOnpartialReuse = true, SizeType32 numVocabs = 1, std::shared_ptr kvCacheConnectorManager = nullptr); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, @@ -1464,7 +1468,7 @@ class KVCacheManager : public BaseKVCacheManager std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true, - bool copyOnpartialReuse = true); + bool copyOnpartialReuse = true, SizeType32 numVocabs = 1); ~KVCacheManager() override = default; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 2aebf77b96d..f5dcce330dc 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -40,11 +40,35 @@ class BlockRange return BlockRange(cacheManager, blockIds, requestId); } - static BlockRange fromNewlyAllocatedBlockIds( - BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) + static BlockRange fromAllBlockIds( + BaseKVCacheManager const& cacheManager, LlmRequest const& llmRequest, SizeType32 beam = kFIRST_AND_ONLY_BEAM) { + assert(kFIRST_AND_ONLY_BEAM == beam); + + const LlmRequest::RequestIdType requestId = llmRequest.getSeqSlotId(0); + std::vector blockIds; + auto const windowSize = firstWindowSize(cacheManager); + + for (int i = 0; i < llmRequest.getNumSequences(); i++) + { + auto const& thisBlockIds = cacheManager.getSequence(llmRequest.getSeqSlotId(i)) + .getCacheBlockIds(windowSize) + .at(kFIRST_AND_ONLY_BEAM); + blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end()); + } + return BlockRange(cacheManager, blockIds, requestId); + } + + static BlockRange fromNewlyAllocatedBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest const& llmRequest) + { + const LlmRequest::RequestIdType requestId = llmRequest.getSeqSlotId(0); auto const windowSize = firstWindowSize(cacheManager); - auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); + std::vector blockIds; + for (int i = 0; i < llmRequest.getNumSequences(); i++) + { + auto const& thisBlockIds = cacheManager.getNewlyAllocatedBlockIds(llmRequest.getSeqSlotId(i), windowSize); + blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end()); + } return BlockRange(cacheManager, blockIds, requestId); } diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 85c9a3ac942..d86375e6bc1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -130,6 +131,7 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, + std::optional decoderContextFeatures = std::nullopt, std::optional crossAttentionMask = std::nullopt, LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::optional> inputTokenExtraIds = std::nullopt, @@ -139,9 +141,10 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + SizeType32 numVocabs = 1) : mRequestId(requestId) - , mPromptLen(inputTokens->size()) + , mPromptLen(inputTokens->size() / numVocabs) , mMaxNewTokens(maxNewTokens) , mSamplingConfig(samplingConfig) , mEndId(endId) @@ -185,6 +188,7 @@ class GenericLlmRequest , mFinishReasons(samplingConfig.beamWidth) , mEncoderInputFeatures(std::move(encoderInputFeatures)) , mEncoderOutputLength(encoderOutputLength) + , mDecoderContextFeatures(std::move(decoderContextFeatures)) , mCrossAttentionMask(std::move(crossAttentionMask)) , mLlmRequestType(llmRequestType) , mContextPhaseParams(contextPhaseParams) @@ -197,6 +201,7 @@ class GenericLlmRequest , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) , mCacheSaltID(cacheSaltID) + , mNumVocabs{numVocabs} { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) { @@ -224,9 +229,9 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt) + std::optional cacheSaltID = std::nullopt, SizeType32 numVocabs = 1) : mRequestId(requestId) - , mPromptLen(inputTokens.size()) + , mPromptLen(inputTokens.size() / numVocabs) , mMaxNewTokens(maxNewTokens) , mSamplingConfig(samplingConfig) , mEndId(endId) @@ -265,6 +270,7 @@ class GenericLlmRequest , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) , mCacheSaltID(cacheSaltID) + , mNumVocabs{numVocabs} { if (mEncoderTokens.has_value()) { @@ -275,7 +281,7 @@ class GenericLlmRequest GenericLlmRequest(RequestIdType requestId, executor::Request const& req) : mRequestId(requestId) - , mPromptLen(req.getInputTokenIds().size()) + , mPromptLen(req.getInputTokenIds().size() / req.getNumVocabs()) , mMaxNewTokens(req.getMaxTokens()) , mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig()) , mEndId(req.getEndId()) @@ -304,6 +310,7 @@ class GenericLlmRequest , mLanguageAdapterUid(req.getLanguageAdapterUid()) , mAllottedTimeMs(req.getAllottedTimeMs()) , mCacheSaltID(req.getCacheSaltID()) + , mNumVocabs(req.getNumVocabs()) { if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY) { @@ -437,6 +444,16 @@ class GenericLlmRequest mEncoderInputFeatures = std::nullopt; } + auto const& decoderContextFeatures = req.getDecoderContextFeatures(); + if (decoderContextFeatures.has_value()) + { + mDecoderContextFeatures = executor::detail::toITensor(decoderContextFeatures.value()); + } + else + { + mDecoderContextFeatures = std::nullopt; + } + auto const& crossAttentionMask = req.getCrossAttentionMask(); if (crossAttentionMask.has_value()) { @@ -512,12 +529,54 @@ class GenericLlmRequest mContextProgress = progress; } + /// @brief Get number of vocabs for this multi vocab sampling + /// @return The number of vocabs + [[nodiscard]] SizeType32 getNumVocabs() const + { + return mNumVocabs; + } + /// @brief Get total number of tokens for this req (prompt + generated) /// @param beam The beam index /// @return The number of tokens [[nodiscard]] SizeType32 getNumTokens(SizeType32 beam) const { - return mTokens.at(beam).size() - mNumPreDecodedTokens[beam]; + return mTokens.at(beam).size() / getNumVocabs() - mNumPreDecodedTokens[beam]; + } + + [[nodiscard]] bool isCfg() const + { + return mSamplingConfig.cfgScale.has_value() && mSamplingConfig.cfgScale->at(0) != 1.0f; + } + + [[nodiscard]] SizeType32 getNumSequences() const + { + if (isCfg()) + { + TLLM_CHECK_WITH_INFO(mSamplingConfig.beamWidth == 1, "cfgScale is only supported for beamWidth = 1"); + return 2; + } + return 1; + } + + [[nodiscard]] SizeType32 getSeqSlot(int idx) const + { + TLLM_CHECK_WITH_INFO(idx >= 0 && idx < getNumSequences(), "seq slot idx is out of range"); + return mSeqSlots[idx]; + } + + [[nodiscard]] uint64_t getSeqSlotId(int idx = 0) const + { + if (idx == 0) + { + return mRequestId; + } + if (isCfg() && idx == 1) + { + return std::numeric_limits::max() - mRequestId; + } + TLLM_CHECK_WITH_INFO(false, "Sequence slot id is implemented for CFG only"); + return 0; } /// @brief Get the number of subrequests, the expected number of responses under non-streaming mode. In sampling @@ -818,9 +877,9 @@ class GenericLlmRequest for (std::size_t beam = 0; beam < mTokens.size(); ++beam) { auto& beamTokens = mTokens.at(beam); - beamTokens.resize(newPromptLen); + beamTokens.resize(newPromptLen * mNumVocabs); auto& beamUniqueTokens = mUniqueTokens.at(beam); - beamUniqueTokens.resize(newPromptLen); + beamUniqueTokens.resize(newPromptLen * mNumVocabs); if (returnLogProbs()) { @@ -840,7 +899,7 @@ class GenericLlmRequest mPrepopulatedPromptLenTarget = 0; mPrepopulatedPromptLenDraft = 0; mContextChunkSize = mPromptLen; - mSeqSlot.reset(); + mSeqSlots.clear(); } /// @brief Get the maximum length of tokens returned to the client. Use to ensure we don't return to @@ -1174,6 +1233,54 @@ class GenericLlmRequest return mEncoderInputFeatures.value_or(nullptr); } + [[nodiscard]] TensorPtr getDecoderContextFeatures() const + { + return mDecoderContextFeatures.value_or(nullptr); + } + + void setAttentionPriorIdx(SizeType32 attentionPriorIdx, runtime::ModelConfig const& modelConfig) + { + auto const lastIdx = getEncoderOutputLen() - modelConfig.getAttentionPriorWindowRight() - 1; + if (attentionPriorIdx > lastIdx) + { + // no need to move further the attention window will cover all tokens till end + attentionPriorIdx = lastIdx; + } + mAttentionPriorIdx = attentionPriorIdx; + if (mAttentionPriorCounters.size() == 0) + { + // TODO: lazy initialization due to inconsistencies between + // runtime::ITensor::SharedPtr and at::Tensor + mAttentionPriorCounters.resize(getEncoderOutputLen(), 0); + } + mAttentionPriorCounters[attentionPriorIdx]++; + if (attentionPriorIdx >= lastIdx) + { + mAttentionPriorCounterCloseToEnd++; + } + else if (mAttentionPriorCounters[attentionPriorIdx] >= 8) + { + // increment to avoid getting stuck in the same encoder output + setAttentionPriorIdx(attentionPriorIdx + 1, modelConfig); + } + } + + bool isAttentionPriorFinished() const + { + return mAttentionPriorCounterCloseToEnd >= 20; + } + + [[nodiscard]] SizeType32 getAttentionPriorIdx(runtime::ModelConfig const& modelConfig) + { + if (!mAttentionPriorIdx.has_value()) + { + setAttentionPriorIdx(1, modelConfig); + } + // `setAttentionPriorIdx` takes care to avoid getting stuck in the same encoder output, + // it is expected that `mAttentionPriorCounters[mAttentionPriorIdx]` is always < 8 + return mAttentionPriorIdx.value(); + } + void setEncoderOutputHost(TensorPtr encoderOutputHost) { mEncoderOutputHost = std::move(encoderOutputHost); @@ -1879,7 +1986,7 @@ class GenericLlmRequest runtime::SamplingConfig mSamplingConfig; std::optional mEndId{std::nullopt}; std::optional mPadId{std::nullopt}; - std::optional mSeqSlot{std::nullopt}; + std::vector mSeqSlots{}; std::optional mLogitsPostProcessor{std::nullopt}; bool mApplyLogitsPostProcessorBatched{false}; std::optional mClientId{std::nullopt}; @@ -1972,6 +2079,13 @@ class GenericLlmRequest // Encoder input tokens std::optional> mEncoderTokens{std::nullopt}; + // for attention prior, placeholder for where to focus in encoder output + std::optional mAttentionPriorIdx; + // counts how many times certain attention prior idx was attended + std::vector mAttentionPriorCounters; + // counts how many times attention prior idx is close to the end of sequence + SizeType32 mAttentionPriorCounterCloseToEnd{0}; + bool mReturnEncoderOutput; // Encoder output, used to compute cross attention KV-Cache. @@ -1992,6 +2106,9 @@ class GenericLlmRequest // which encoder output shape cannot be inferred from encoder input shape due to downsampling. std::optional mEncoderOutputLength{std::nullopt}; + // decoder context features to replace the token encodings + std::optional mDecoderContextFeatures{std::nullopt}; + // Input cross attention mask. std::optional mCrossAttentionMask{std::nullopt}; @@ -2047,6 +2164,8 @@ class GenericLlmRequest // Context request only. The hashes of the blocks that are requested by the corresponding generation request. std::vector mRequestedBlockHashes; + SizeType32 mNumVocabs; + bool mIsDummyRequest{false}; bool mUseDraftModel{false}; @@ -2227,6 +2346,7 @@ class LlmRequest : public GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, + std::optional decoderContextFeatures = std::nullopt, std::optional crossAttentionMask = std::nullopt, LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, @@ -2262,7 +2382,7 @@ class LlmRequest : public GenericLlmRequest encoderInputTokens ? std::make_optional(std::make_shared(std::move(*encoderInputTokens))) : std::optional>(std::nullopt), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures), encoderOutputLength, - std::move(crossAttentionMask), llmRequestType, + std::move(decoderContextFeatures), std::move(crossAttentionMask), llmRequestType, inputTokenExtraIds ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, diff --git a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h index 13bde6d07a5..04c814b2c00 100644 --- a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h @@ -72,10 +72,14 @@ class RuntimeBuffers static constexpr auto kHostContextLengthsTensorName = "host_context_lengths"; static constexpr auto kSequenceLengthsTensorName = "sequence_length"; static constexpr auto kPromptEmbeddingTableTensorName = "prompt_embedding_table"; + static constexpr auto kDecoderContextFeaturesTensorName = "decoder_context_features"; + static constexpr auto kDecoderContextFeaturesMaskTensorName = "decoder_context_features_mask"; static constexpr auto kTasksTensorName = "tasks"; static constexpr auto kPromptVocabSizeTensorName = "prompt_vocab_size"; static constexpr auto kMRopeRotaryCosSinTensorName = "mrope_rotary_cos_sin"; static constexpr auto kMRopePositionDeltasTensorName = "mrope_position_deltas"; + static constexpr auto kAttentionPriorScoresTensorName = "attention_prior_scores"; + static constexpr auto kAttentionPriorFocusTensorName = "attention_prior_focus"; using SizeType32 = runtime::SizeType32; using TensorPtr = runtime::ITensor::SharedPtr; @@ -147,6 +151,17 @@ class RuntimeBuffers //! Prompt-Tuning std::unique_ptr promptTuningBuffers; + //! Attention prior + TensorPtr attentionPriorScores; // [b*5,] + TensorPtr attentionPriorFocus; // [b,] + bool useAttentionPrior; + bool useContextEmbeddings; + int attentionPriorLookahead; + + //! Overwriting decoder context features + TensorPtr decoderContextFeatures; // [b*t, d] + TensorPtr decoderContextFeaturesMask; // [b*t] + private: //! Runtime //! Type of host tensor: 0 for context, 1 for generation @@ -290,6 +305,9 @@ class RuntimeBuffers runtime::EagleBuffers::Inputs const& eagleBuffers, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig); + void processAttentionPriorScores( + RequestVector const& genRequests, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig); + private: void create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen, runtime::TllmRuntime const& runtime, diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 9dda07d19c6..feddb0f1d67 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -76,7 +76,8 @@ class SamplingConfig std::optional const& noRepeatNgramSize = std::nullopt, std::optional const& numReturnSequences = std::nullopt, std::optional const& minP = std::nullopt, - std::optional> const& beamWidthArray = std::nullopt); + std::optional> const& beamWidthArray = std::nullopt, + std::optional const& cfgScale = std::nullopt); bool operator==(SamplingConfig const& other) const; @@ -100,6 +101,7 @@ class SamplingConfig [[nodiscard]] std::optional getNumReturnSequences() const; [[nodiscard]] std::optional getMinP() const; [[nodiscard]] std::optional> getBeamWidthArray() const; + [[nodiscard]] std::optional getCfgScale() const; void setBeamWidth(SizeType32 beamWidth); void setTopK(std::optional const& topK); @@ -120,6 +122,7 @@ class SamplingConfig void setNumReturnSequences(std::optional const& numReturnSequences); void setMinP(std::optional const& minP); void setBeamWidthArray(std::optional> const& beamWidthArray); + void setCfgScale(std::optional const& cfgScale); private: static SizeType32 checkBeamWidth(SizeType32 beamWidth); @@ -141,6 +144,7 @@ class SamplingConfig static std::optional const& checkMinP(std::optional const& minP); static std::pair> const&, SizeType32 const> const checkBeamWidthArray( std::optional> const& beamWidthArray, SizeType32 const beamWidth); + static std::optional const& checkCfgScale(std::optional const& cfgScale); void updateNumReturnBeams(); friend class Serialization; @@ -192,6 +196,8 @@ class SamplingConfig std::optional mMinP; /// @brief Controls the beam width for each step for Variable-Beam-Width-Search. std::optional> mBeamWidthArray; + /// @brief Controls the cfg scale for sampling. + std::optional mCfgScale; }; /// @brief Additional output that should be gathered. @@ -660,6 +666,7 @@ class Request /// @param encoderInputFeatures Encoder input features for multimodal models. /// @param encoderOutputLength Encoder output length if encoder input and output have different lengths (due to /// convolution down-sampling, etc.) + /// @param decoderContextFeatures Decoder context features for multimodal models. /// @param crossAttentionMask Cross attention mask. /// @param numReturnSequences The number of returning sequences. /// @param eagleConfig The EAGLE speculative decoding configuration @@ -670,6 +677,7 @@ class Request /// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism /// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled. /// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string. + /// @param numVocabs The number of vocabs. Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -692,12 +700,13 @@ class Request std::optional contextPhaseParams = std::nullopt, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, + std::optional decoderContextFeatures = std::nullopt, std::optional crossAttentionMask = std::nullopt, SizeType32 numReturnSequences = 1, std::optional eagleConfig = std::nullopt, std::optional skipCrossAttnBlocks = std::nullopt, std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional cacheSaltID = std::nullopt); + std::optional cacheSaltID = std::nullopt, SizeType32 numVocabs = 1); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -738,6 +747,7 @@ class Request [[nodiscard]] std::optional const& getContextPhaseParams() const; [[nodiscard]] std::optional getEncoderInputFeatures() const; [[nodiscard]] std::optional getEncoderOutputLength() const; + [[nodiscard]] std::optional getDecoderContextFeatures() const; [[nodiscard]] std::optional getCrossAttentionMask() const; [[nodiscard]] RequestType getRequestType() const; [[nodiscard]] std::optional getEagleConfig() const; @@ -747,6 +757,7 @@ class Request [[nodiscard]] std::optional getAllottedTimeMs() const; [[nodiscard]] std::optional getCacheSaltID() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; + [[nodiscard]] SizeType32 getNumVocabs() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); @@ -775,6 +786,7 @@ class Request void setContextPhaseParams(ContextPhaseParams contextPhaseParams); void setEncoderInputFeatures(Tensor encoderInputFeatures); void setEncoderOutputLength(SizeType32 encoderOutputLength); + void setDecoderContextFeatures(Tensor decoderContextFeatures); void setCrossAttentionMask(Tensor crossAttentionMask); void setEagleConfig(std::optional const& eagleConfig); void setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks); @@ -782,6 +794,7 @@ class Request void setLanguageAdapterUid(SizeType32 languageAdapterUid); void setAllottedTimeMs(MillisecondsType allottedTimeMs); void setCacheSaltID(CacheSaltIDType cacheSaltID); + void setNumVocabs(SizeType32 numVocabs); private: friend class Serialization; diff --git a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h index d35080d5588..0fc4b1a3d53 100644 --- a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h +++ b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h @@ -133,6 +133,11 @@ class DefaultDecodingParams { return std::vector{1}; } + + [[nodiscard]] __host__ __device__ static constexpr float getCfgScale() + { + return 1.0f; + } }; } // namespace layers } // namespace tensorrt_llm diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index d0a9e726d13..b1eef2f794d 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -48,7 +48,8 @@ class GptDecoderBatched : public IGptDecoderBatched explicit GptDecoderBatched(CudaStreamPtr stream); void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, - nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override; + nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, + SizeType32 vocab_size = 0) override; void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 606ba3c98a4..d6876606720 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -87,7 +87,8 @@ class IGptDecoderBatched //! @brief Setup the decoder before calling `forward()` virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, - nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) + nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, + SizeType32 vocab_size = 0) = 0; //! @brief Disable Lookahead decoding. diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index b1858573e68..dd08dd258a5 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -100,7 +100,8 @@ class ModelConfig }; explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbLayers, SizeType32 nbAttentionLayers, - SizeType32 nbRnnLayers, SizeType32 nbHeads, SizeType32 hiddenSize, nvinfer1::DataType dtype) + SizeType32 nbRnnLayers, SizeType32 nbHeads, SizeType32 hiddenSize, nvinfer1::DataType dtype, + std::optional> vocabSizes = std::nullopt) : mVocabSize(vocabSize) , mNbLayers(nbLayers) , mNbAttentionLayers(nbAttentionLayers) @@ -141,10 +142,20 @@ class ModelConfig , mManageWeightsType(ManageWeightsType::kDisabled) , mSkipCrossAttnBlocks(false) , mNumLanguages(0) + , mVocabSizes{vocabSizes} + , mUseAttentionPrior(false) + , mUseContextEmbeddings(false) { TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers, "Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers, mNbAttentionLayers, mNbRnnLayers); + if (mVocabSizes) + { + SizeType32 const sizesSum = std::accumulate(mVocabSizes.value().cbegin(), mVocabSizes.value().cend(), 0); + TLLM_CHECK_WITH_INFO( + sizesSum == vocabSize, "Sum of all vocab sizes (%d) must equal to vocabSize (%d)", sizesSum, vocabSize); + } + setNbKvHeads(mNbHeads); } @@ -158,9 +169,93 @@ class ModelConfig return mVocabSize; } - [[nodiscard]] SizeType32 constexpr getVocabSizePadded(SizeType32 worldSize) const noexcept + [[nodiscard]] SizeType32 getNumVocabs() const + { + return mVocabSizes ? mVocabSizes.value().size() : 1; + } + + [[nodiscard]] std::vector getVocabSizes() const + { + return mVocabSizes ? *mVocabSizes : std::vector{mVocabSize}; + } + + [[nodiscard]] bool constexpr useAttentionPrior() const noexcept + { + return mUseAttentionPrior; + } + + [[nodiscard]] bool constexpr useContextEmbeddings() const noexcept + { + return mUseContextEmbeddings; + } + + void constexpr useAttentionPrior(bool useAttentionPrior) noexcept + { + mUseAttentionPrior = useAttentionPrior; + } + + void constexpr useContextEmbeddings(bool useContextEmbeddings) noexcept + { + mUseContextEmbeddings = useContextEmbeddings; + } + + [[nodiscard]] std::vector getComputeAttentionPriorFromLayers() const noexcept + { + return mComputeAttentionPriorFromLayers; + } + + [[nodiscard]] std::vector getApplyAttentionPriorToLayers() const noexcept + { + return mApplyAttentionPriorToLayers; + } + + [[nodiscard]] SizeType32 constexpr getAttentionPriorLookahead() const noexcept { - return (mVocabSize + worldSize - 1) / worldSize * worldSize; + return mAttentionPriorLookahead; + } + + [[nodiscard]] SizeType32 constexpr getAttentionPriorWindowLeft() const noexcept + { + return mAttentionPriorWindowLeft; + } + + [[nodiscard]] SizeType32 constexpr getAttentionPriorWindowRight() const noexcept + { + return mAttentionPriorWindowRight; + } + + void setComputeAttentionPriorFromLayers(std::vector const& computeAttentionPriorFromLayers) noexcept + { + mComputeAttentionPriorFromLayers = computeAttentionPriorFromLayers; + } + + void setApplyAttentionPriorToLayers(std::vector const& applyAttentionPriorToLayers) noexcept + { + mApplyAttentionPriorToLayers = applyAttentionPriorToLayers; + } + + void constexpr setAttentionPriorLookahead(SizeType32 attentionPriorLookahead) noexcept + { + mAttentionPriorLookahead = attentionPriorLookahead; + } + + void constexpr setAttentionPriorWindowLeft(SizeType32 attentionPriorWindowLeft) noexcept + { + mAttentionPriorWindowLeft = attentionPriorWindowLeft; + } + + void constexpr setAttentionPriorWindowRight(SizeType32 attentionPriorWindowRight) noexcept + { + mAttentionPriorWindowRight = attentionPriorWindowRight; + } + + [[nodiscard]] SizeType32 constexpr getVocabSizePadded(SizeType32 worldSize, SizeType32 vocabSize = 0) const noexcept + { + if (vocabSize == 0) + { + vocabSize = mVocabSize; + } + return (vocabSize + worldSize - 1) / worldSize * worldSize; } [[nodiscard]] SizeType32 countLocalLayers( @@ -950,6 +1045,17 @@ class ModelConfig // Language adapter info std::optional mNumLanguages; + + // Size of each vocab if there are multiple vocabs + std::optional> mVocabSizes; + // parameters of attention prior + bool mUseAttentionPrior; + bool mUseContextEmbeddings; + std::vector mComputeAttentionPriorFromLayers; + std::vector mApplyAttentionPriorToLayers; + SizeType32 mAttentionPriorLookahead; + SizeType32 mAttentionPriorWindowLeft; + SizeType32 mAttentionPriorWindowRight; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h index a2fbc7fd0e0..4e1e2ea4361 100644 --- a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h +++ b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h @@ -50,6 +50,7 @@ class GenericPromptTuningParams std::vector promptTuningEnabled; // [batchSize] vector of bool that indicates which requests in a batch have ptuning enabled + SizeType32 numVocabs; // [1], on gpu }; class PromptTuningParams : public GenericPromptTuningParams diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 099dce17312..872856b69db 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -140,6 +140,9 @@ class SamplingConfig configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK()); topP = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topP; }, layers::DefaultDecodingParams::getTopP()); + cfgScale = fuseValues( + configs, [&configs](size_t ci) { return configs[ci].cfgScale; }, + layers::DefaultDecodingParams::getCfgScale()); // Generate a random seed for each samplingConfig with randomSeed == std::nullopt randomSeed = std::vector(configs.size()); @@ -229,6 +232,7 @@ class SamplingConfig SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32) SET_FROM_OPTIONAL(minP, MinP, FloatType) SET_FROM_OPTIONAL(beamWidthArray, BeamWidthArray, std::vector) + SET_FROM_OPTIONAL(cfgScale, CfgScale, FloatType) #undef SET_FROM_OPTIONAL } @@ -278,6 +282,8 @@ class SamplingConfig // valid &= validateVec("lengthPenalty", lengthPenalty, 0.f); valid &= validateVec("noRepeatNgramSize", noRepeatNgramSize, 0); valid &= validateVec("minP", minP, -fltEpsilon, {1.f}); + // TODO: validation of cfgScale? + valid &= validateVec("cfgScale", cfgScale, -10.0f); // TODO: check `beamWidthArray` // Detect greedy sampling and overwrite params. @@ -371,6 +377,8 @@ class SamplingConfig std::optional normalizeLogProbs; + OptVec cfgScale; // [1] or [batchSize] + bool operator==(SamplingConfig const& other) const { return beamWidth == other.beamWidth && numReturnSequences == other.numReturnSequences @@ -383,7 +391,8 @@ class SamplingConfig && lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs - && cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray; + && cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray + && cfgScale == other.cfgScale; } SizeType32 getNumReturnBeams() const diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index c0482deb554..a207e50fdbf 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -30,14 +30,10 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager { if (llmReq->isFirstContextChunk()) { - auto const requestId = llmReq->mRequestId; auto const promptLen = llmReq->mPromptLen; auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; auto draftLength = llmReq->getNumDraftTokens(); - // Allocate/Reuse KV cache - kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq); - // EagleNet will increment kv cache up to maxPathLen to account for accepted tokens. // Then up to maxDecodingDraftTokens will be used to generate next draft tokens. if (modelConfig.getSpeculativeDecodingMode().isEagle()) @@ -45,26 +41,29 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager draftLength = modelConfig.getSpeculativeDecodingModule().getMaxPathLen() + modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens(); } - - // Allocate more KV cache for speculative decoding - if (draftLength > 0) + for (int i = 0; i < llmReq->getNumSequences(); i++) { - for (SizeType32 di = 0; di < draftLength; ++di) + auto const requestId = llmReq->getSeqSlotId(i); + // Allocate/Reuse KV cache + kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq); + // Allocate more KV cache for speculative decoding + if (draftLength > 0) { - kvCacheManager.addToken(requestId); + for (SizeType32 di = 0; di < draftLength; ++di) + { + kvCacheManager.addToken(requestId); + } + } + if (crossKvCacheManager) + { + crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq); } - } - - if (crossKvCacheManager) - { - crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq); } } } for (auto const& llmReq : generationRequests) { - auto const requestId = llmReq->mRequestId; auto decodingTokens = llmReq->getNumDraftTokens() + 1; // EagleNet will increment kv cache up to maxPathLen to account for accepted tokens. @@ -75,9 +74,14 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager + modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens(); } - for (SizeType32 di = 0; di < decodingTokens; ++di) + for (int i = 0; i < llmReq->getNumSequences(); i++) { - kvCacheManager.addToken(requestId); + + auto const requestId = llmReq->getSeqSlotId(i); + for (SizeType32 di = 0; di < decodingTokens; ++di) + { + kvCacheManager.addToken(requestId); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp index 514d100fe58..894521e5374 100644 --- a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp +++ b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp @@ -35,15 +35,27 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana // Skip assigning sequence slot for DISAGG_GENERATION_INIT request continue; } - auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot) + auto const isReqNew = (llmReq->isContextInitState() && llmReq->mSeqSlots.empty()) || (llmReq->isDisaggGenerationTransmissionComplete()); if (isReqNew && llmReq->getReturnPerfMetrics()) { llmReq->setFirstScheduledTime(); } - auto const reqSeqSlot = seqSlotManager.getSequenceSlot(isReqNew, llmReq->mRequestId); - TLLM_CHECK_WITH_INFO(reqSeqSlot, "Unable to get batch slot for request ID %lu", llmReq->mRequestId); - llmReq->mSeqSlot = reqSeqSlot; + + for (int i = 0; i < llmReq->getNumSequences(); i++) + { + auto const reqSeqSlot = seqSlotManager.getSequenceSlot(isReqNew, llmReq->getSeqSlotId(i)); + TLLM_CHECK_WITH_INFO( + reqSeqSlot, "Unable to get batch slot for request ID %lu", llmReq->getSeqSlotId(i)); + if ((int) llmReq->mSeqSlots.size() >= i + 1) + { + llmReq->mSeqSlots[i] = reqSeqSlot.value(); + } + else + { + llmReq->mSeqSlots.push_back(reqSeqSlot.value()); + } + } } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 168ea89693f..93195cc03bf 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -44,7 +44,7 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest { size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); constexpr SizeType32 beam{0}; - auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); + auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest, beam); auto poolNum = cacheManager->getBlockManager().getNumPools(); if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) { @@ -70,7 +70,7 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques constexpr SizeType32 beam{0}; return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); } - return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest); } bool CacheFormatter::needSendCache( diff --git a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp index 9c9c56ba9d6..b554ca93a4a 100644 --- a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp @@ -407,9 +407,14 @@ std::tuple MaxUtilizationScheduler::operator()( // If we can't allocate a started request, we need to start freeing started requests // from the end of the vector and try again // Here we simulate freeing the kvCache blocks associated with that sequence - kvCacheManager.schedulingRemoveSequence((*lastStartedReqIt)->mRequestId); + for (int i = 0; i < (*lastStartedReqIt)->getNumSequences(); i++) + { + auto const requestId = (*lastStartedReqIt)->getSeqSlotId(i); + kvCacheManager.schedulingRemoveSequence((*lastStartedReqIt)->getSeqSlotId(i)); + TLLM_LOG_DEBUG( + "MaxUtilizationScheduler: request ID %lu -> pause", (*lastStartedReqIt)->getSeqSlotId(i)); + } pausedRequests.emplace_back(*lastStartedReqIt); - TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> pause", (*lastStartedReqIt)->mRequestId); reqItEnd = std::next(lastStartedReqIt).base(); } else diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 3335d69a015..871020b2dde 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -75,7 +75,7 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens() + disaggFirstGenTokenSize; // Get position of the current sequence in the decoder - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); batchSlotsRange[batchIdx] = seqSlot; fillValuesRange[batchIdx] = currentSequenceLen; ++batchIdx; @@ -661,8 +661,8 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs; - TLLM_CHECK(llmReq->mSeqSlot.has_value()); - auto const batchSlot = llmReq->mSeqSlot.value(); + TLLM_CHECK(!llmReq->mSeqSlots.empty()); + auto const batchSlot = llmReq->mSeqSlots.at(0); auto const batchSize = decoderState.getMaxNumSequences(); TLLM_CHECK(0 <= batchSlot && batchSlot < batchSize); @@ -676,6 +676,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon decoderState.setBeamWidth(batchSlot, beamWidth); auto const promptLen = llmReq->getPromptLen(); + auto const numVocabs = modelConfig.getNumVocabs(); SizeType32 numDecodingEngineTokens{1}; if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) @@ -709,7 +710,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon initializeLogProbs(dJointOutput, batchSlot, samplingConfig, decoderBufferManager); auto const& reqTokens = llmReq->getTokens(0); - TLLM_CHECK(reqTokens.size() == static_cast(promptLen)); + TLLM_CHECK(reqTokens.size() == static_cast(promptLen * numVocabs)); TensorPtr requestIds = ITensor::slice(inputIds, inputOffset, promptLen); // Copy to pinned host memory (don't care about stream of bufferManager) decoderBufferManager.copy(reqTokens.data(), *requestIds); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 527291b220b..62dd68f76ce 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -691,8 +691,7 @@ class CacheReceiver::Impl if (!disableSelectiveCacheTransfer) { auto* cacheManager = mFormatter->getCacheManager(); - auto blockRange - = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + auto blockRange = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest); requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); } diff --git a/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp b/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp index 56fd393c68d..b538d9cc342 100644 --- a/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp @@ -121,6 +121,11 @@ void EncoderBuffers::updateBufferSizes(RequestVector const& requests, ModelConfi { encoderInputLen += req->getEncoderInputLen(); encoderOutputLen += req->getEncoderOutputLen(); + if (req->isCfg()) + { + // for CFG, repeat the encoder output twice + encoderOutputLen += req->getEncoderOutputLen(); + } maxInputLengthInBatch = std::max(maxInputLengthInBatch, req->getEncoderInputLen()); // Decoder input is encoder output } @@ -222,40 +227,59 @@ void EncoderBuffers::setFromInputs(RequestVector const& requests, ModelConfig co { SizeType32 const inputLength = llmReq->getEncoderInputLen(); SizeType32 const outputLength = llmReq->getEncoderOutputLen(); - if (llmReq->getEncoderInputFeatures()) + for (int s = 0; s < llmReq->getNumSequences(); s++) { - auto const& reqFeatures - = llmReq - ->getEncoderInputFeatures(); // whisper: [length, featureDim]; Vision: [batch_size, channel, W, H] - TLLM_LOG_DEBUG("EncoderBuffers::setFromInputs - request id = %d, input features length = %d", - llmReq->mRequestId, inputLength); - manager.copy(*reqFeatures, *ITensor::slice(inputFeatures, offset, inputLength)); - offset += inputLength; - } - else - { - auto const& reqTokens = *llmReq->getEncoderTokens().value(); - inputIdsAll.insert(inputIdsAll.end(), reqTokens.begin(), reqTokens.end()); - if (tokenTypeIds) + if (llmReq->getEncoderInputFeatures()) { - tokenTypeIdsAll.insert( - tokenTypeIdsAll.end(), tokenTypeIdsReserved.begin(), tokenTypeIdsReserved.begin() + inputLength); + if (s == 0) + { + // copy input features from request to the buffer for conditional generation + auto const& reqFeatures = llmReq->getEncoderInputFeatures(); // whisper: [length, featureDim]; + // Vision: [batch_size, channel, W, H] + TLLM_LOG_DEBUG("EncoderBuffers::setFromInputs - request id = %d, input features length = %d", + llmReq->mRequestId, inputLength); + manager.copy(*reqFeatures, *ITensor::slice(inputFeatures, offset, inputLength)); + offset += inputLength; + } + else if (s == 1) + { + // need to add dummy input of zeros for CFG + auto uncondFeatures = ITensor::slice(inputFeatures, offset, inputLength); + manager.setMem(*uncondFeatures, 0); + offset += inputLength; + } + else + { + TLLM_CHECK_WITH_INFO( + false, "Unexpected sequence index for llmReq [%ld]: %d", llmReq->mRequestId, s); + } } + else + { + // TODO: CFG support for encoder that processes tokens is not implemented yet + auto const& reqTokens = *llmReq->getEncoderTokens().value(); + inputIdsAll.insert(inputIdsAll.end(), reqTokens.begin(), reqTokens.end()); + if (tokenTypeIds) + { + tokenTypeIdsAll.insert(tokenTypeIdsAll.end(), tokenTypeIdsReserved.begin(), + tokenTypeIdsReserved.begin() + inputLength); + } + } + if (positionIds) + { + SizeType32 const length = modelConfig.isWhisper() ? outputLength : inputLength; + positionIdsAll.insert( + positionIdsAll.end(), positionIdsReserved.begin(), positionIdsReserved.begin() + length); + } + if (modelConfig.useLanguageAdapter()) + { + auto const languageAdapterRouting + = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), inputLength); + languageAdapterRoutingAll.insert(languageAdapterRoutingAll.end(), std::begin(languageAdapterRouting), + std::end(languageAdapterRouting)); + } + inputLengthsAll.push_back(inputLength); } - if (positionIds) - { - SizeType32 const length = modelConfig.isWhisper() ? outputLength : inputLength; - positionIdsAll.insert( - positionIdsAll.end(), positionIdsReserved.begin(), positionIdsReserved.begin() + length); - } - if (modelConfig.useLanguageAdapter()) - { - auto const languageAdapterRouting - = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), inputLength); - languageAdapterRoutingAll.insert( - languageAdapterRoutingAll.end(), std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); - } - inputLengthsAll.push_back(inputLength); } // copy inputs from host to device @@ -396,6 +420,10 @@ void EncoderBuffers::rearrangeOutputs(RequestVector const& requests, ModelConfig } } offset += size; + if (req->isCfg()) + { + offset += size; + } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -478,16 +506,25 @@ void EncoderBuffers::setBufferSizes(RequestVector const& contextRequests, Reques for (auto const& llmReq : contextRequests) { - numRequests += 1; - encoderInputLen += llmReq->getEncoderInputLen(); + numRequests += llmReq->getNumSequences(); encoderOutputLen += llmReq->getEncoderOutputLen(); + if (llmReq->isCfg()) + { + encoderInputLen += llmReq->getEncoderInputLen(); + encoderOutputLen += llmReq->getEncoderOutputLen(); + } maxInputLengthInBatch = std::max(maxInputLengthInBatch, llmReq->getEncoderInputLen()); } for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - numRequests += reqBeamWidth; // tile by beam width + encoderOutputLen += llmReq->getEncoderOutputLen(); + if (llmReq->isCfg()) + { + encoderOutputLen += llmReq->getEncoderOutputLen(); + } + auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; + numRequests += reqBeamWidth * llmReq->getNumSequences(); // tile by beam width maxInputLengthInBatch = std::max(maxInputLengthInBatch, llmReq->getEncoderInputLen()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -523,17 +560,26 @@ void EncoderBuffers::fill( bool isCtx = llmReq->isContextInitState(); if (isCtx) { + // copy encoder output to encoder output buffer for both ctx and gen requests, + // disable freeing enc buffer in llm request for it size = llmReq->getEncoderOutputLen(); - auto const encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); + auto encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); manager.copy(*llmReq->getEncoderOutput(), *encoderOutputSlice); offset += size; - inputLengthsAll.emplace_back(size); + if (llmReq->isCfg()) + { + auto encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); + manager.setMem(*encoderOutputSlice, 0); + offset += size; + inputLengthsAll.emplace_back(size); + } } else { auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - std::fill_n(std::back_inserter(inputLengthsAll), reqBeamWidth, + auto const numSeq = llmReq->getNumSequences(); + std::fill_n(std::back_inserter(inputLengthsAll), reqBeamWidth * numSeq, llmReq->getEncoderOutputLen()); // although encoder output is not needed, gen phase still needs the // encoder length info for cross kv cache. Also tile by beam width } diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 3a68d03eb69..fee15e6cb9c 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -87,7 +87,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) { continue; } - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); if (llmReq->isContextInitState() && llmReq->isFirstContextChunk()) { // The request is in the first context forward step (considering kv cache reuse). @@ -180,7 +180,7 @@ void GuidedDecoder::execute(DecoderInputBuffers const& decoderInputBuffers, Buff auto const& guidedDecodingParams = llmReq->getGuidedDecodingParams(); if (guidedDecodingParams.has_value()) { - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const& logits = decoderInputBuffers.logits.at(requestIdx); auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot}); diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index df3840c14b4..3beb26ef77d 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -22,6 +22,7 @@ #include "tensorrt_llm/batch_manager/medusaBuffers.h" #include "tensorrt_llm/batch_manager/runtimeBuffers.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/kernels/cfgKernels.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" @@ -70,8 +71,8 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, tr::ITensor::SharedPtr const& logits, std::vector const& numContextLogitsVec, - tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - OptionalRef medusaBuffers) const + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef medusaBuffers, + SizeType32 vocabId) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleContextLogits); @@ -121,11 +122,31 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re // Get the logits from the last context token and draft tokens auto const numDecoderLogits = 1 + draftLength; - auto const seqSlot = llmReq->mSeqSlot.value(); + TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); + // this is CFG support implementation, where we advance the logits index through the unconditional logits + if (llmReq->isCfg()) + { + logitsIndex += numContextLogits + draftLength; + TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); + // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) + + float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); + SizeType32 vocabOffset = 0; + auto vocabSizes = modelConfig.getVocabSizes(); + for (SizeType32 i = 0; i < vocabId; ++i) + { + vocabOffset += vocabSizes[i]; + } + tensorrt_llm::kernels::invokeCfg( + manager.getStream(), logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); + } + auto const seqSlot = llmReq->mSeqSlots.at(0); + if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + // speculative decoding is not supported for numVocabs > 1 auto& medusaLogitsHeads = inputBuffers.predictedDraftLogits.at(seqSlot); TLLM_CHECK(medusaBuffers); setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice, @@ -159,14 +180,34 @@ SizeType32 HandleContextLogits::operator()(DecoderInputBuffers& inputBuffers, Re } else { - decoderLogits = logitsView; - decoderLogits->unsqueeze(1); + auto curVocablogitsView = logitsView; + auto const logitsViewShape = logitsView->getShape(); + auto vocabSizes = modelConfig.getVocabSizes(); + if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs + { + SizeType32 offset = 0; + for (SizeType32 i = 0; i < vocabId; ++i) + { + offset += vocabSizes[i]; + } + auto const vocabSizePadded = logitsViewShape.d[1]; + curVocablogitsView = ITensor::slice(logitsView, {0, offset}, vocabSizes[vocabId]); // [vocabSize,] + curVocablogitsView + = ITensor::view(curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); + } + auto const updateLogitsViewShape = curVocablogitsView->getShape(); + decoderLogits = ITensor::view(curVocablogitsView, + ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); } decoderRequests.push_back(llmReq); allDecoderLogits.emplace_back(std::move(decoderLogits)); } ++batchIndex; + if (llmReq->isCfg()) + { + ++batchIndex; + } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 5018ae36290..40b7ac2f6b1 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/kernels/cfgKernels.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" @@ -78,11 +79,18 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, RequestVector const& generationRequests, tr::ITensor::SharedPtr const& logits, tr::SizeType32 logitsIndex, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, OptionalRef genRuntimeBuffers, - OptionalRef medusaBuffers) const + OptionalRef medusaBuffers, SizeType32 vocabId) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleGenerationLogits); + auto vocabSizes = modelConfig.getVocabSizes(); + SizeType32 vocabOffset = 0; + for (SizeType32 i = 0; i < vocabId; ++i) + { + vocabOffset += vocabSizes[i]; + } + auto& decoderRequests = inputBuffers.decoderRequests; decoderRequests.reserve(decoderRequests.size() + generationRequests.size()); auto& allDecoderLogits = inputBuffers.logits; @@ -91,7 +99,7 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque for (auto const& llmReq : generationRequests) { auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const draftLength = llmReq->getNumDraftTokens(); auto const numLogits = draftLength + reqBeamWidth; @@ -108,6 +116,18 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid(*logitsView, manager, "logits") == false, "Found invalid number (NaN or Inf) in logits"); + // CFG implementation: get unconditional logits and add them to logitsView + if (llmReq->isCfg()) + { + logitsIndex += numLogits; + TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex, numLogits); + // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) + float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); + tensorrt_llm::kernels::invokeCfg( + manager.getStream(), logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); + } + auto const logitsViewShape = logitsView->getShape(); + TLLM_CHECK(llmReq->isGenerationInProgressState()); TensorPtr decoderLogits; if (reqBeamWidth > 1) @@ -117,8 +137,16 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque } else { - decoderLogits = logitsView; - decoderLogits->unsqueeze(1); + auto curVocablogitsView = logitsView; + if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs + { + curVocablogitsView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSizes[vocabId]); // [vocabSize,] + curVocablogitsView = ITensor::view( + curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); // [numLogits == 1, vocabSize] + } + auto const updateLogitsViewShape = curVocablogitsView->getShape(); + decoderLogits = ITensor::view( + curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); } decoderRequests.push_back(llmReq); allDecoderLogits.emplace_back(std::move(decoderLogits)); @@ -147,6 +175,7 @@ void HandleGenerationLogits::operator()(DecoderInputBuffers& inputBuffers, Reque } if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + // speculative decoding is not supported for numVocabs > 1 auto& medusaLogitsHeads = inputBuffers.predictedDraftLogits.at(seqSlot); TLLM_CHECK(medusaBuffers); setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice, diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index eaa6b1cac4e..b85cf6caac6 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -562,13 +562,14 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager, + SizeType32 numVocabs, std::shared_ptr kvCacheConnectorManager, std::optional agentConfig) : mNumLayers{static_cast(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} , mStream{stream} , mCacheType{cacheType} + , mNumVocabs{numVocabs} { if (agentConfig.has_value()) mLoopbackAgent = makeLoopbackAgent("nixl", &agentConfig.value()); @@ -603,7 +604,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, /*isSWA=*/windowSize < maxSequenceLength, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority, - mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, mLoopbackAgent); + mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, numVocabs, mLoopbackAgent); } auto const numAllPools = getNumPools(); @@ -658,7 +659,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager, + std::shared_ptr kvCacheConnectorManager, SizeType32 numVocabs, std::shared_ptr loopbackAgent) : mDataType{dtype} , mWindowSize{windowSize} @@ -686,6 +687,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mKvCacheConnectorManager{std::move(kvCacheConnectorManager)} + , mNumVocabs{numVocabs} { std::map numLayersPerPool; @@ -788,9 +790,10 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co { auto cacheBlockIds = sequence.getCacheBlockIds(windowSize); auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx); + auto const numVocabs = llmRequest.getNumVocabs(); - auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false); + auto blockedUniqueTokens = chopVectorIntoBlocks( + uniqueTokens, uniqueTokens.size() - numVocabs, getTokensPerBlock() * numVocabs, false); auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], windowSize); } @@ -1089,8 +1092,9 @@ void WindowBlockManager::offloadBlock( std::optional WindowBlockManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { + auto const numVocabs = llmRequest.getNumVocabs(); auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size(), mTokensPerBlock, false); + = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size(), mTokensPerBlock * numVocabs, false); auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); BlockKey ret; ret.loraTaskId = llmRequest.getLoraTaskId(); @@ -1137,7 +1141,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& { KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId(); - numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size(); + numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size() / mNumVocabs; if (perBlockRetentions[bi].retentionPriority.has_value() && matchingBlock->getPriority() != perBlockRetentions[bi].retentionPriority && mEventManager) { @@ -1278,7 +1282,9 @@ void WindowBlockManager::addSequence( : *(llmRequest.getEncoderUniqueTokens().value()); // Ignore last token because it can't be recovered - auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, inputLength - 1, mTokensPerBlock, true); + auto const numVocabs = llmRequest.getNumVocabs(); + auto blockedUniqueTokens = chopVectorIntoBlocks( + uniqueTokens, (inputLength - 1) * numVocabs, mTokensPerBlock * numVocabs, true); // Add empty block if last token is separated if (inputLength % mTokensPerBlock == 1) { @@ -1305,7 +1311,8 @@ void WindowBlockManager::addSequence( TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks); auto const prepopulatedPromptLen - = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory); + = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions, mode, directory) / numVocabs; + mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); @@ -1319,7 +1326,7 @@ void WindowBlockManager::addSequence( llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", - llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); + llmRequest.getSeqSlotId(), inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); } void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse) @@ -1479,7 +1486,8 @@ SizeType32 WindowBlockManager::storeBlocks( TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(), block->getBlockId()); needMatch = false; // no matching needed for following blocks - block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); + block->setBlockKey( + blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock * mNumVocabs); block->setPrevBlock(searchRoot); block->setPrevBlockInSeq(searchRoot); searchRoot->addNextBlock(blockKey, block); @@ -1711,11 +1719,13 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef< { // If llmRequest is provided, store the blocks for reuse. auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0); + auto const numVocabs = llmRequest->getNumVocabs(); // Only (length - 1) tokens of the sequence have their kv-state // recorded in kv-cache. We assume the last token's state is not filled yet. - auto const usableSize = static_cast(uniqueTokens.size()) - 1; + auto const usableSize = static_cast(uniqueTokens.size()) - numVocabs; auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true); + = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock * numVocabs, + /*allowPartial=*/true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); std::vector cacheBlockIds(allocatedBlocks.size()); @@ -1772,11 +1782,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, - bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse) + bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse, SizeType32 numVocabs) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) + enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse, + numVocabs) { } @@ -1787,12 +1798,12 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager) + SizeType32 numVocabs, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, - copyOnPartialReuse, kvCacheConnectorManager) + copyOnPartialReuse, numVocabs, kvCacheConnectorManager) { } @@ -1803,7 +1814,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager) + SizeType32 numVocabs, std::shared_ptr kvCacheConnectorManager) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) @@ -1813,7 +1824,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) + enablePartialReuse, copyOnPartialReuse, numVocabs, std::move(kvCacheConnectorManager)) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} { @@ -1837,11 +1848,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, - std::shared_ptr kvCacheConnectorManager) + SizeType32 numVocabs, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, - std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) + std::move(eventManager), enablePartialReuse, copyOnPartialReuse, numVocabs, std::move(kvCacheConnectorManager)) { } @@ -1942,7 +1953,8 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep( auto const numUnSharedBlocks = tc::ceilDiv(numUnSharedTokens, getTokensPerBlock()) * req.mSamplingConfig.beamWidth; auto const numRequiredBlocks = numSharedBlocks + numUnSharedBlocks; - return numRequiredBlocks; + // we need more blocks if there are multiple sequences in this request + return numRequiredBlocks * req.getNumSequences(); } if (req.isGenerationInProgressState()) @@ -1980,7 +1992,7 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, { if (req.isContextInitState() && req.getContextCurrentPosition() == 0) { - return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock()); + return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock()) * req.getNumSequences(); } return 0; // cross KV cache doesn't grow after the initial context phase @@ -1988,23 +2000,26 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, auto const temporaryAttentionWindow = mBlockManager.getWindowSizeMetadata(windowSize).temporaryAttentionWindow; - SizeType32 const numContextBlocks - = (std::min(req.mPromptLen, windowSize + temporaryAttentionWindow) + mSinkBubbleLength) / getTokensPerBlock(); - - SizeType32 const numTotalBlocksPerBeam = tc::ceilDiv( - std::min(req.mPromptLen + req.mMaxNewTokens, windowSize + temporaryAttentionWindow) + mSinkBubbleLength, - getTokensPerBlock()); + SizeType32 const numContextBlocks = req.getNumSequences() + * (std::min(req.mPromptLen, windowSize + temporaryAttentionWindow) + mSinkBubbleLength) / getTokensPerBlock(); + SizeType32 const numTotalBlocksPerBeam = req.getNumSequences() + * tc::ceilDiv( + std::min(req.mPromptLen + req.mMaxNewTokens, windowSize + temporaryAttentionWindow) + mSinkBubbleLength, + getTokensPerBlock()); SizeType32 const numGenBlocksPerBeam = numTotalBlocksPerBeam - numContextBlocks; SizeType32 numAllocBlocksPerBeam = 0; { std::scoped_lock lck(mSequencesMtx); - auto const seqIt = mSequences.find(req.mRequestId); - if (seqIt != mSequences.end()) + for (int i = 0; i < req.getNumSequences(); i++) { - auto const& seq = seqIt->second; - numAllocBlocksPerBeam = seq.getCacheBlockIds(windowSize).at(0).size(); + auto const seqIt = mSequences.find(req.getSeqSlotId(i)); + if (seqIt != mSequences.end()) + { + auto const& seq = seqIt->second; + numAllocBlocksPerBeam += seq.getCacheBlockIds(windowSize).at(0).size(); + } } } @@ -2197,18 +2212,21 @@ void KVCacheManager::addSequence( void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { - auto const requestId = llmRequest.mRequestId; - if (mSequences.find(requestId) != mSequences.end()) + for (int i = 0; i < llmRequest.getNumSequences(); i++) { - auto& sequence = getSequence(requestId); - if (mEnableBlockReuse && !llmRequest.isDummyRequest()) + auto const requestId = llmRequest.getSeqSlotId(i); + if (mSequences.find(requestId) != mSequences.end()) { - mBlockManager.storeContextBlocks(sequence, llmRequest); + auto& sequence = getSequence(requestId); + if (mEnableBlockReuse && !llmRequest.isDummyRequest()) + { + mBlockManager.storeContextBlocks(sequence, llmRequest); + } + } + else + { + TLLM_LOG_WARNING("[kv cache manager] storeContextBlocks: Can not find sequence for request %lu", requestId); } - } - else - { - TLLM_LOG_WARNING("[kv cache manager] storeContextBlocks: Can not find sequence for request %lu", requestId); } } diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index fb6aa5cc67f..1ffb85e6b6c 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -135,12 +135,12 @@ std::optional LlmRequest::createResult(bool useFastLogits, int for (SizeType32 beam = 0; beam < nbBeams; ++beam) { auto const& tokens = getTokens(beam); - auto const nbTokensOut = calculateNbTokensOut(tokens.size()); + auto const nbTokensOut = calculateNbTokensOut(tokens.size() / getNumVocabs()); if (nbTokensOut > 0) { - auto const first = tokens.data() + startTokenPos; - result.outputTokenIds.at(beam).assign(first, first + nbTokensOut); + auto const first = tokens.data() + startTokenPos * getNumVocabs(); + result.outputTokenIds.at(beam).assign(first, first + nbTokensOut * getNumVocabs()); } } @@ -322,7 +322,7 @@ std::shared_ptr LlmRequest::createChildRequest(RequestIdType request childReq->mSequenceIndex = mChildRequests.size() + 1; childReq->mParentRequestId = this->mRequestId; childReq->mSequenceFinalVec = this->mSequenceFinalVec; - childReq->mSeqSlot.reset(); + childReq->mSeqSlots.clear(); // To ensure different randomness across children, assign a unique random seed to each child // by adding its sequence index to the base seed. If no seed is provided, the parent's seed defaults to 0. diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index dbb90da326a..fcbf689555b 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -49,7 +49,8 @@ bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool rep for (size_t batchIdx = 0; batchIdx < inputBuffers.decoderRequests.size(); ++batchIdx) { auto const& llmReq = inputBuffers.decoderRequests.at(batchIdx); - auto& logits = inputBuffers.logits.at(batchIdx); + auto& logits = inputBuffers.logits.at( + batchIdx); // Check if this should be ...at(llmReq->mSeqSlots.at(0)) instead for CFG // Invoke non-batched processor or collect arguments for batched processor if (llmReq->mLogitsPostProcessor) diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index c9b2bb0b937..a2baf9b31bc 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -90,7 +90,7 @@ std::pair, std::vector> getActiveSlots(Reque std::vector generationSteps; for (auto const& llmReq : decoderRequests) { - activeSlots.push_back(llmReq->mSeqSlot.value()); + activeSlots.push_back(llmReq->mSeqSlots.at(0)); generationSteps.push_back(llmReq->getDecodingIter()); } diff --git a/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp b/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp index 6fc7977ef8f..eb333726f95 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp @@ -51,7 +51,9 @@ void RnnStateBuffers::fillSlotMappings( SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) { - auto const seqSlot = llmReq->mSeqSlot.value(); + // TODO: rnn state does not support CFG yet + TLLM_CHECK_WITH_INFO(!llmReq->isCfg(), "rnn state buffers do not support CFG yet"); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; rnnStateManager->fillSlotMapping(*slotMappingHost, batchIdx, seqSlot, reqBeamWidth); ++batchIdx; diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 691fb9c7efd..d28dfb1e232 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -79,6 +79,10 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + useAttentionPrior = modelConfig.useAttentionPrior(); + useContextEmbeddings = modelConfig.useContextEmbeddings(); + attentionPriorLookahead = modelConfig.getAttentionPriorLookahead(); + auto const& manager = runtime.getBufferManager(); auto const& engine = runtime.getEngine(); @@ -118,6 +122,20 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, inputsIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + if (useAttentionPrior) + { + // probs in attention kernel are in full precision + attentionPriorScores = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT); + attentionPriorFocus = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + } + + if (useContextEmbeddings) + { + auto const featsType = engine.getTensorDataType(kDecoderContextFeaturesTensorName); + decoderContextFeatures = manager.emptyTensor(MemoryType::kGPU, featsType); + decoderContextFeaturesMask = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kBOOL); + } + if (worldConfig.isPipelineParallel()) { hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType()); @@ -246,17 +264,21 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques NVTX3_SCOPED_RANGE(runtimeBuffersSetBufferSizes); // set context sizes - numContextRequests = static_cast(contextRequests.size()); + numContextRequests = 0; + for (auto const& llmReq : contextRequests) + { + numContextRequests += llmReq->getNumSequences(); + } auto numContextLogits = numContextRequests; numContextTokens = 0; maxContextLength = 0; for (auto const& llmReq : contextRequests) { auto const draftLength = llmReq->isLastContextChunk() ? llmReq->getNumDraftTokens() : 0; - numContextLogits += draftLength; + numContextLogits += draftLength * llmReq->getNumSequences(); auto const contextChunkSize = llmReq->getContextChunkSize(); - numContextTokens += contextChunkSize + draftLength; + numContextTokens += (contextChunkSize + draftLength) * llmReq->getNumSequences(); if (maxContextLength < llmReq->mPromptLen) { maxContextLength = llmReq->mPromptLen; @@ -264,15 +286,19 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques } // set generation sizes - numGenRequests = static_cast(genRequests.size()); + numGenRequests = 0; + for (auto const& llmReq : genRequests) + { + numGenRequests += llmReq->getNumSequences(); + } numGenSequences = 0; numGenTokens = 0; for (auto const& llmReq : genRequests) { auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - numGenSequences += reqBeamWidth; + numGenSequences += reqBeamWidth * llmReq->getNumSequences(); auto const draftLen = llmReq->getNumDraftTokens(); - numGenTokens += draftLen + reqBeamWidth; + numGenTokens += (draftLen + reqBeamWidth) * llmReq->getNumSequences(); } numLogits = numContextLogits + numGenTokens; @@ -381,7 +407,7 @@ void RuntimeBuffers::reshape(TllmRuntime const& runtime, ModelConfig const& mode seqSlotsDevice->reshape(numRequestsShape); auto const numTokens = getNumTokens(); - inputsIds->reshape(ITensor::makeShape({numTokens})); + inputsIds->reshape(ITensor::makeShape({numTokens * modelConfig.getNumVocabs()})); if (modelConfig.useMrope()) { @@ -416,6 +442,18 @@ void RuntimeBuffers::reshape(TllmRuntime const& runtime, ModelConfig const& mode tensor->reshape(shape); } + if (useAttentionPrior) + { + attentionPriorScores->reshape(ITensor::makeShape({attentionPriorLookahead * getNumSequences()})); + attentionPriorFocus->reshape(ITensor::makeShape({getNumSequences()})); + } + + if (useContextEmbeddings) + { + decoderContextFeatures->reshape(ITensor::makeShape({numTokens, modelConfig.getHiddenSize()})); + decoderContextFeaturesMask->reshape(ITensor::makeShape({numTokens})); + runtime.getBufferManager().setMem(*decoderContextFeaturesMask, 0); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -491,9 +529,11 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request for (auto const& llmReq : requests) { // Get position of the current sequence in the decoder - auto const seqSlot = llmReq->mSeqSlot.value(); - seqSlotIndices[batchIdx] = seqSlot; - ++batchIdx; + for (auto const& seqSlot : llmReq->mSeqSlots) + { + seqSlotIndices[batchIdx] = seqSlot; + ++batchIdx; + } } } @@ -501,11 +541,17 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request manager.copy(*seqSlots, *seqSlotsDevice); } + SizeType32 contextRequestsSize = 0; + for (auto const& llmReq : contextRequests) + { + contextRequestsSize += llmReq->getNumSequences(); + } + // context preparation loop - if (!contextRequests.empty()) + if (contextRequestsSize > 0) { NVTX3_SCOPED_RANGE(contextPrepareLoop); - numContextLogits.resize(contextRequests.size()); + numContextLogits.resize(contextRequestsSize); SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) @@ -516,202 +562,266 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request llmReq->getMaxNumGeneratedTokens() == 0, "Context request should not have generated tokens."); auto const& reqTokens = llmReq->getTokens(0); - auto const& draftTokens = llmReq->getDraftTokens(); - auto const draftLength = llmReq->getNumDraftTokens(); - auto const& positionIds = llmReq->getPositionIds(); - - auto const contextChunkSize = llmReq->getContextChunkSize(); - auto const beginCompute = llmReq->getContextCurrentPosition(); - auto const endCompute = beginCompute + contextChunkSize; - inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, reqTokens.begin() + endCompute); - - logitsIdsHostPtr[totalNumLogits++] = contextChunkSize; - numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1; - - if (llmReq->isLastContextChunk()) + // for CFG requests, add the inputs to the buffer twice + std::vector is_conditional_vec{true}; + if (llmReq->isCfg()) { - inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); - std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1); - totalNumLogits += draftLength; + is_conditional_vec.push_back(false); } - auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0); - contextLengthsHostPtr[batchIdx] = inputLength; - auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition(); - sequenceLengthsHostPtr[batchIdx] = sequenceLen; - - if (static_cast(pastKeyValueLengthsPtr)) + for (auto const& is_conditional : is_conditional_vec) { - pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength; - } + auto const& origTokens = llmReq->getTokens(0); + std::vector dummyTokens; + if (!is_conditional) + { + // that is special token added in "convert_checkpoint", + // that is expanded to all zeros + dummyTokens.assign(origTokens.size(), modelConfig.getVocabSize()); + } - if (positionIds.has_value()) - { - TLLM_CHECK_WITH_INFO(!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); - positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute, - positionIds.value()->begin() + endCompute); - } - else - { - if (isChatGlm) + auto const& reqTokens = is_conditional ? origTokens : dummyTokens; + auto const& draftTokens = llmReq->getDraftTokens(); + auto const draftLength = llmReq->getNumDraftTokens(); + auto const& positionIds = llmReq->getPositionIds(); + + auto const contextChunkSize = llmReq->getContextChunkSize(); + auto const beginCompute = llmReq->getContextCurrentPosition(); + auto const endCompute = beginCompute + contextChunkSize; + inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, + reqTokens.begin() + beginCompute + contextChunkSize * llmReq->getNumVocabs()); + + logitsIdsHostPtr[totalNumLogits++] = contextChunkSize; + numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1; + + if (llmReq->isLastContextChunk()) { - // Specialize for ChatGLM-6B with 2D-Position-Embedding - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); - positionIdsHost.back() = positionIdsHost.back() - 1; + inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); + std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1); + totalNumLogits += draftLength; + } + auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0); + contextLengthsHostPtr[batchIdx] = inputLength; + auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition(); + sequenceLengthsHostPtr[batchIdx] = sequenceLen; - positionIdsHostRow2.resize(totalInputSize + inputLength); - positionIdsHostRow2.back() = 1; + if (static_cast(pastKeyValueLengthsPtr)) + { + pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength; } - else if (isGlm) + + if (positionIds.has_value()) { - // Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position - auto start = inputHost.begin() + totalInputSize; - auto end = start + inputLength; - auto it = std::find_if( - start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; }); - llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength; - - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); - positionIdsHost.back() = llmReq->mMaskPosition; - - positionIdsHostRow2.resize(totalInputSize + inputLength); - positionIdsHostRow2.back() = 1; + TLLM_CHECK_WITH_INFO( + !(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); + positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute, + positionIds.value()->begin() + endCompute); } else { - // Other models - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, - std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute); + if (isChatGlm) + { + // Specialize for ChatGLM-6B with 2D-Position-Embedding + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); + positionIdsHost.back() = positionIdsHost.back() - 1; + + positionIdsHostRow2.resize(totalInputSize + inputLength); + positionIdsHostRow2.back() = 1; + } + else if (isGlm) + { + // Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position + auto start = inputHost.begin() + totalInputSize; + auto end = start + inputLength; + auto it = std::find_if( + start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; }); + llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength; + + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); + positionIdsHost.back() = llmReq->mMaskPosition; + + positionIdsHostRow2.resize(totalInputSize + inputLength); + positionIdsHostRow2.back() = 1; + } + else + { + // Other models + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, + std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute); + } } - } - if (modelConfig.useMrope()) - { - auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value(); - TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize, - "Provided MropeRotarySinCos is %ld and expected is %d.\n", optMropeRotaryCosSin->getShape().d[0], - int(mropeRotaryCosSinSize)); + if (modelConfig.useMrope()) + { + auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value(); + TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize, + "Provided MropeRotarySinCos is %ld and expected is %d.\n", + optMropeRotaryCosSin->getShape().d[0], int(mropeRotaryCosSinSize)); - auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1); - manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx); - } + auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1); + manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx); + } - if (modelConfig.useLanguageAdapter()) - { - auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting( - modelConfig.getNumLanguages().value(), endCompute - beginCompute); - languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), - std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + if (modelConfig.useLanguageAdapter()) + { + auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting( + modelConfig.getNumLanguages().value(), endCompute - beginCompute); + languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), + std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + } + totalInputSize += inputLength; + ++batchIdx; } - totalInputSize += inputLength; - ++batchIdx; } if (rnnStateBuffers) { + // TODO: dont implement CFG for rnn state buffers for now rnnStateBuffers->fillSlotMappings(contextRequests, rnnStateManagerPtr); } + + // set decoder context features and mask + if (useContextEmbeddings) + { + SizeType32 tokenIdx = 0; + for (auto const& llmReq : contextRequests) + { + auto const contextPosition = llmReq->getContextCurrentPosition(); + auto const contextChunkSize = llmReq->getContextChunkSize(); + if (llmReq->getDecoderContextFeatures()) + { + auto const& reqFeatures = llmReq->getDecoderContextFeatures(); + TLLM_CHECK_WITH_INFO(contextPosition + contextChunkSize <= reqFeatures->getShape().d[0], + "Decoder context features [%d, %d], but request is at position %d and chunk size %d", + (int) reqFeatures->getShape().d[0], (int) reqFeatures->getShape().d[1], contextPosition, + contextChunkSize); + // specifying offset and size across 0th dimension + manager.copy(*ITensor::slice(reqFeatures, contextPosition, contextChunkSize), + *ITensor::slice(decoderContextFeatures, tokenIdx, contextChunkSize)); + manager.setMem(*ITensor::slice(decoderContextFeaturesMask, tokenIdx, contextChunkSize), 1); + } + tokenIdx += llmReq->getNumSequences() * contextChunkSize; + } + } } - // generation preparation loop + // generation preparation loop - CHECK THIS LINE ONWARDS if (!genRequests.empty()) { NVTX3_SCOPED_RANGE(genPrepareLoop); - auto const numContextRequests = static_cast(contextRequests.size()); - auto numSequences = numContextRequests; + auto numSequences = contextRequestsSize; + for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - auto const draftLength = llmReq->getNumDraftTokens(); - auto const& draftTokens = llmReq->getDraftTokens(); - auto const numLogits = draftLength + reqBeamWidth; - TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1); - - auto const promptLen = llmReq->mPromptLen; - auto const sequenceLen - = promptLen + llmReq->getMaxNumGeneratedTokens() + static_cast(trtOverlap); - auto const& positionIds = llmReq->getPositionIds(); - for (int beam = 0; beam < reqBeamWidth; ++beam) + for (int s = 0; s < llmReq->getNumSequences(); s++) { - auto const numTokens = llmReq->getNumTokens(beam) + static_cast(trtOverlap); - // TODO: can this be removed completely? - if (!trtOverlap) + auto reqBeamWidth = llmReq->getBeamWidthByIter(); + auto const draftLength = llmReq->getNumDraftTokens(); + auto const& draftTokens = llmReq->getDraftTokens(); + auto const numLogits = draftLength + reqBeamWidth; + TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1); + + auto const promptLen = llmReq->mPromptLen; + auto const sequenceLen + = promptLen + llmReq->getMaxNumGeneratedTokens() + static_cast(trtOverlap); + auto const& positionIds = llmReq->getPositionIds(); + for (int reqBeam = 0; reqBeam < reqBeamWidth; ++reqBeam) { - auto const lastToken = llmReq->getLastTokens(beam); - inputHost.push_back(lastToken); - if (draftLength > 0) + // for CFG, simply use tokens from the 0th beam during generation + int beam = llmReq->isCfg() ? 0 : reqBeam; + auto const numTokens = llmReq->getNumTokens(beam) + static_cast(trtOverlap); + // TODO: can this be removed completely? + if (!trtOverlap) { - inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); + if (llmReq->getNumVocabs() > 1) + { + auto const& beamTokens = llmReq->getTokens(beam); + TLLM_CHECK_WITH_INFO(beamTokens.size() % llmReq->getNumVocabs() == 0, + "Number of tokens needs to be a multiple of number of vocabs!"); + inputHost.insert( + inputHost.end(), beamTokens.cend() - llmReq->getNumVocabs(), beamTokens.cend()); + } + else + { + auto const lastToken = llmReq->getLastTokens(beam); + inputHost.push_back(lastToken); + } + if (draftLength > 0) + { + inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); + } } - } - // If model updates generation position ids do not append them here. - if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds()) - { - if (positionIds.has_value()) - { - TLLM_CHECK_WITH_INFO( - !(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); - auto last_context_position_id = positionIds.value()->back(); - positionIdsHost.push_back( - static_cast(last_context_position_id + sequenceLen - promptLen)); - } - else + // If model updates generation position ids do not append them here. + if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds()) { - if (isChatGlm) // ChatGLM-6B - { - positionIdsHost.push_back(static_cast(promptLen - 2)); - positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); - } - else if (isGlm) + if (positionIds.has_value()) { - positionIdsHost.push_back(llmReq->mMaskPosition); - positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); + TLLM_CHECK_WITH_INFO( + !(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); + auto last_context_position_id = positionIds.value()->back(); + positionIdsHost.push_back( + static_cast(last_context_position_id + sequenceLen - promptLen)); } - else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART + else { - // positionIds is just the size of tokens -1 - positionIdsHost.push_back(numTokens - 1); + if (isChatGlm) // ChatGLM-6B + { + positionIdsHost.push_back(static_cast(promptLen - 2)); + positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); + } + else if (isGlm) + { + positionIdsHost.push_back(llmReq->mMaskPosition); + positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); + } + else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART + { + // positionIds is just the size of tokens -1 + positionIdsHost.push_back(numTokens - 1); + } } } - } - if (modelConfig.useMrope()) - { - auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value(); - mropePositionDeltasHost.push_back(optMropePositionDeltas); + if (modelConfig.useMrope()) + { + auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value(); + mropePositionDeltasHost.push_back(optMropePositionDeltas); + } + + if (modelConfig.useLanguageAdapter()) + { + // Generation requests only have one token per sequence + auto const languageAdapterRouting + = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1); + languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), + std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + } } - if (modelConfig.useLanguageAdapter()) + if (static_cast(pastKeyValueLengthsPtr)) { - // Generation requests only have one token per sequence - auto const languageAdapterRouting - = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1); - languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), - std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + SizeType32 pastKeyValueLength = sequenceLen - 1; + std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength); } - } - - if (static_cast(pastKeyValueLengthsPtr)) - { - SizeType32 pastKeyValueLength = sequenceLen - 1; - std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength); - } - totalInputSize += numLogits; + totalInputSize += numLogits; - std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1); + std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1); - totalNumLogits += numLogits; + totalNumLogits += numLogits; - if (rnnStateBuffers) - { - auto const seqSlot = llmReq->mSeqSlot.value(); - auto& rnnStateManager = *rnnStateManagerPtr; - rnnStateManager.fillSlotMapping(*rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth); + if (rnnStateBuffers) + { + TLLM_CHECK_WITH_INFO(!llmReq->isCfg(), "CFG is not supported for rnn state buffers"); + auto const seqSlot = llmReq->mSeqSlots[0]; + auto& rnnStateManager = *rnnStateManagerPtr; + rnnStateManager.fillSlotMapping( + *rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth); + } + numSequences += reqBeamWidth; } - numSequences += reqBeamWidth; } if (transformerBuffers && maxBeamWidth > 1) @@ -719,26 +829,46 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request transformerBuffers->copyCacheIndirection(genRequests, decoderState.getCacheIndirectionOutput(), stream); } - numSequences = numContextRequests; + numSequences = contextRequestsSize; for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - auto const draftLength = llmReq->getNumDraftTokens(); + for (int s = 0; s < llmReq->getNumSequences(); s++) + { + auto const reqBeamWidth = llmReq->getBeamWidthByIter(); + auto const draftLength = llmReq->getNumDraftTokens(); - auto const contextQLength = llmReq->mPromptLen + draftLength; - auto const sequenceLen - = contextQLength + llmReq->getMaxNumGeneratedTokens() + static_cast(trtOverlap); + auto const contextQLength = llmReq->mPromptLen + draftLength; + auto const sequenceLen + = contextQLength + llmReq->getMaxNumGeneratedTokens() + static_cast(trtOverlap); - std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength); - std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen); - numSequences += reqBeamWidth; + std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength); + std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen); + numSequences += reqBeamWidth; + } } + if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) { // copy from lookahead decoding buffer mLookaheadBuffers->setFromInputs(numContextRequests, numGenRequests, *requestTypes, *seqSlots, decoderState.getLookaheadBuffers(), runtime, modelConfig, worldConfig); } + + if (useAttentionPrior) + { + // set to zero attention prior scores, so scores from different layers can be accumulated + manager.setMem(*attentionPriorScores, 0); + // copy focus indices from llm requests to a buffer + std::vector focus_lst(numContextRequests, 0); + for (auto const& llmReq : genRequests) + { + for (SizeType32 i = 0; i < llmReq->getNumSequences(); i++) + { + focus_lst.push_back(llmReq->getAttentionPriorIdx(modelConfig)); + } + } + manager.copy(focus_lst.data(), *attentionPriorFocus, runtime::MemoryType::kCPU); + } } // check skipCrossAttnBlocks @@ -901,6 +1031,51 @@ void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, R TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +void RuntimeBuffers::processAttentionPriorScores( + RequestVector const& genRequests, TllmRuntime const& runtime, ModelConfig const& modelConfig) +{ + /** + * is called after inference is done. processes the "scores" buffer and sets up + * the index with most attention focus for each request. + */ + if (!useAttentionPrior) + { + TLLM_LOG_WARNING("processing attention prior scores, when attention prior is disabled"); + return; + } + + // copy scores to host + auto const& manager = runtime.getBufferManager(); + auto const& stream = runtime.getStream(); + auto scoresHost + = manager.cpu(ITensor::makeShape({getNumSequences() * attentionPriorLookahead}), nvinfer1::DataType::kFLOAT); + manager.copy(*attentionPriorScores, *scoresHost); + stream.synchronize(); + + // for each generation request, analyze scores and set the attention prior idx + size_t scoresOffset = numContextRequests * attentionPriorLookahead; + auto* scoresHostPtr = bufferCast(*scoresHost); + for (auto const& llmReq : genRequests) + { + size_t prevPriorIdx = llmReq->getAttentionPriorIdx(modelConfig); + float maxScore = scoresHostPtr[scoresOffset]; + int idxShift = 0; + for (int i = 1; i < attentionPriorLookahead; i++) + { + if (scoresHostPtr[scoresOffset + i] > maxScore) + { + maxScore = scoresHostPtr[scoresOffset + i]; + idxShift = i; + } + } + + llmReq->setAttentionPriorIdx(prevPriorIdx + idxShift, modelConfig); + + // TODO: remove hardcode of lookahead size + scoresOffset += attentionPriorLookahead * llmReq->getNumSequences(); + } +} + std::tuple RuntimeBuffers::prepareStep( RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, runtime::decoder::DecoderState const& decoderState, @@ -974,6 +1149,15 @@ void RuntimeBuffers::fillIOMaps(ModelConfig const& modelConfig, WorldConfig cons inputMap.insert_or_assign(kHostContextLengthsTensorName, contextLengthsHost); inputMap.insert_or_assign(kSequenceLengthsTensorName, sequenceLengthsDevice); + if (useContextEmbeddings) + { + inputMap.insert_or_assign(kDecoderContextFeaturesTensorName, decoderContextFeatures); + inputMap.insert_or_assign(kDecoderContextFeaturesMaskTensorName, decoderContextFeaturesMask); + } + if (useAttentionPrior) + { + inputMap.insert_or_assign(kAttentionPriorFocusTensorName, attentionPriorFocus); + } if (modelConfig.useCrossAttention()) { encoderBuffers->insertInputTensors(inputMap); @@ -1017,6 +1201,10 @@ void RuntimeBuffers::fillIOMaps(ModelConfig const& modelConfig, WorldConfig cons { mEagleBuffers->insertInputTensors(inputMap, outputMap, worldConfig); } + if (useAttentionPrior) + { + outputMap.insert_or_assign(kAttentionPriorScoresTensorName, attentionPriorScores); + } for (auto const& outputTensor : mAdditionalOutputTensors) { diff --git a/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp b/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp index 4f81c892668..b61f5730da2 100644 --- a/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp @@ -354,19 +354,22 @@ void TransformerBuffers::copyKvBlockOffsets(RequestVector const& contextRequests { for (auto const& llmReq : requests) { - auto const requestId = llmReq->mRequestId; - auto const isContextRequest = llmReq->isContextInitState(); - auto const beamWidth = isContextRequest ? contextBeamWidth : llmReq->getBeamWidthByIter(); - auto const maxBeamBlockCount - = kvCacheManager->copyBlockOffsets(*kvCacheBlockOffsetsHost, numSequences, requestId); - maxBlockCount = std::max(maxBlockCount, maxBeamBlockCount); - if (crossKvCacheBlockOffsetsHost) + for (int i = 0; i < llmReq->getNumSequences(); i++) { - auto const maxCrossBeamBlockCount - = crossKvCacheManager->copyBlockOffsets(*crossKvCacheBlockOffsetsHost, numSequences, requestId); - maxCrossBlockCount = std::max(maxCrossBlockCount, maxCrossBeamBlockCount); + auto const requestId = llmReq->getSeqSlotId(i); + auto const isContextRequest = llmReq->isContextInitState(); + auto const beamWidth = isContextRequest ? contextBeamWidth : llmReq->getBeamWidthByIter(); + auto const maxBeamBlockCount + = kvCacheManager->copyBlockOffsets(*kvCacheBlockOffsetsHost, numSequences, requestId); + maxBlockCount = std::max(maxBlockCount, maxBeamBlockCount); + if (crossKvCacheBlockOffsetsHost) + { + auto const maxCrossBeamBlockCount + = crossKvCacheManager->copyBlockOffsets(*crossKvCacheBlockOffsetsHost, numSequences, requestId); + maxCrossBlockCount = std::max(maxCrossBlockCount, maxCrossBeamBlockCount); + } + numSequences += beamWidth; } - numSequences += beamWidth; } } @@ -402,7 +405,15 @@ void TransformerBuffers::copyCacheIndirection( TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(copyCacheIndirection); - auto const numGenerationRequests = genRequests.size(); + std::vector slots; + for (auto const& llmReq : genRequests) + { + for (int i = 0; i < llmReq->getNumSequences(); i++) + { + slots.push_back(llmReq->getSeqSlot(i)); + } + } + auto const numGenerationRequests = slots.size(); auto batchedCopySrcOffsets = BufferRange(*cacheIndirBatchedCopySrcOffsets); auto batchedCopyDstOffsets = BufferRange(*cacheIndirBatchedCopyDstOffsets); @@ -421,8 +432,8 @@ void TransformerBuffers::copyCacheIndirection( cacheIndirShape.d[1] = reqBeamWidth; // Use beam width of current step rather than max beam width as dst offset auto const copySize = static_cast(ITensor::volume(cacheIndirShape)); - std::transform(genRequests.begin(), genRequests.end(), batchedCopySrcOffsets.begin(), - [copySize](auto const& llmReq) { return llmReq->mSeqSlot.value() * copySize; }); + std::transform(slots.begin(), slots.end(), batchedCopySrcOffsets.begin(), + [copySize](auto const& slot) { return slot * copySize; }); std::generate_n( batchedCopyDstOffsets.begin(), numGenerationRequests, [copySize, i = 0]() mutable { return (i++) * copySize; }); std::fill_n(batchedCopySizes.begin(), numGenerationRequests, copySize); @@ -442,6 +453,7 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const& manager = runtime.getBufferManager(); + auto const& stream = runtime.getStream(); // Reshape the tensor to make sure the dim1 matches maxEncoderInputLengthInBatch. auto crossAttentionMaskShape = crossAttentionMaskDevice->getShape(); @@ -462,7 +474,6 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq } } // If not all requests have cross attention mask, let us create the default ones. - auto const& stream = runtime.getStream(); if (!allContextCrossAttentionMaskProvided) { TLLM_LOG_WARNING("Default padding attention mask will be used as not all requests have cross attention mask."); @@ -499,136 +510,149 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq bool* pinnedMemPtr = bufferCastOrNull(crossAttentionMaskPinnedHost); for (auto const& llmReq : contextRequests) { - auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); - auto const position = llmReq->getContextCurrentPosition(); - auto const size = llmReq->getContextChunkSize(); - if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) + for (int s = 0; s < llmReq->getNumSequences(); s++) { - auto memType = crossAttentionMaskRequest->getMemoryType(); - auto const crossAttentionMaskRequestDim0 - = static_cast(crossAttentionMaskRequest->getShape().d[0]); - auto const crossAttentionMaskRequestDim1 - = static_cast(crossAttentionMaskRequest->getShape().d[1]); - TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from contextRequests position %d chunkSize %d", - crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, position, size); - if ((position + size - 1) >= crossAttentionMaskRequestDim0) + auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); + auto const position = llmReq->getContextCurrentPosition(); + auto const size = llmReq->getContextChunkSize(); + if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) + { + auto memType = crossAttentionMaskRequest->getMemoryType(); + auto const crossAttentionMaskRequestDim0 + = static_cast(crossAttentionMaskRequest->getShape().d[0]); + auto const crossAttentionMaskRequestDim1 + = static_cast(crossAttentionMaskRequest->getShape().d[1]); + TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from contextRequests position %d chunkSize %d", + crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, position, size); + if ((position + size - 1) >= crossAttentionMaskRequestDim0) + { + TLLM_LOG_WARNING( + "The provided crossAttentionMask input is not complete for context phases, the last row " + "will be " + "used by default."); + } + // copy it to pinned memory if it is a cpu tensor. + if (memType == MemoryType::kCPU) + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); + auto const copiedPosition + = std::min(crossAttentionMaskRequestDim0 - 1, static_cast(position)); + auto const copiedSize + = std::min(crossAttentionMaskRequestDim0 - copiedPosition, static_cast(size)); + SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); + SizeType64 inputMaskSize = (copiedSize * crossAttentionMaskRequestDim1); + std::memcpy(pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, + inputMaskSize); + pinnedMemPtr += inputMaskSize; + for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + { + SizeType64 tokenIdInPinnedMem + = std::min(copiedSize - 1, static_cast(tokenId - position)); + batchedCopySrcOffsets.begin()[numCopiedTokens] + = (pinnedMemPtr - primarySrcPtr) + tokenIdInPinnedMem * crossAttentionMaskRequestDim1; + batchedCopyDstOffsets.begin()[numCopiedTokens] + = numTokens * static_cast(maxEncoderInputLengthInBatch); + batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; + } + } + else + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); + for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + { + batchedCopySrcOffsets.begin()[numCopiedTokens] + = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) + + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(tokenId)) + * crossAttentionMaskRequestDim1; + batchedCopyDstOffsets.begin()[numCopiedTokens] + = numTokens * static_cast(maxEncoderInputLengthInBatch); + batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; + } + } + } + else { + numTokens += size; TLLM_LOG_WARNING( - "The provided crossAttentionMask input is not complete for context phases, the last row " + "CrossAttentionMask is not provided for sequence %d of request. Default padding attention mask " "will be " - "used by default."); + "created.", + s); } - // copy it to pinned memory if it is a cpu tensor. - if (memType == MemoryType::kCPU) + } + } + sync_check_cuda_error(stream.get()); + + for (auto const& llmReq : genRequests) + { + for (int s = 0; s < llmReq->getNumSequences(); s++) + { + auto const promptLen = llmReq->mPromptLen; + auto const decodingIter = llmReq->getDecodingIter(); + auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); + if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); - auto const copiedPosition - = std::min(crossAttentionMaskRequestDim0 - 1, static_cast(position)); - auto const copiedSize - = std::min(crossAttentionMaskRequestDim0 - copiedPosition, static_cast(size)); - SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); - SizeType64 inputMaskSize = (copiedSize * crossAttentionMaskRequestDim1); - std::memcpy( - pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); - pinnedMemPtr += inputMaskSize; - for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + auto const memType = crossAttentionMaskRequest->getMemoryType(); + auto const crossAttentionMaskRequestDim0 + = static_cast(crossAttentionMaskRequest->getShape().d[0]); + auto const crossAttentionMaskRequestDim1 + = static_cast(crossAttentionMaskRequest->getShape().d[1]); + TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from genRequests decodingIter %d", + crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, decodingIter); + if (promptLen + decodingIter - 1 >= crossAttentionMaskRequestDim0) { - SizeType64 tokenIdInPinnedMem - = std::min(copiedSize - 1, static_cast(tokenId - position)); + TLLM_LOG_WARNING( + "The provided crossAttentionMask input [%d, %d] is not complete for generation phases: %d >= " + "%d.", + crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, promptLen + decodingIter - 1, + crossAttentionMaskRequestDim0); + } + // copy it to pinned memory if it is a cpu tensor. + if (memType == MemoryType::kCPU) + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); + SizeType64 copiedPosition = std::min( + crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)); + SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); + SizeType64 inputMaskSize = crossAttentionMaskRequestDim1; + std::memcpy(pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, + inputMaskSize); + pinnedMemPtr += inputMaskSize; batchedCopySrcOffsets.begin()[numCopiedTokens] - = (pinnedMemPtr - primarySrcPtr) + tokenIdInPinnedMem * crossAttentionMaskRequestDim1; + = static_cast(pinnedMemPtr - primarySrcPtr); batchedCopyDstOffsets.begin()[numCopiedTokens] = numTokens * static_cast(maxEncoderInputLengthInBatch); batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; - numCopiedTokens++; - numTokens++; } - } - else - { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); - for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + else { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) - + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(tokenId)) + + std::min( + crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)) * crossAttentionMaskRequestDim1; batchedCopyDstOffsets.begin()[numCopiedTokens] = numTokens * static_cast(maxEncoderInputLengthInBatch); batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; - numCopiedTokens++; - numTokens++; } - } - } - else - { - numTokens += size; - TLLM_LOG_WARNING( - "CrossAttentionMask is not provided for the request. Default padding attention mask will be " - "created."); - } - } - sync_check_cuda_error(stream.get()); - - for (auto const& llmReq : genRequests) - { - auto const promptLen = llmReq->mPromptLen; - auto const decodingIter = llmReq->getDecodingIter(); - auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); - if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) - { - auto const memType = crossAttentionMaskRequest->getMemoryType(); - auto const crossAttentionMaskRequestDim0 - = static_cast(crossAttentionMaskRequest->getShape().d[0]); - auto const crossAttentionMaskRequestDim1 - = static_cast(crossAttentionMaskRequest->getShape().d[1]); - TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from genRequests decodingIter %d", - crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, decodingIter); - if (promptLen + decodingIter - 1 >= crossAttentionMaskRequestDim0) - { - TLLM_LOG_WARNING( - "The provided crossAttentionMask input is not complete for generation phases, the last row " - "will be " - "used by default."); - } - // copy it to pinned memory if it is a cpu tensor. - if (memType == MemoryType::kCPU) - { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); - SizeType64 copiedPosition = std::min( - crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)); - SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); - SizeType64 inputMaskSize = crossAttentionMaskRequestDim1; - std::memcpy( - pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); - pinnedMemPtr += inputMaskSize; - batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast(pinnedMemPtr - primarySrcPtr); - batchedCopyDstOffsets.begin()[numCopiedTokens] - = numTokens * static_cast(maxEncoderInputLengthInBatch); - batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; } else { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); - batchedCopySrcOffsets.begin()[numCopiedTokens] - = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) - + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)) - * crossAttentionMaskRequestDim1; - batchedCopyDstOffsets.begin()[numCopiedTokens] - = numTokens * static_cast(maxEncoderInputLengthInBatch); - batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numTokens++; + TLLM_LOG_WARNING( + "CrossAttentionMask is not provided for sequence %d of generation request. Full valid " + "attentionMask will " + "be used " + "by default.", + s); } - numCopiedTokens++; - numTokens++; - } - else - { - numTokens++; - TLLM_LOG_WARNING( - "CrossAttentionMask is not provided for the generation request. Full valid attentionMask will " - "be used " - "by default."); } } sync_check_cuda_error(stream.get()); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index c54e02642ca..1614e58ebca 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -693,7 +693,7 @@ std::unique_ptr TrtGptModelInflightBatching::c kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique(kvCacheConfig.getEventBufferMaxSize()) : nullptr, - kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); + kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse(), mModelConfig.getNumVocabs()); reshapeKvTensors(kvCacheManager->getOffsetTableDimensions()); @@ -872,7 +872,8 @@ void TrtGptModelInflightBatching::forwardSync() { llmReq->setNumPreDecodedTokens(0, beam); } - if (llmReq->isGenerationToCompleteState()) + bool crossAttnFinished = mModelConfig.useAttentionPrior() && llmReq->isAttentionPriorFinished(); + if (llmReq->isGenerationToCompleteState() || crossAttnFinished) { llmReq->setState(LlmRequestState::kGENERATION_COMPLETE); terminateRequest(llmReq); @@ -902,7 +903,10 @@ void TrtGptModelInflightBatching::forwardSync() llmReq->finishByReason(mReqIdsToTerminate[llmReq->mRequestId]); llmReq->clearGeneratedTokens(); } - mReqIdsToTerminate.erase(llmReq->mRequestId); + for (int i = 0; i < llmReq->getNumSequences(); i++) + { + mReqIdsToTerminate.erase(llmReq->getSeqSlotId(i)); + } } } } @@ -921,7 +925,10 @@ void TrtGptModelInflightBatching::forwardSync() "cacheTransceiverConfig."); mCacheTransceiver->respondAndSendAsync(llmReq.get()); } - mSeqSlotManager->freeSequenceSlot(llmReq->mRequestId); + for (int i = 0; i < llmReq->getNumSequences(); i++) + { + mSeqSlotManager->freeSequenceSlot(llmReq->getSeqSlotId(i)); + } } } } @@ -1099,6 +1106,13 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests sync_check_cuda_error(mRuntime->getStream().get()); + // forward the attention prior index to llm requests for the next iteration + if (mModelConfig.useAttentionPrior()) + { + mBuffers[getFusedBufferId()]->processAttentionPriorScores( + currRequests.generationRequests, *mRuntime, mModelConfig); + } + // Postpone decoder setup if model does not need to setup buffers for the context phase. if (!mModelConfig.getSpeculativeDecodingMode().needsDecoderPrologue()) { @@ -1479,8 +1493,8 @@ void TrtGptModelInflightBatching::createDecoder(std::optional(); - + // mDecoderState = std::make_unique(); + mDecoderStates.clear(); if (mWorldConfig.isLastPipelineParallelRank()) { auto decoderType = mRuntime->getEngine().getTensorDataType("logits"); @@ -1500,24 +1514,38 @@ void TrtGptModelInflightBatching::createDecoder(std::optional(mRuntime->getStreamPtr()); - mDecoder->setup( - decodingMode, getMaxNumSequences(), mOperatingBeamWidth, decoderType, mModelConfig, mWorldConfig); + mDecoders.clear(); + for (SizeType32 i = 0; i < getNumVocabs(); i++) + { - mDecoderState->setup(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), getSinkTokenLen(), - getMaxSequenceLen(), decoderType, mModelConfig, mWorldConfig, mRuntime->getBufferManager()); + mDecoders.push_back(std::make_unique(mRuntime->getStreamPtr())); + auto& decoder = mDecoders.back(); + decoder->setup(decodingMode, getMaxNumSequences(), mOperatingBeamWidth, decoderType, mModelConfig, + mWorldConfig, mModelConfig.getVocabSizes()[i]); - if (!mModelConfig.getSpeculativeDecodingMode().isNone()) - { - mDecoderState->setupSpeculativeDecoding(mModelConfig.getSpeculativeDecodingMode(), - mModelConfig.getMaxDecodingTokens(), decoderType, mModelConfig, mWorldConfig, - mRuntime->getBufferManager()); + mDecoderStates.push_back(std::make_unique()); + auto& decoderState = mDecoderStates.back(); + + decoderState->setup(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), getSinkTokenLen(), + getMaxSequenceLen(), decoderType, mModelConfig, mWorldConfig, mRuntime->getBufferManager()); + + if (!mModelConfig.getSpeculativeDecodingMode().isNone()) + { + decoderState->setupSpeculativeDecoding(mModelConfig.getSpeculativeDecodingMode(), + mModelConfig.getMaxDecodingTokens(), decoderType, mModelConfig, mWorldConfig, + mRuntime->getBufferManager()); + } } } else { - mDecoderState->setupCacheIndirection( - getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), mRuntime->getBufferManager()); + for (SizeType32 i = 0; i < getNumVocabs(); i++) + { + mDecoderStates.push_back(std::make_unique()); + auto& decoderState = mDecoderStates.back(); + decoderState->setupCacheIndirection( + getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), mRuntime->getBufferManager()); + } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -1539,15 +1567,25 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& mDecoderInputBuffers.clear(); mDecoderOutputBuffers.clear(); + mDecoderOutputBuffers.resize(getNumVocabs()); + for (SizeType32 i = 0; i < mNumMicroBatches; ++i) { mDecoderInputBuffers.emplace_back( getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); mDecoderInputBuffers.back().setupMedusaLogits(getMaxNumSequences(), mModelConfig); - mDecoderOutputBuffers.emplace_back(getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(), - mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); - mDecoderOutputBuffers.back().setupSpeculativeDecoding( - getMaxNumSequences(), mModelConfig.getMaxDecodingTokens(), mModelConfig); + } + + for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) + { + for (SizeType32 i = 0; i < mNumMicroBatches; ++i) + { + + mDecoderOutputBuffers[vid].emplace_back(getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(), + mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager()); + mDecoderOutputBuffers[vid].back().setupSpeculativeDecoding( + getMaxNumSequences(), mModelConfig.getMaxDecodingTokens(), mModelConfig); + } } mSlotDecoderBuffers.clear(); @@ -1557,7 +1595,11 @@ void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& mOperatingBeamWidth, getMaxSequenceLen(), mRuntime->getBufferManager())); } - mDecodingInputs.resize(mNumMicroBatches); + mDecodingInputs.resize(getNumVocabs()); + for (auto& inputs : mDecodingInputs) + { + inputs.resize(mNumMicroBatches); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -1674,7 +1716,7 @@ void TrtGptModelInflightBatching::prepareDistGenBufferAndDecoder(RequestVector c auto const bufferId = getFusedBufferId(); auto& runtimeBuffers = *mBuffers[bufferId]; runtimeBuffers.prepareStep(cacheTransCompleteRequests, {}, getMaxBeamWidth(), getMaxAttentionWindow(), - *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(), + *mDecoderStates.front(), mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(), mPeftTables[mMicroBatchId], *mRuntime, mModelConfig, mWorldConfig, getGatherGenerationLogits(), isTrtOverlap()); auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); @@ -1733,13 +1775,13 @@ TrtGptModelInflightBatching::prepareBuffers( NVTX3_SCOPED_RANGE(prepareBuffers); auto& runtimeBuffers = *mBuffers.at(bufferId); - + auto& decoderState = mDecoderStates.front(); // used for fill sequenceLength tensor, could use any vocab's buffer auto allNewTokens = mWorldConfig.isLastPipelineParallelRank() - ? RuntimeBuffers::OptionalRef(mDecoderState->getAllNewTokens()) + ? RuntimeBuffers::OptionalRef(decoderState->getAllNewTokens()) : std::nullopt; auto [optProfileId, inputMap, outputMap] = runtimeBuffers.prepareStep(contextRequests, generationRequests, - mOperatingBeamWidth, getMaxAttentionWindow(), *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), + mOperatingBeamWidth, getMaxAttentionWindow(), *decoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(), mPeftTables[bufferId], *mRuntime, mModelConfig, mWorldConfig, getGatherGenerationLogits(), isTrtOverlap(), allNewTokens); @@ -1865,23 +1907,42 @@ void TrtGptModelInflightBatching::setupDecoderStep( { auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); - auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] - = (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType, - inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(), - mOperatingBeamWidth, buffers.mMedusaBuffers); + // resize input buffers based on requests + unsigned decoderInputSize{0}; + + if (!contextRequests.empty()) + { + for (auto const& llmReq : contextRequests) + { + auto const& reqTokens = llmReq->getTokens(0); + if (llmReq->isLastContextChunk()) + { + decoderInputSize += reqTokens.size(); + } + } + } + inputBuffers.inputsIds->resize(decoderInputSize); - auto const localBatchSize = batchSlots->getSize(); - if (localBatchSize > 0) + for (SizeType32 vocabId = 0; vocabId < getNumVocabs(); vocabId++) { - auto samplingConfig = SamplingConfig(samplingConfigs); - mDecoder->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots, - {mDecoderState->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt, - lookaheadAlgoConfigs); + auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = (*mCreateNewDecoderRequests)( + mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType, inputBuffers, + *mDecoderStates[vocabId], mRuntime->getStream(), *mDecoders[vocabId]->getDecoderStream(), + getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers); - auto const& stream = mDecoder->getDecoderStream(); - CudaEvent event{}; - stream->record(event); - mRuntime->getStreamPtr()->wait(event); + auto const localBatchSize = batchSlots->getSize(); + if (localBatchSize > 0) + { + auto samplingConfig = SamplingConfig(samplingConfigs); + mDecoders[vocabId]->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots, + {mDecoderStates[vocabId]->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt, + lookaheadAlgoConfigs); + + auto const& stream = mDecoders[vocabId]->getDecoderStream(); + CudaEvent event{}; + stream->record(event); + mRuntime->getStreamPtr()->wait(event); + } } } @@ -1892,7 +1953,7 @@ void TrtGptModelInflightBatching::postProcessRequest( LlmRequest& llmReq, std::vector const& numDroppedTokens) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto const seqSlot = llmReq.mSeqSlot.value(); + auto const seqSlot = llmReq.mSeqSlots.at(0); auto const reqBeamWidth = llmReq.getBeamWidthByIter(true); auto const& bufferManager = getBufferManager(); @@ -1960,14 +2021,17 @@ void TrtGptModelInflightBatching::getDecoderSlotHostOutputs( if (mWorldConfig.isLastPipelineParallelRank()) { - auto event = mDecoder->finalize(*mDecoderState, seqSlot, samplingConfig, streaming); + + auto& decoder = mDecoders.front(); + auto& decoderState = mDecoderStates.front(); + auto event = decoder->finalize(*decoderState, seqSlot, samplingConfig, streaming); // Make sure that postprocessing is done before copying outputIds mCopyBufferManager.getStream().wait(event.get()); - auto sequenceLengths = mDecoderState->getSequenceLengths(seqSlot); - auto outputIds = mDecoderState->getGatheredIds(seqSlot); - auto cumLogProbs = mDecoderState->getCumLogProbs(seqSlot); - auto logProbs = mDecoderState->getLogProbs(seqSlot); + auto sequenceLengths = decoderState->getSequenceLengths(seqSlot); + auto outputIds = decoderState->getGatheredIds(seqSlot); + auto cumLogProbs = decoderState->getCumLogProbs(seqSlot); + auto logProbs = decoderState->getLogProbs(seqSlot); mCopyBufferManager.copy(*sequenceLengths, *mSlotDecoderBuffers[seqSlot]->sequenceLengths); mCopyBufferManager.copy(*outputIds, *mSlotDecoderBuffers[seqSlot]->outputIds); @@ -2039,54 +2103,63 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(decoderStepAsync); + runtime::CudaEvent decoderFinishEvent; + for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) + { + auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId()); + auto& decoderState = mDecoderStates[vid]; - auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId()); - - auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); - auto& contextRuntimeBuffers = mBuffers.at(contextBufferId); - auto const logitsIndex = (*mHandleContextLogits)(decoderInputBuffers, scheduledRequests.contextRequests, - contextRuntimeBuffers->logits, contextRuntimeBuffers->numContextLogits, mModelConfig, - mRuntime->getBufferManager(), contextRuntimeBuffers->mMedusaBuffers); + auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); + auto& contextRuntimeBuffers = mBuffers.at(contextBufferId); + auto const logitsIndex = (*mHandleContextLogits)(decoderInputBuffers, scheduledRequests.contextRequests, + contextRuntimeBuffers->logits, contextRuntimeBuffers->numContextLogits, mModelConfig, + mRuntime->getBufferManager(), contextRuntimeBuffers->mMedusaBuffers, vid); - auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0; - auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); - auto& genRuntimeBuffers = mBuffers.at(genBufferId); - (*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits, - genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, - genRuntimeBuffers->mMedusaBuffers); + auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0; + auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); + auto& genRuntimeBuffers = mBuffers.at(genBufferId); + (*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits, + genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, + genRuntimeBuffers->mMedusaBuffers, vid); - if (mOperatingBeamWidth > 1) - { - copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId); - } + if (mOperatingBeamWidth > 1) + { + copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId, vid); + } - mLogitsPostProcessorIsApplied = (*mLogitsPostProcessor)(decoderInputBuffers, mReplicateLogitsPostProcessor, - mWorldConfig, mRuntime->getStreamPtr(), mLogitsPostProcessorBatched); + mLogitsPostProcessorIsApplied = (*mLogitsPostProcessor)(decoderInputBuffers, mReplicateLogitsPostProcessor, + mWorldConfig, mRuntime->getStreamPtr(), mLogitsPostProcessorBatched); - if (mGuidedDecoder) - { - mGuidedDecoder->execute(decoderInputBuffers, mRuntime->getBufferManager()); - } + if (mGuidedDecoder) + { + mGuidedDecoder->execute(decoderInputBuffers, mRuntime->getBufferManager()); + } - auto const fusedBufferId = getFusedBufferId(); - auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); + auto const fusedBufferId = getFusedBufferId(); + auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); - auto& decodingInput = mDecodingInputs.at(mMicroBatchId); - decodingInput = (*mMakeDecodingBatchInputOutput)(mDecoderInputBuffers.at(fusedBufferId), *mDecoderState, - mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers); + auto& decodingInput = mDecodingInputs[vid].at(mMicroBatchId); + decodingInput = (*mMakeDecodingBatchInputOutput)(mDecoderInputBuffers.at(fusedBufferId), *decoderState, + mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers); - auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *decodingInput); + auto finishedEvent = mDecoders[vid]->forwardAsync(*decoderState, *decodingInput); - auto const returnLogProbs = batchReturnLogProbs(scheduledRequests); - auto updateDecoderBuffersEvent = (*mUpdateDecoderBuffers)(mModelConfig, mDecoderOutputBuffers.at(fusedBufferId), - mRuntime->getBufferManager(), *mDecoderState, returnLogProbs, decoderFinishEvent); + auto const returnLogProbs = batchReturnLogProbs(scheduledRequests); + auto updateDecoderBuffersEvent + = (*mUpdateDecoderBuffers)(mModelConfig, mDecoderOutputBuffers[vid].at(fusedBufferId), + mRuntime->getBufferManager(), *decoderState, returnLogProbs, finishedEvent); + if (vid == getNumVocabs() - 1) + { + decoderFinishEvent = std::move(updateDecoderBuffersEvent); + } + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return updateDecoderBuffersEvent; + return decoderFinishEvent; } void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( - ScheduledRequests const& scheduledRequests, SizeType32 genBufferId) + ScheduledRequests const& scheduledRequests, SizeType32 genBufferId, SizeType32 vocabId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(copyCacheIndirectionFromOutputsToInputs); @@ -2097,7 +2170,8 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( auto* copySizesPtr = bufferCast(*genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySizes); // Only `cacheIndirShape.d[2]` is used - auto const& cacheIndirShape = mDecoderState->getCacheIndirectionOutput()->getShape(); + auto& decoderState = mDecoderStates[vocabId]; + auto const& cacheIndirShape = decoderState->getCacheIndirectionOutput()->getShape(); auto const maxBeamWidth = cacheIndirShape.d[1]; auto const maxAttentionWindow = cacheIndirShape.d[2]; auto const slotOffset = maxBeamWidth * maxAttentionWindow; @@ -2109,14 +2183,17 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( { for (auto const& llmReq : requests) { - auto const reqBeamWidth = llmReq->getBeamWidthByIter(); - auto const seqSlot = llmReq->mSeqSlot.value(); - auto const copySize = reqBeamWidth * maxAttentionWindow; - srcOffsetsPtr[batchIdx] = seqSlot * slotOffset; - dstOffsetsPtr[batchIdx] = seqSlot * slotOffset; - copySizesPtr[batchIdx] = copySize; - maxCopySize = std::max(maxCopySize, copySize); - batchIdx++; + for (int s = 0; s < llmReq->getNumSequences(); s++) + { + auto const reqBeamWidth = llmReq->getBeamWidthByIter(); + auto const seqSlot = llmReq->getSeqSlot(s); + auto const copySize = reqBeamWidth * maxAttentionWindow; + srcOffsetsPtr[batchIdx] = seqSlot * slotOffset; + dstOffsetsPtr[batchIdx] = seqSlot * slotOffset; + copySizesPtr[batchIdx] = copySize; + maxCopySize = std::max(maxCopySize, copySize); + batchIdx++; + } } } if (batchIdx != 0) @@ -2137,8 +2214,8 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( auto const copySizesDeviceSlice = ITensor::slice(genRuntimeBuffers.mCacheIndirDecoderIOBatchedCopyCopySizesDevice, 0, batchIdx); manager.copy(sizesSlice->data(), *copySizesDeviceSlice); // Explicitly move to device for faster access. - runtime::kernels::invokeCopyBatch(*mDecoderState->getCacheIndirectionOutput(), - *mDecoderState->getCacheIndirectionInput(), *srcOffsetsSliceDeviceSlice, *dstOffsetsSliceDeviceSlice, + runtime::kernels::invokeCopyBatch(*decoderState->getCacheIndirectionOutput(), + *decoderState->getCacheIndirectionInput(), *srcOffsetsSliceDeviceSlice, *dstOffsetsSliceDeviceSlice, *copySizesDeviceSlice, maxCopySize, manager.getStream()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -2150,40 +2227,46 @@ std::vector> TrtGptModelInflightBatching:: TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(communicateDecoderBuffers); - auto& decoderOutputBuffers = mDecoderOutputBuffers.at(getFusedBufferId()); - std::vector> asyncHandles; - if (mWorldConfig.isLastPipelineParallelRank()) + for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) { - if (broadcastPostDecoder()) - { - DecoderStepAsyncSend::bcast(decoderOutputBuffers, *mDecoderState, returnLogProbs, mOperatingBeamWidth, - mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommTensorPara, 0); - } + auto& decoderOutputBuffers = mDecoderOutputBuffers[vid].at(getFusedBufferId()); - if (mWorldConfig.isPipelineParallel()) + if (mWorldConfig.isLastPipelineParallelRank()) { - auto const peerSend = 0; - asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, *mDecoderState, - returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), - *mMpiCommPipelinePara, peerSend)); + if (broadcastPostDecoder()) + { + DecoderStepAsyncSend::bcast(decoderOutputBuffers, *mDecoderStates[vid], returnLogProbs, + mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), + *mMpiCommTensorPara, 0); + } + + if (mWorldConfig.isPipelineParallel()) + { + auto const peerSend = 0; + asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, + *mDecoderStates[vid], returnLogProbs, mOperatingBeamWidth, + mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend)); + } } - } - else - { - auto const peerRecv = mWorldConfig.isFirstPipelineParallelRank() ? mWorldConfig.getPipelineParallelism() - 1 - : mWorldConfig.getPipelineParallelRank() - 1; - DecoderStepAsyncSend::recv(decoderOutputBuffers, *mDecoderState, returnLogProbs, mOperatingBeamWidth, - mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerRecv); - auto const peerSend = mWorldConfig.getPipelineParallelRank() + 1; - if (peerSend != mWorldConfig.getPipelineParallelism() - 1) + else { - asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, *mDecoderState, - returnLogProbs, mOperatingBeamWidth, mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), - *mMpiCommPipelinePara, peerSend)); + auto const peerRecv = mWorldConfig.isFirstPipelineParallelRank() + ? mWorldConfig.getPipelineParallelism() - 1 + : mWorldConfig.getPipelineParallelRank() - 1; + DecoderStepAsyncSend::recv(decoderOutputBuffers, *mDecoderStates[vid], returnLogProbs, mOperatingBeamWidth, + mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerRecv); + auto const peerSend = mWorldConfig.getPipelineParallelRank() + 1; + if (peerSend != mWorldConfig.getPipelineParallelism() - 1) + { + asyncHandles.emplace_back(std::make_unique(decoderOutputBuffers, + *mDecoderStates[vid], returnLogProbs, mOperatingBeamWidth, + mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend)); + } } } - TLLM_CHECK_WITH_INFO(asyncHandles.size() <= 2, "Up to two decoder step async handles expected"); + TLLM_CHECK_WITH_INFO(asyncHandles.size() <= static_cast(2 * getNumVocabs()), + "Up to two decoder step async handles per vocab expected"); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return asyncHandles; @@ -2194,15 +2277,10 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(updateRequests); - auto const& decoderOutputBuffers = mDecoderOutputBuffers.at(getFusedBufferId()); + auto const& decoderOutputBuffers = mDecoderOutputBuffers.front().at(getFusedBufferId()); - auto const hostNewOutputTokensShape = decoderOutputBuffers.newOutputTokensHost->getShape(); - auto const* const hostNewOutputTokensData - = bufferCast(*decoderOutputBuffers.newOutputTokensHost); auto const* const sequenceLengthsHostData = bufferCast(*decoderOutputBuffers.sequenceLengthsHost); auto const* const decoderFinishedSumPtr = bufferCast(*decoderOutputBuffers.finishedSumHost); - auto const* const cumLogProbsPtr = bufferCast(*decoderOutputBuffers.cumLogProbsHost); - auto const* const logProbsPtr = bufferCast(*decoderOutputBuffers.logProbsHost); auto const* const finishReasonsHostData = bufferCast(*decoderOutputBuffers.finishReasonsHost); @@ -2214,7 +2292,7 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu continue; } auto const reqBeamWidth = llmReq->getBeamWidthByIter(true); - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const currentNumOfTokens = llmReq->getMaxBeamNumTokens(); // Save the accepted token logits from target model @@ -2256,22 +2334,36 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu numDroppedTokens[beam] = numGeneratedTokens - numNewTokens[beam]; for (SizeType32 step = 0; step < numNewTokens[beam]; ++step) { - auto const newTokenIdx = tc::flat_index(hostNewOutputTokensShape.d, step, seqSlot, beam); - auto const newToken = hostNewOutputTokensData[newTokenIdx]; - llmReq->addNewToken(newToken, beam); - TLLM_LOG_DEBUG("request ID %ld beam %d newToken %d", llmReq->mRequestId, beam, newToken); - - if (llmReq->returnLogProbs()) + SizeType32 vocabOffset = 0; + for (SizeType32 vid = 0; vid < getNumVocabs(); ++vid) { - auto const cumLogProb = cumLogProbsPtr[seqSlot * mOperatingBeamWidth + beam]; - llmReq->setCumLogProb(cumLogProb, beam); - - auto const beginLogProbsOffset = reqBeamWidth == 1 ? llmReq->mPromptLen : 0; - SizeType32 offset - = (seqSlot * mOperatingBeamWidth + beam) * getMaxSequenceLen() + beginLogProbsOffset; - auto const generatedLength = seqLen - llmReq->mPromptLen; - std::vector logProbs(logProbsPtr + offset, logProbsPtr + offset + generatedLength); - llmReq->setLogProbs(logProbs, beam); + + auto const& vocabDecoderOutputBuffers = mDecoderOutputBuffers[vid].at(getFusedBufferId()); + auto const hostNewOutputTokensShape = vocabDecoderOutputBuffers.newOutputTokensHost->getShape(); + auto const newTokenIdx = tc::flat_index(hostNewOutputTokensShape.d, step, seqSlot, beam); + auto const* const hostNewOutputTokensData + = bufferCast(*vocabDecoderOutputBuffers.newOutputTokensHost); + auto const newToken = hostNewOutputTokensData[newTokenIdx]; + llmReq->addNewToken(newToken + vocabOffset, beam); + TLLM_LOG_DEBUG("request ID %ld beam %d newToken %d", llmReq->mRequestId, beam, newToken); + + if (llmReq->returnLogProbs()) + { + auto const* const cumLogProbsPtr + = bufferCast(*vocabDecoderOutputBuffers.cumLogProbsHost); + auto const cumLogProb = cumLogProbsPtr[seqSlot * mOperatingBeamWidth + beam]; + llmReq->setCumLogProb(cumLogProb, beam); + + auto const beginLogProbsOffset = reqBeamWidth == 1 ? llmReq->mPromptLen : 0; + SizeType32 offset + = (seqSlot * mOperatingBeamWidth + beam) * getMaxSequenceLen() + beginLogProbsOffset; + auto const generatedLength = seqLen - llmReq->mPromptLen; + auto const* const logProbsPtr + = bufferCast(*vocabDecoderOutputBuffers.logProbsHost); + std::vector logProbs(logProbsPtr + offset, logProbsPtr + offset + generatedLength); + llmReq->setLogProbs(logProbs, beam); + } + vocabOffset += mModelConfig.getVocabSizes()[vid]; } } @@ -2340,7 +2432,8 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu // Terminate if request has finished or if it is speculative decoding target model if (decoderFinishedSumPtr[seqSlot] == reqBeamWidth - || (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->hasDraftTokens())) + || (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->hasDraftTokens()) + || (mModelConfig.useAttentionPrior() && llmReq->isAttentionPriorFinished())) { postProcessRequest(*llmReq, numDroppedTokens); @@ -2441,7 +2534,7 @@ void TrtGptModelInflightBatching::rewindKVCacheBlocks(SizeType32 numSequences) TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const bufferId = getFusedBufferId(); auto& runtimeBuffers = *mBuffers.at(bufferId); - auto& decoderOutputBuffers = mDecoderOutputBuffers.at(bufferId); + auto& decoderOutputBuffers = mDecoderOutputBuffers.front().at(bufferId); auto localNbLayers = mModelConfig.getNbAttentionLayers( mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank()); @@ -2469,9 +2562,9 @@ void TrtGptModelInflightBatching::rewindKVCacheBlocks(SizeType32 numSequences) commonRewindLen = 0; rewindLens = bufferCast(*decoderOutputBuffers.prevDraftTokensLengthsHost); } - + auto& decoderState = mDecoderStates.front(); tensorrt_llm::runtime::kernels::invokeUpdateKVBlockArrayDraftTokenLocation( - *mDecoderState->getAcceptedLengthsCumSum(), *mDecoderState->getAcceptedPackedPaths(), + *decoderState->getAcceptedLengthsCumSum(), *decoderState->getAcceptedPackedPaths(), *runtimeBuffers.sequenceLengthsDevice, pointerArrayPtr, offsetArrayPtr, localNbLayers, numSequences, mRewindInputs.numKvHeads, sizeInBytesPerKVHead, commonRewindLen, rewindLens, *runtimeBuffers.seqSlots, getMaxAttentionWindow(), mRewindInputs.maxBlocksPerSeq, tokensPerBlock, mRewindInputs.isUseOneMoreBlock, @@ -2590,7 +2683,8 @@ void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& sch setupSpeculativeDecodingModule(mDecodingConfig); mBuffers.at(bufferId)->mLookaheadBuffers->enableLookaheadDecoding( getMaxBatchSize(), mModelConfig.getMaxDecodingTokens()); - mDecoderOutputBuffers.at(getFusedBufferId()) + mDecoderOutputBuffers.front() + .at(getFusedBufferId()) .enableLookaheadDecoding(getMaxNumSequences(), mModelConfig.getMaxDecodingTokens()); createDecoder(mDecodingConfig.getDecodingMode()); } @@ -2601,10 +2695,13 @@ void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& sch mModelConfig.disableSeamlessLookaheadDecoding(); mDecodingConfig.setDecodingMode(executor::DecodingMode::Auto()); mBuffers.at(bufferId)->mLookaheadBuffers->disableLookaheadDecoding(); - mDecoderOutputBuffers.at(getFusedBufferId()).disableLookaheadDecoding(getMaxNumSequences()); - mDecoder->disableLookahead( - scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()).setupBatchSlots); - mDecoderState->disableLookahead(scheduledRequests.generationRequests); + mDecoderOutputBuffers.front().at(getFusedBufferId()).disableLookaheadDecoding(getMaxNumSequences()); + for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) + { + mDecoders[vid]->disableLookahead( + scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()).setupBatchSlots); + mDecoderStates[vid]->disableLookahead(scheduledRequests.generationRequests); + } for (auto const& llmReq : scheduledRequests.generationRequests) { if (llmReq->getNumDraftTokens() > 0) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 1478172ddf9..73d51bd90f5 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -351,7 +351,8 @@ class TrtGptModelInflightBatching : public TrtGptModel /// @brief Copies the content of the cache indirection outputs to the cache indirection inputs. /// @param[in] scheduledRequests The requests to copy the cache indirections for. /// @param[in] genBufferId The id of the generation buffers for those requests. - void copyCacheIndirectionFromOutputsToInputs(ScheduledRequests const& scheduledRequests, SizeType32 genBufferId); + void copyCacheIndirectionFromOutputsToInputs( + ScheduledRequests const& scheduledRequests, SizeType32 genBufferId, SizeType32 vocabId); [[nodiscard]] bool getGatherGenerationLogits() const override { @@ -461,6 +462,21 @@ class TrtGptModelInflightBatching : public TrtGptModel SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override; + [[nodiscard]] SizeType32 getNumVocabs() const + { + return mModelConfig.getNumVocabs(); + } + + [[nodiscard]] std::vector getVocabSizes() const + { + return mModelConfig.getVocabSizes(); + } + + [[nodiscard]] SizeType32 getVocabSize() const + { + return mModelConfig.getVocabSize(); + } + private: /******************** Configs ********************/ // Parameters of the model (TRT engine) @@ -484,9 +500,12 @@ class TrtGptModelInflightBatching : public TrtGptModel // Runner for the TRT engine. The engine produces logits. std::unique_ptr mRuntime; // Decoder that generates new tokens from the logits. - std::unique_ptr mDecoder; + + std::vector> mDecoders; + // Decoder state for all requests + // std::unique_ptr mDecoderState; // Decoder state for all requests - std::unique_ptr mDecoderState; + std::vector> mDecoderStates; // Synchronization handles for decoder std::vector> mDecoderFinishedEvents; @@ -563,13 +582,13 @@ class TrtGptModelInflightBatching : public TrtGptModel // Decoder input buffers for each micro batch. std::vector mDecoderInputBuffers; // Decoder output buffers for each micro batch. - std::vector mDecoderOutputBuffers; + std::vector> mDecoderOutputBuffers; // Buffers for each slot in the decoder std::vector> mSlotDecoderBuffers; // PEFT table for each micro batch std::vector mPeftTables; // Decoder input for each micro batch. - std::vector> mDecodingInputs; + std::vector>> mDecodingInputs; /******************** Book keeping ********************/ // List of requests in each micro batch diff --git a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp index 74ed6102ebc..1c812732bbf 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp @@ -258,24 +258,28 @@ void terminateRequest(SequenceSlotManager& seqSlotManager, LlmRequest& llmReq, S { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // If a sequence slot is associated with this request id, free it - seqSlotManager.freeSequenceSlot(llmReq.mRequestId); - // Remove the sequence from kvCacheManager - auto const requestId = llmReq.mRequestId; - if (kvCacheManager) + for (int i = 0; i < llmReq.getNumSequences(); i++) { - kvCacheManager->removeSequence(requestId, llmReq); - } - if (crossKvCacheManager) - { - crossKvCacheManager->removeSequence(requestId, llmReq); - } - if (pause && !llmReq.isGenerationCompleteState()) - { - llmReq.pause(maxInputLen); - } - else - { - TLLM_LOG_DEBUG("terminated: request ID %lu, paused: %d", requestId, pause); + auto const requestId = llmReq.getSeqSlotId(i); + seqSlotManager.freeSequenceSlot(requestId); + + // Remove the sequence from kvCacheManager + if (kvCacheManager) + { + kvCacheManager->removeSequence(requestId, llmReq); + } + if (crossKvCacheManager) + { + crossKvCacheManager->removeSequence(requestId, llmReq); + } + if (pause && !llmReq.isGenerationCompleteState()) + { + llmReq.pause(maxInputLen); + } + else + { + TLLM_LOG_DEBUG("terminated: request ID %lu, paused: %d", requestId, pause); + } } if (peftCacheManager) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 51bccb7e6e7..6bbd7b40368 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -115,6 +115,12 @@ struct FusedQKVMaskedAttentionDispatchParams KVCacheBuffer kv_block_array; KVLinearBuffer shift_k_cache_buffer; bool cross_attention = false; + float* attention_prior_scores = nullptr; + int const* attention_prior_focus = nullptr; + bool apply_attention_prior = false; + int attention_prior_lookahead = 5; + int attention_prior_window_left = 1; + int attention_prior_window_right = 5; int const* memory_length_per_sample = nullptr; int max_distance = 0; bool block_sparse_attention = false; @@ -631,6 +637,17 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params const& params, cud dispatch_params.rotary_cogvlm_vision_start = mVisionStart; dispatch_params.rotary_cogvlm_vision_length = mVisionLength; dispatch_params.cross_attention = isCrossAttention(); + if (ComputeAttentionPrior()) + { + dispatch_params.attention_prior_scores = params.attention_prior_scores; + dispatch_params.attention_prior_lookahead = mAttentionPriorLookahead; + dispatch_params.attention_prior_window_left = mAttentionPriorWindowLeft; + dispatch_params.attention_prior_window_right = mAttentionPriorWindowRight; + } + if (ApplyAttentionPrior() || ComputeAttentionPrior()) + { + dispatch_params.attention_prior_focus = params.attention_prior_focus; + dispatch_params.apply_attention_prior = mApplyAttentionPrior; + } dispatch_params.memory_length_per_sample = params.encoder_input_lengths; dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE; dispatch_params.block_sparse_params = mBlockSparseParams; @@ -2967,6 +2996,11 @@ std::string AttentionOp::toString() const ss << "mMaxContextLength: " << mMaxContextLength << std::endl; ss << "mQKVBiasEnabled: " << std::boolalpha << mQKVBiasEnabled << std::endl; ss << "mCrossAttention: " << std::boolalpha << mCrossAttention << std::endl; + ss << "mComputeAttentionPrior: " << std::boolalpha << mComputeAttentionPrior << std::endl; + ss << "mApplyAttentionPrior: " << std::boolalpha << mApplyAttentionPrior << std::endl; + ss << "mAttentionPriorLookahead: " << mAttentionPriorLookahead << std::endl; + ss << "mAttentionPriorWindowLeft: " << mAttentionPriorWindowLeft << std::endl; + ss << "mAttentionPriorWindowRight: " << mAttentionPriorWindowRight << std::endl; ss << "mMaxDistance: " << mMaxDistance << std::endl; ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl; ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index f33194c02fa..1b0ffcd0f18 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -219,6 +219,10 @@ class AttentionOp int32_t spec_decoding_max_generation_length = 1; // optional when fuse_fp4_quant is enabled int32_t start_token_idx_sf = 0; + + // optional when attention prior is used. + float* attention_prior_scores = nullptr; + int32_t const* attention_prior_focus = nullptr; }; template @@ -323,6 +327,31 @@ class AttentionOp return mCrossAttention; } + [[nodiscard]] bool ComputeAttentionPrior() const + { + return mComputeAttentionPrior; + } + + [[nodiscard]] bool ApplyAttentionPrior() const + { + return mApplyAttentionPrior; + } + + [[nodiscard]] int AttentionPriorLookahead() const + { + return mAttentionPriorLookahead; + } + + [[nodiscard]] int AttentionPriorWindowLeft() const + { + return mAttentionPriorWindowLeft; + } + + [[nodiscard]] int AttentionPriorWindowRight() const + { + return mAttentionPriorWindowRight; + } + [[nodiscard]] bool useKVCache() const { return mUseKVCache; @@ -409,6 +438,11 @@ class AttentionOp int32_t mMaxContextLength = 0; bool mQKVBiasEnabled = false; bool mCrossAttention = false; + bool mComputeAttentionPrior = false; + bool mApplyAttentionPrior = false; + int mAttentionPriorLookahead = 5; + int mAttentionPriorWindowLeft = 1; + int mAttentionPriorWindowRight = 5; int mMaxDistance = 0; bool mPosShiftEnabled = false; bool mPagedContextFMHA = false; @@ -469,14 +503,15 @@ class AttentionOp mRotaryEmbeddingLongMscale, mRotaryEmbeddingMaxPositions, mRotaryEmbeddingOriginalMaxPositions, (int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType, mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank, - mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance, - mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA, - mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, - mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, - mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, - mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, - mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, - mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1)); + mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, + mComputeAttentionPrior, mApplyAttentionPrior, mMaxDistance, mPosShiftEnabled, mPagedContextFMHA, + mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA, mChunkPrefillBufferBatchSize, + mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, + mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, + mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, + mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, + mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, + mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1)); }; private: diff --git a/cpp/tensorrt_llm/executor/request.cpp b/cpp/tensorrt_llm/executor/request.cpp index 987eeef894e..1ca138ba6ef 100644 --- a/cpp/tensorrt_llm/executor/request.cpp +++ b/cpp/tensorrt_llm/executor/request.cpp @@ -37,19 +37,20 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::optional logitslogitsPostProcessor, std::optional encoderInputTokenIds, std::optional clientId, bool returnAllGeneratedTokens, float priority, RequestType type, std::optional contextPhaseParams, std::optional encoderInputFeatures, - std::optional encoderOutputLength, std::optional crossAttentionMask, - SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, - std::optional guidedDecodingParams, std::optional languageAdapterUid, - std::optional allottedTimeMs, std::optional cacheSaltID) + std::optional encoderOutputLength, std::optional decoderContextFeatures, + std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, + std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, + std::optional languageAdapterUid, std::optional allottedTimeMs, + std::optional cacheSaltID, SizeType32 numVocabs) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), std::move(multimodalEmbedding), std::move(mRopeConfig), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig), std::move(logitsPostProcessorName), std::move(logitslogitsPostProcessor), std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, - std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, - numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, - allottedTimeMs, cacheSaltID)) + std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, + std::move(decoderContextFeatures), crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID, numVocabs)) { } @@ -183,6 +184,11 @@ std::optional Request::getEncoderInputTokenIds() const return mImpl->getEncoderInputTokenIds(); } +std::optional Request::getDecoderContextFeatures() const +{ + return mImpl->getDecoderContextFeatures(); +} + std::optional Request::getClientId() const { return mImpl->getClientId(); @@ -253,6 +259,11 @@ std::optional Request::getCacheSaltID() const return mImpl->getCacheSaltID(); } +SizeType32 Request::getNumVocabs() const +{ + return mImpl->getNumVocabs(); +} + void Request::setStreaming(bool streaming) { mImpl->setStreaming(streaming); @@ -388,6 +399,11 @@ void Request::setEncoderOutputLength(SizeType32 encoderOutputLength) mImpl->setEncoderOutputLength(encoderOutputLength); } +void Request::setDecoderContextFeatures(Tensor decoderContextFeatures) +{ + return mImpl->setDecoderContextFeatures(decoderContextFeatures); +} + void Request::setCrossAttentionMask(Tensor crossAttentionMask) { mImpl->setCrossAttentionMask(crossAttentionMask); @@ -422,4 +438,9 @@ void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) { return mImpl->setCacheSaltID(cacheSaltID); } + +void Request::setNumVocabs(SizeType32 numVocabs) +{ + return mImpl->setNumVocabs(numVocabs); +} } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/requestImpl.h b/cpp/tensorrt_llm/executor/requestImpl.h index 94de53a7817..07828de8f97 100644 --- a/cpp/tensorrt_llm/executor/requestImpl.h +++ b/cpp/tensorrt_llm/executor/requestImpl.h @@ -45,10 +45,11 @@ class Request::Impl std::optional encoderInputTokenIds, std::optional clientId, bool returnAllGeneratedTokens, PriorityType priority, RequestType type, std::optional contextPhaseParams, std::optional encoderInputFeatures, std::optional encoderOutputLength, - std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, + std::optional decoderContextFeatures, std::optional crossAttentionMask, + SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, std::optional allottedTimeMs, - std::optional cacheSaltID) + std::optional cacheSaltID, SizeType32 numVocabs = 1) : mInputTokenIds(std::move(inputTokenIds)) , mMaxNewTokens(maxNewTokens) , mStreaming(streaming) @@ -78,6 +79,7 @@ class Request::Impl , mContextPhaseParams(std::move(contextPhaseParams)) , mEncoderInputFeatures(std::move(encoderInputFeatures)) , mEncoderOutputLength(encoderOutputLength) + , mDecoderContextFeatures(std::move(decoderContextFeatures)) , mCrossAttentionMask(std::move(crossAttentionMask)) , mNumReturnSequences(numReturnSequences) , mEagleConfig(std::move(eagleConfig)) @@ -86,6 +88,7 @@ class Request::Impl , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) , mCacheSaltID(cacheSaltID) + , mNumVocabs(numVocabs) { validate(); } @@ -259,6 +262,11 @@ class Request::Impl return mEncoderInputFeatures; } + [[nodiscard]] std::optional getDecoderContextFeatures() const + { + return mDecoderContextFeatures; + } + [[nodiscard]] std::optional getCrossAttentionMask() const { return mCrossAttentionMask; @@ -302,6 +310,11 @@ class Request::Impl return mCacheSaltID; } + [[nodiscard]] SizeType32 getNumVocabs() const + { + return mNumVocabs; + } + void setStreaming(bool streaming) { mStreaming = streaming; @@ -432,6 +445,11 @@ class Request::Impl mEncoderInputFeatures = encoderInputFeatures; } + void setDecoderContextFeatures(Tensor decoderContextFeatures) + { + mDecoderContextFeatures = decoderContextFeatures; + } + void setCrossAttentionMask(Tensor crossAttentionMask) { mCrossAttentionMask = crossAttentionMask; @@ -481,6 +499,11 @@ class Request::Impl mCacheSaltID = cacheSaltID; } + void setNumVocabs(SizeType32 numVocabs) + { + mNumVocabs = numVocabs; + } + private: void validate() { @@ -540,6 +563,7 @@ class Request::Impl lambda(mKvCacheRetentionConfig); lambda(mLogitsPostProcessorName); lambda(mEncoderInputTokenIds); + lambda(mDecoderContextFeatures); lambda(mClientId); lambda(mReturnAllGeneratedTokens); lambda(mPriority); @@ -555,6 +579,7 @@ class Request::Impl lambda(mLanguageAdapterUid); lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt); lambda(mCacheSaltID); + lambda(mNumVocabs); } VecTokens mInputTokenIds; @@ -586,6 +611,7 @@ class Request::Impl std::optional mContextPhaseParams; std::optional mEncoderInputFeatures; std::optional mEncoderOutputLength; + std::optional mDecoderContextFeatures; std::optional mCrossAttentionMask; SizeType32 mNumReturnSequences; std::optional mEagleConfig; @@ -594,6 +620,7 @@ class Request::Impl std::optional mLanguageAdapterUid; std::optional mAllottedTimeMs; std::optional mCacheSaltID; + SizeType32 mNumVocabs; }; } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/samplingConfig.cpp b/cpp/tensorrt_llm/executor/samplingConfig.cpp index 176865340e2..87214aa81c2 100644 --- a/cpp/tensorrt_llm/executor/samplingConfig.cpp +++ b/cpp/tensorrt_llm/executor/samplingConfig.cpp @@ -36,7 +36,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty, OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP, - OptVec const& beamWidthArray) + OptVec const& beamWidthArray, OptFloat const& cfgScale) : mBeamWidth(checkBeamWidth(beamWidth)) , mTopK(checkTopK(topK)) , mTopP(checkTopP(topP)) @@ -55,6 +55,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF , mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize)) , mNumReturnSequences(checkNumReturnSequences(numReturnSequences, beamWidth)) , mMinP(checkMinP(minP)) + , mCfgScale(checkCfgScale(cfgScale)) { updateNumReturnBeams(); std::tie(mBeamWidthArray, mBeamWidth) = checkBeamWidthArray(beamWidthArray, mBeamWidth); @@ -69,7 +70,7 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const && mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty && mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences - && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray; + && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray && mCfgScale == other.mCfgScale; } // Getters @@ -163,6 +164,11 @@ OptSize32 SamplingConfig::getNumReturnSequences() const return mNumReturnSequences; } +OptFloat SamplingConfig::getCfgScale() const +{ + return mCfgScale; +} + std::optional SamplingConfig::getMinP() const { return mMinP; @@ -271,6 +277,11 @@ void SamplingConfig::setBeamWidthArray(OptVec const& beamWidthArray) std::tie(mBeamWidthArray, mBeamWidth) = checkBeamWidthArray(beamWidthArray, mBeamWidth); } +void SamplingConfig::setCfgScale(std::optional const& cfgScale) +{ + mCfgScale = checkCfgScale(cfgScale); +} + // Checkers SizeType32 SamplingConfig::checkBeamWidth(SizeType32 beamWidth) { @@ -278,6 +289,12 @@ SizeType32 SamplingConfig::checkBeamWidth(SizeType32 beamWidth) return beamWidth; } +OptFloat const& SamplingConfig::checkCfgScale(OptFloat const& cfgScale) +{ + // TODO: implement checking the cfg scale + return cfgScale; +} + OptFloat const& SamplingConfig::checkTopK(OptFloat const& topK) { if (topK.has_value()) diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index b3726029ed5..1a401110153 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -705,6 +705,7 @@ Request Serialization::deserializeRequest(std::istream& is) auto contextPhaseParams = su::deserialize>(is); auto encoderInputFeatures = su::deserialize>(is); auto encoderOutputLength = su::deserialize>(is); + auto decoderContextFeatures = su::deserialize>(is); auto crossAttentionMask = su::deserialize>(is); auto numReturnSequences = su::deserialize(is); auto eagleConfig = su::deserialize>(is); @@ -724,8 +725,9 @@ Request Serialization::deserializeRequest(std::istream& is) std::move(kvCacheRetentionConfig), std::move(logitsPostProcessorName), std::nullopt, std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, - std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks), - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID); + std::move(decoderContextFeatures), std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), + std::move(skipCrossAttnBlocks), std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, + cacheSaltID); } void Serialization::serialize(Request const& request, std::ostream& os) diff --git a/cpp/tensorrt_llm/kernels/cfgKernels.h b/cpp/tensorrt_llm/kernels/cfgKernels.h new file mode 100644 index 00000000000..1f502271c5c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cfgKernels.h @@ -0,0 +1,58 @@ +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include +#include + +namespace tensorrt_llm::kernels +{ + +//! Apply classifier-free guidance (CFG) on GPU in-place using cuBLAS. +//! It overwrites `logitsView` with: logits = cfgScale * logits + (1 - cfgScale) * uncondLogits +//! Only the slice [vocabOffset, vocabOffset + vocabSize) is modified. +inline void invokeCfg(tensorrt_llm::runtime::CudaStream const& stream, runtime::ITensor::SharedPtr logitsView, + runtime::ITensor::SharedPtr uncondLogitsView, float cfgScale, runtime::SizeType32 vocabOffset, + runtime::SizeType32 vocabSize) +{ + using TensorPtr = runtime::ITensor::SharedPtr; + + // Restrict to current vocabulary segment. + TensorPtr logitsVocabView = runtime::ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); + TensorPtr uncondLogitsVocabView = runtime::ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); + + void* condPtr = logitsVocabView->data(); + void const* uncondPtr = uncondLogitsVocabView->data(); + + cudaDataType_t dataType{}; + switch (logitsVocabView->getDataType()) + { + case nvinfer1::DataType::kFLOAT: dataType = CUDA_R_32F; break; + case nvinfer1::DataType::kHALF: dataType = CUDA_R_16F; break; + default: TLLM_THROW("Unsupported data type for CFG"); + } + + auto handlePtr = getCublasHandle(); + auto& handle = *handlePtr; + tensorrt_llm::common::check_cuda_error(cublasSetStream(handle, stream.get())); + + int n = static_cast(vocabSize); + int inc = 1; + + // Use float for the scaling factors and always accumulate in FP32 to + // satisfy cuBLAS requirements (FP16 vectors must use FP32 compute/alpha). + float alphaF = cfgScale; // Scaling factor in FP32 + float axpyF = 1.0f - cfgScale; // (1 - cfgScale) in FP32 + + tensorrt_llm::common::check_cuda_error(cublasScalEx(handle, n, &alphaF, CUDA_R_32F, // alpha + condPtr, dataType, // x and its type + inc, CUDA_R_32F)); // increments + compute type + + tensorrt_llm::common::check_cuda_error(cublasAxpyEx(handle, n, &axpyF, CUDA_R_32F, // alpha + uncondPtr, dataType, inc, // x + condPtr, dataType, inc, // y + CUDA_R_32F)); // compute type +} + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h index 3f2705f2eea..6a6aad0ea0c 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h @@ -210,6 +210,16 @@ struct Multihead_attention_params_base int const* memory_length_per_sample = nullptr; int32_t const* mrope_position_deltas = nullptr; + + // fields related to attention prior: + // scores which accumulate window of cross attention probs + // and focus which specifies which index in encoder output sequence + float* attention_prior_scores = nullptr; + int const* attention_prior_focus = nullptr; + bool apply_attention_prior = false; + int attention_prior_lookahead = 5; + int attention_prior_window_left = 1; + int attention_prior_window_right = 5; }; template diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 21b9112b9fe..3d9714d4e92 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1528,6 +1528,18 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float const kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f); convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f)); + // parameters related to attention prior + int focus; + if (params.attention_prior_focus != nullptr) + { + focus = params.attention_prior_focus[batch_beam_idx]; + } + bool const store_scores = params.attention_prior_scores != nullptr; + float* scores_ptr = nullptr; + if (store_scores) + { + scores_ptr = ¶ms.attention_prior_scores[batch_beam_idx * params.attention_prior_lookahead]; + } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); @@ -1849,7 +1861,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { relative_attention_bias = convert_to_float(relative_attention_bias_ptr[tlength]); } - if (has_attention_mask && tidx == 0) + if (has_attention_mask && tidx == 0 && !DO_CROSS_ATTENTION) { // Note: reuse the relative_attention_bias variable. // attention_mask = 1.0 means that the position is not masked. @@ -2055,7 +2067,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { relative_attention_bias = convert_to_float(relative_attention_bias_ptr[local_time_now]); } - if (is_active && has_attention_mask) + if (is_active && has_attention_mask && !DO_CROSS_ATTENTION) { // Note: reuse the relative_attention_bias variable. // attention_mask = 1.0 means that the position is not masked. @@ -2268,13 +2280,30 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float inv_sum = __fdividef(logit_scale, sum + 1.e-6f); int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; + float sum_rescale = 0.0f; for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK) { int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; if (!MULTI_BLOCK_FLAG) { - convert_from_float(&logits_smem[ti], qk_smem[ti] * inv_sum); + float prob = qk_smem[ti] * inv_sum; + if (DO_CROSS_ATTENTION && params.attention_prior_focus != nullptr) + { + // do the masking to the prob + if (ti < (focus - params.attention_prior_window_left) + || ti > (focus + params.attention_prior_window_right)) + { + prob *= 0.1f; + } + // store back + qk_smem[ti] = prob; + sum_rescale += prob; + } + else + { + convert_from_float(&logits_smem[ti], prob); + } } else { @@ -2290,6 +2319,27 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } } + // for the case when we apply prior, we need to perform additional normalization, + // dividing by the sum of the modified probs. + __syncthreads(); + if (!MULTI_BLOCK_FLAG && DO_CROSS_ATTENTION && params.attention_prior_focus != nullptr) + { + sum_rescale = block_sum(&red_smem[WARPS_PER_BLOCK], sum_rescale); + + // finally loop to compute probability, store probability to buffer if needed + float inv_sum_rescale = __fdividef(1.0f, sum_rescale + 1.e-6f); + for (int ti = tidx; ti <= kv_loop_length; ti += THREADS_PER_BLOCK) + { + float prob = qk_smem[ti] * inv_sum_rescale; + if (store_scores && ti >= focus && ti < focus + params.attention_prior_lookahead) + { + scores_ptr[ti - focus] = prob; + } + convert_from_float(&logits_smem[ti], prob); + } + __syncthreads(); + } + // Put Values part below so we leverage __syncthreads // from the previous step diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index e0325b51c8a..abf402e5dbc 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -134,7 +134,7 @@ void initBindings(nb::module_& m) .def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming) .def_rw("end_id", &GenLlmReq::mEndId) .def_rw("pad_id", &GenLlmReq::mPadId) - .def_rw("seq_slot", &GenLlmReq::mSeqSlot) + .def_rw("seq_slots", &GenLlmReq::mSeqSlots) .def_prop_ro("return_log_probs", &GenLlmReq::returnLogProbs) .def_prop_ro("return_context_logits", &GenLlmReq::getReturnContextLogits) .def_prop_ro("return_generation_logits", &GenLlmReq::getReturnGenerationLogits) @@ -281,6 +281,7 @@ void initBindings(nb::module_& m) bool apply_logits_post_processor_batched, std::optional encoder_input_tokens, bool return_encoder_output, std::optional client_id, executor::PriorityType priority, std::optional encoder_input_features, + std::optional decoder_context_features, std::optional encoder_output_length, std::optional cross_attention_mask, tb::LlmRequestType llm_request_type, std::optional input_token_extra_ids, @@ -317,6 +318,7 @@ void initBindings(nb::module_& m) auto lora_config_tensor_ptr = makeOptionalTensor(lora_config); auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits); auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features); + auto decoder_context_features_tensor_ptr = makeOptionalTensor(decoder_context_features); auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); @@ -329,10 +331,10 @@ void initBindings(nb::module_& m) return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr, exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched, encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, - encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, - num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, - guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, - arrival_time}; + encoder_output_length, decoder_context_features_tensor_ptr, cross_attention_mask_tensor_ptr, + llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, + skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, + language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -351,7 +353,8 @@ void initBindings(nb::module_& m) nb::arg("apply_logits_post_processor_batched") = false, nb::arg("encoder_input_tokens") = std::nullopt, nb::arg("return_encoder_output") = false, nb::arg("client_id") = std::nullopt, nb::arg("priority") = executor::Request::kDefaultPriority, nb::arg("encoder_input_features") = std::nullopt, - nb::arg("encoder_output_len") = std::nullopt, nb::arg("cross_attention_mask") = std::nullopt, + nb::arg("encoder_output_len") = std::nullopt, nb::arg("decoder_context_features") = std::nullopt, + nb::arg("cross_attention_mask") = std::nullopt, nb::arg("llm_request_type") = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, nb::arg("input_token_extra_ids") = std::nullopt, nb::arg("num_return_sequences") = 1, nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, @@ -460,7 +463,7 @@ void initBindings(nb::module_& m) { if (contextRequests[i]->isLastContextChunk()) { - activeSlots.push_back(*contextRequests[i]->mSeqSlot); + activeSlots.push_back(contextRequests[i]->mSeqSlots.at(0)); generationSteps.push_back(contextRequests[i]->getDecodingIter()); auto contextLogitsOffset = numContextLogitsPrefixSum[i + 1] - 1; tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, contextLogitsOffset, 1); @@ -489,7 +492,7 @@ void initBindings(nb::module_& m) { if (genRequests[i]->isGenerationInProgressState()) { - activeSlots.push_back(*genRequests[i]->mSeqSlot); + activeSlots.push_back(genRequests[i]->mSeqSlots.at(0)); generationSteps.push_back(genRequests[i]->getDecodingIter()); auto logitsOffset = genLogitsOffset + i * beamWidth; diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 68c719fb687..6df7c0b706d 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -488,7 +488,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) std::vector const&, std::optional const&, nvinfer1::DataType, SizeType32, int64_t, runtime::SizeType32, bool, bool, tbk::CacheType, std::optional, std::shared_ptr, - bool, bool, std::shared_ptr>(), + bool, bool, SizeType32, std::shared_ptr>(), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), @@ -496,8 +496,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true, nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt, nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, - nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr, - nb::call_guard()); + nb::arg("copy_on_partial_reuse") = true, nb::arg("num_vocabs") = 1, + nb::arg("kv_connector_manager") = nullptr, nb::call_guard()); } void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index 07d630cb3b2..2dd79c0a3dc 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -115,6 +115,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mPriority, // from_torch(mEncoderInputFeatures), // mEncoderOutputLength, // + from_torch(mDecoderContextFeatures), // from_torch(mCrossAttentionMask), // getLlmRequestType(), // std::nullopt, // inputTokenExtraIds diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 4ea47fdcc8c..cf395add156 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -76,6 +76,7 @@ class LlmRequest : public tb::GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, + std::optional decoderContextFeatures = std::nullopt, std::optional crossAttentionMask = std::nullopt, tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, @@ -135,6 +136,7 @@ class LlmRequest : public tb::GenericLlmRequest priority, // encoderInputFeatures, // encoderOutputLength, // + decoderContextFeatures, // crossAttentionMask, // llmRequestType, // inputTokenExtraIds // diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 5e0ece45636..21cc4a144d7 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -280,11 +280,13 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def(nb::self != nb::self); nb::class_(m, "ModelConfig") - .def(nb::init(), + .def(nb::init>>(), nb::arg("vocab_size"), nb::arg("num_layers"), nb::arg("num_attention_layers"), nb::arg("num_rnn_layers"), - nb::arg("num_heads"), nb::arg("hidden_size"), nb::arg("data_type")) + nb::arg("num_heads"), nb::arg("hidden_size"), nb::arg("data_type"), nb::arg("vocab_sizes") = nb::none()) .def_prop_ro("vocab_size", &tr::ModelConfig::getVocabSize) - .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, nb::arg("world_size")) + .def( + "vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, nb::arg("world_size"), nb::arg("vocab_size") = 0) .def("num_layers", &tr::ModelConfig::getNbLayers, nb::arg("pipeline_parallelism") = 1, nb::arg("pipeline_parallelism_rank") = 0) .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, nb::arg("pipeline_parallelism") = 1, @@ -297,6 +299,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_prop_ro("hidden_size", &tr::ModelConfig::getHiddenSize) .def_prop_ro("size_per_head", &tr::ModelConfig::getSizePerHead) .def_prop_ro("data_type", &tr::ModelConfig::getDataType) + .def_prop_ro("num_vocabs", &tr::ModelConfig::getNumVocabs) + .def_prop_ro("vocab_sizes", &tr::ModelConfig::getVocabSizes) .def_prop_ro("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode) .def_prop_rw("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) .def_prop_rw( @@ -306,6 +310,25 @@ NB_MODULE(TRTLLM_NB_MODULE, m) nb::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) .def_prop_rw("use_packed_input", nb::overload_cast<>(&tr::ModelConfig::usePackedInput, nb::const_), nb::overload_cast(&tr::ModelConfig::usePackedInput)) + .def_prop_rw("use_gpt_attention_plugin", + nb::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, nb::const_), + nb::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) + .def_prop_rw("use_packed_input", nb::overload_cast<>(&tr::ModelConfig::usePackedInput, nb::const_), + nb::overload_cast(&tr::ModelConfig::usePackedInput)) + .def_prop_rw("use_attention_prior", nb::overload_cast<>(&tr::ModelConfig::useAttentionPrior, nb::const_), + nb::overload_cast(&tr::ModelConfig::useAttentionPrior)) + .def_prop_rw("use_context_embeddings", nb::overload_cast<>(&tr::ModelConfig::useContextEmbeddings, nb::const_), + nb::overload_cast(&tr::ModelConfig::useContextEmbeddings)) + .def_prop_rw("compute_attention_prior_from_layers", &tr::ModelConfig::getComputeAttentionPriorFromLayers, + &tr::ModelConfig::setComputeAttentionPriorFromLayers) + .def_prop_rw("apply_attention_prior_to_layers", &tr::ModelConfig::getApplyAttentionPriorToLayers, + &tr::ModelConfig::setApplyAttentionPriorToLayers) + .def_prop_rw("attention_prior_lookahead", &tr::ModelConfig::getAttentionPriorLookahead, + &tr::ModelConfig::setAttentionPriorLookahead) + .def_prop_rw("attention_prior_window_left", &tr::ModelConfig::getAttentionPriorWindowLeft, + &tr::ModelConfig::setAttentionPriorWindowLeft) + .def_prop_rw("attention_prior_window_right", &tr::ModelConfig::getAttentionPriorWindowRight, + &tr::ModelConfig::setAttentionPriorWindowRight) .def_prop_rw("kv_cache_type", nb::overload_cast<>(&tr::ModelConfig::getKVCacheType, nb::const_), nb::overload_cast(&tr::ModelConfig::setKVCacheType)) .def_prop_rw("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock) @@ -375,7 +398,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) }; auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) { - if (t.size() != 19) + if (t.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -400,7 +423,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) config.numReturnSequences = nb::cast(t[16]); config.minP = nb::cast>(t[17]); config.beamWidthArray = nb::cast>>(t[18]); - + config.cfgScale = nb::cast>(t[19]); new (&self) tr::SamplingConfig(config); }; @@ -427,6 +450,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_rw("num_return_sequences", &tr::SamplingConfig::numReturnSequences) .def_rw("min_p", &tr::SamplingConfig::minP) .def_rw("beam_width_array", &tr::SamplingConfig::beamWidthArray) + .def_rw("cfg_scale", &tr::SamplingConfig::cfgScale) .def_rw("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs) .def("__getstate__", SamplingConfigGetState) .def("__setstate__", SamplingConfigSetState) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index de9aa8a8c07..a8057d89e10 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -81,7 +81,7 @@ void initRequestBindings(nb::module_& m) }; auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state) { - if (state.size() != 19) + if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -103,29 +103,31 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[15]), // NoRepeatNgramSize nb::cast>(state[16]), // NumReturnSequences nb::cast>(state[17]), // MinP - nb::cast>>(state[18]) // BeamWidthArray + nb::cast>>(state[18]), // BeamWidthArray + nb::cast>(state[19]) // CfgScale ); }; nb::class_(m, "SamplingConfig") .def(nb::init const&, // beamWidth - std::optional const&, // topP - std::optional const&, // topPMin - std::optional const&, // topPResetIds - std::optional const&, // topPDecay - std::optional const&, // seed - std::optional const&, // temperature - std::optional const&, // minTokens - std::optional const&, // beamSearchDiversityRate - std::optional const&, // repetitionPenalty - std::optional const&, // presencePenalty - std::optional const&, // frequencyPenalty - std::optional const&, // lengthPenalty - std::optional const&, // earlyStopping - std::optional const&, // noRepeatNgramSize - std::optional const&, // numReturnSequences - std::optional const&, // minP - std::optional> const& // beamWidthArray + std::optional const&, // beamWidth + std::optional const&, // topP + std::optional const&, // topPMin + std::optional const&, // topPResetIds + std::optional const&, // topPDecay + std::optional const&, // seed + std::optional const&, // temperature + std::optional const&, // minTokens + std::optional const&, // beamSearchDiversityRate + std::optional const&, // repetitionPenalty + std::optional const&, // presencePenalty + std::optional const&, // frequencyPenalty + std::optional const&, // lengthPenalty + std::optional const&, // earlyStopping + std::optional const&, // noRepeatNgramSize + std::optional const&, // numReturnSequences + std::optional const&, // minP + std::optional> const&, // beamWidthArray + std::optional const& // CfgScale >(), // clang-format off nb::arg("beam_width") = 1, @@ -147,7 +149,8 @@ void initRequestBindings(nb::module_& m) nb::arg("no_repeat_ngram_size") = nb::none(), nb::arg("num_return_sequences") = nb::none(), nb::arg("min_p") = nb::none(), - nb::arg("beam_width_array") = nb::none()) // clang-format on + nb::arg("beam_width_array") = nb::none(), + nb::arg("cfg_scale") = nb::none()) // clang-format on .def_prop_rw("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_prop_rw("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_prop_rw("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) @@ -174,6 +177,7 @@ void initRequestBindings(nb::module_& m) .def_prop_rw("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) .def_prop_rw( "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) + .def_prop_rw("cfg_scale", &tle::SamplingConfig::getCfgScale, &tle::SamplingConfig::setCfgScale) .def("__getstate__", samplingConfigGetstate) .def("__setstate__", samplingConfigSetstate); @@ -572,12 +576,12 @@ void initRequestBindings(nb::module_& m) self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), - self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID()); + self.getDecoderContextFeatures(), self.getCrossAttentionMask(), self.getEagleConfig(), + self.getSkipCrossAttnBlocks(), self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getNumVocabs()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { - if (state.size() != 34) + if (state.size() != 38) { throw std::runtime_error("Invalid Request state!"); } @@ -599,10 +603,13 @@ void initRequestBindings(nb::module_& m) nb::cast(state[24]), nb::cast(state[25]), nb::cast>(state[26]), nb::cast>(state[27]), nb::cast>(state[28]), - nb::cast>(state[29]), 1, nb::cast>(state[30]), - nb::cast>(state[31]), - nb::cast>(state[32]), - nb::cast>(state[33])); + nb::cast>(state[29]), nb::cast>(state[30]), 1, + nb::cast>(state[31]), nb::cast>(state[32]), + nb::cast>(state[33]), + nb::cast>(state[34]), // languageAdapterUid + nb::cast>(state[35]), // allottedTimeMs + nb::cast>(state[36]), // cacheSaltID + nb::cast(state[37])); // numVocabs }; nb::class_ request(m, "Request", nb::dynamic_attr()); @@ -636,6 +643,7 @@ void initRequestBindings(nb::module_& m) std::optional, // contextPhaseParams std::optional, // encoderInputFeatures std::optional, // encoderOutputLength + std::optional, // decoderContextFeatures std::optional, // crossAttentionMask SizeType32, // numReturnSequences std::optional, // eagleConfig @@ -643,7 +651,8 @@ void initRequestBindings(nb::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional // cacheSaltID + std::optional, // cacheSaltID + SizeType32 // numVocabs >(), // clang-format off nb::arg("input_token_ids"), @@ -676,6 +685,7 @@ void initRequestBindings(nb::module_& m) nb::arg("context_phase_params") = nb::none(), nb::arg("encoder_input_features") = nb::none(), nb::arg("encoder_output_length") = nb::none(), + nb::arg("decoder_context_features") = nb::none(), nb::arg("cross_attention_mask") = nb::none(), nb::arg("num_return_sequences") = 1, nb::arg("eagle_config") = nb::none(), @@ -683,8 +693,9 @@ void initRequestBindings(nb::module_& m) nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), - nb::arg("cache_salt_id") = nb::none() - ) // clang-format on + nb::arg("cache_salt_id") = nb::none(), + nb::arg("num_vocabs") = 1 + ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -719,6 +730,8 @@ void initRequestBindings(nb::module_& m) .def_prop_rw("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) .def_prop_rw( "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) + .def_prop_rw("decoder_context_features", &tle::Request::getDecoderContextFeatures, + &tle::Request::setDecoderContextFeatures) .def_prop_rw("cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) .def_prop_rw("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) .def_prop_rw( @@ -727,6 +740,7 @@ void initRequestBindings(nb::module_& m) "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) + .def_prop_rw("num_vocabs", &tle::Request::getNumVocabs, &tle::Request::setNumVocabs) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def("__getstate__", requestGetstate) .def("__setstate__", requestSetstate); diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 388819b957a..5b445c2be7e 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -294,7 +294,7 @@ void initBindings(nb::module_& m) nb::call_guard()) .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), - nb::call_guard()) + nb::arg("vocab_size") = 0, nb::call_guard()) .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("decoder_state"), nb::arg("input"), nb::call_guard()) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference) diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index 717ab3083e5..041311a283a 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -40,12 +40,13 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, tensorrt_llm::kernels::ContextFMHAType context_fmha_type, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, - bool use_paged_context_fmha, bool use_fp8_context_fmha, bool has_full_attention_mask, bool use_cache, - bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable, - int32_t spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank, - int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size, - int cp_rank, std::set cp_group) + bool qkv_bias_enabled, bool cross_attention, bool compute_attention_prior, bool apply_attention_prior, + int attention_prior_lookahead, int attention_prior_window_left, int attention_prior_window_right, int max_distance, + bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha, + bool has_full_attention_mask, bool use_cache, bool is_spec_decoding_enabled, + bool spec_decoding_is_generation_length_variable, int32_t spec_decoding_max_generation_length, bool is_mla_enabled, + int q_lora_rank, int kv_lora_rank, int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, + bool skip_attn, int cp_size, int cp_rank, std::set cp_group) : mResource{DecoderXQARunner::getResourceGlobal()} { mLayerIdx = layer_idx; @@ -85,6 +86,11 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, mMaxContextLength = max_context_length; mQKVBiasEnabled = qkv_bias_enabled; mCrossAttention = cross_attention; + mComputeAttentionPrior = compute_attention_prior; + mApplyAttentionPrior = apply_attention_prior; + mAttentionPriorLookahead = attention_prior_lookahead; + mAttentionPriorWindowLeft = attention_prior_window_left; + mAttentionPriorWindowRight = attention_prior_window_right; mMaxDistance = max_distance; mPosShiftEnabled = pos_shift_enabled; mDenseContextFMHA = dense_context_fmha; @@ -149,6 +155,11 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng read(d, mMaxContextLength); read(d, mQKVBiasEnabled); read(d, mCrossAttention); + read(d, mComputeAttentionPrior); + read(d, mApplyAttentionPrior); + read(d, mAttentionPriorLookahead); + read(d, mAttentionPriorWindowLeft); + read(d, mAttentionPriorWindowRight); read(d, mMaxDistance); read(d, mPosShiftEnabled); read(d, mDenseContextFMHA); @@ -214,13 +225,14 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode + sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) - + sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) - + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA) + sizeof(mFP8AttenOutput) + sizeof(mHasFullAttentionMask) - + sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mUseLognScaling) + sizeof(mIsSpecDecodingEnabled) - + sizeof(mUseSpecDecoding) + sizeof(mSpecDecodingIsGenerationLengthVariable) - + sizeof(mSpecDecodingMaxGenerationLength) + sizeof(mNbMultiBlockSemaphores) + sizeof(mIsMLAEnabled) - + sizeof(mMLAParams) + sizeof(mFuseFp4Quant) + sizeof(mSkipAttn) - + sizeof(uint32_t) // size of DecoderXQARunnerResource buffer. + + sizeof(mCrossAttention) + sizeof(mComputeAttentionPrior) + sizeof(mApplyAttentionPrior) + + sizeof(mAttentionPriorLookahead) + sizeof(mAttentionPriorWindowLeft) + sizeof(mAttentionPriorWindowRight) + + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + + sizeof(mFP8ContextFMHA) + sizeof(mFP8AttenOutput) + sizeof(mHasFullAttentionMask) + sizeof(mUseKVCache) + + sizeof(mUnfuseQkvGemm) + sizeof(mUseLognScaling) + sizeof(mIsSpecDecodingEnabled) + sizeof(mUseSpecDecoding) + + sizeof(mSpecDecodingIsGenerationLengthVariable) + sizeof(mSpecDecodingMaxGenerationLength) + + sizeof(mNbMultiBlockSemaphores) + sizeof(mIsMLAEnabled) + sizeof(mMLAParams) + sizeof(mFuseFp4Quant) + + sizeof(mSkipAttn) + sizeof(uint32_t) // size of DecoderXQARunnerResource buffer. + sizeof(mCpSize) + sizeof(mCpRank) + sizeof(int32_t) * mCpGroup.size() + mResource->getSerializationSize(); } @@ -264,6 +276,11 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept write(d, mMaxContextLength); write(d, mQKVBiasEnabled); write(d, mCrossAttention); + write(d, mComputeAttentionPrior); + write(d, mApplyAttentionPrior); + write(d, mAttentionPriorLookahead); + write(d, mAttentionPriorWindowLeft); + write(d, mAttentionPriorWindowRight); write(d, mMaxDistance); write(d, mPosShiftEnabled); write(d, mDenseContextFMHA); @@ -347,6 +364,11 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon() mPluginAttributes.emplace_back(PluginField("max_context_length", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("qkv_bias_enabled", nullptr, PluginFieldType::kINT8)); mPluginAttributes.emplace_back(PluginField("do_cross_attention", nullptr, PluginFieldType::kINT8)); + mPluginAttributes.emplace_back(PluginField("compute_attention_prior", nullptr, PluginFieldType::kINT8)); + mPluginAttributes.emplace_back(PluginField("apply_attention_prior", nullptr, PluginFieldType::kINT8)); + mPluginAttributes.emplace_back(PluginField("attention_prior_lookahead", nullptr, PluginFieldType::kINT32)); + mPluginAttributes.emplace_back(PluginField("attention_prior_window_left", nullptr, PluginFieldType::kINT32)); + mPluginAttributes.emplace_back(PluginField("attention_prior_window_right", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("max_distance", nullptr, PluginFieldType::kINT32)); mPluginAttributes.emplace_back(PluginField("pos_shift_enabled", nullptr, PluginFieldType::kINT8)); mPluginAttributes.emplace_back(PluginField("dense_context_fmha", nullptr, PluginFieldType::kINT8)); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h index dd87d67aab9..54a4bd8ecf0 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h @@ -53,13 +53,15 @@ class GPTAttentionPluginCommon : public BasePlugin, public tensorrt_llm::common: tensorrt_llm::kernels::AttentionMaskType mask_type, tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false, - int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false, - bool use_paged_context_fmha = true, bool use_fp8_context_fmha = true, bool has_full_attention_mask = false, - bool use_cache = true, bool is_spec_decoding_enabled = false, - bool spec_decoding_is_generation_length_variable = false, int32_t spec_decoding_max_generation_length = 1, - bool is_mla_enabled = false, int q_lora_rank = 0, int kv_lora_rank = 0, int qk_nope_head_dim = 0, - int qk_rope_head_dim = 0, int v_head_dim = 0, bool fuse_fp4_quant = false, bool skip_attn = false, - int cp_size = 1, int cp_rank = 0, std::set cp_group = {}); + bool compute_attention_prior = false, bool apply_attention_prior = false, int attention_prior_lookahead = 5, + int attention_prior_window_left = 1, int attention_prior_window_right = 5, int max_distance = 0, + bool pos_shift_enabled = false, bool dense_context_fmha = false, bool use_paged_context_fmha = true, + bool use_fp8_context_fmha = true, bool has_full_attention_mask = false, bool use_cache = true, + bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false, + int32_t spec_decoding_max_generation_length = 1, bool is_mla_enabled = false, int q_lora_rank = 0, + int kv_lora_rank = 0, int qk_nope_head_dim = 0, int qk_rope_head_dim = 0, int v_head_dim = 0, + bool fuse_fp4_quant = false, bool skip_attn = false, int cp_size = 1, int cp_rank = 0, + std::set cp_group = {}); GPTAttentionPluginCommon(void const* data, size_t length); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 861b9332dd5..9b69865c2d6 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -59,23 +59,25 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_ tensorrt_llm::kernels::ContextFMHAType context_fmha_type, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, - bool use_paged_context_fmha, bool use_fp8_context_fmha, bool has_full_attention_mask, bool use_cache, - bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable, - int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank, - int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, bool skip_attn, int cp_size, - int cp_rank, std::set cp_group) + bool qkv_bias_enabled, bool cross_attention, bool compute_attention_prior, bool apply_attention_prior, + int attention_prior_lookahead, int attention_prior_window_left, int attention_prior_window_right, int max_distance, + bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha, + bool has_full_attention_mask, bool use_cache, bool is_spec_decoding_enabled, + bool spec_decoding_is_generation_length_variable, int spec_decoding_max_generation_length, bool is_mla_enabled, + int q_lora_rank, int kv_lora_rank, int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool fuse_fp4_quant, + bool skip_attn, int cp_size, int cp_rank, std::set cp_group) : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, num_kv_heads_origin, head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type, kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, - type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, - dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache, - is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, - is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, fuse_fp4_quant, - skip_attn, cp_size, cp_rank, cp_group) + type, max_context_length, qkv_bias_enabled, cross_attention, compute_attention_prior, apply_attention_prior, + attention_prior_lookahead, attention_prior_window_left, attention_prior_window_right, max_distance, + pos_shift_enabled, dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, + use_cache, is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, + spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, + qk_rope_head_dim, v_head_dim, fuse_fp4_quant, skip_attn, cp_size, cp_rank, cp_group) { TLLM_CHECK_WITH_INFO( !is_mla_enabled, "GPTAttentionPlugin no longer supports MLA. Please use the PyTorch workflow instead."); @@ -123,6 +125,7 @@ std::string GPTAttentionPlugin::toString(IdxEntry const& entry) const TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(CROSS_KV); TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(CROSS_KV_LENGTH); TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(ENCODER_INPUT_LENGTH); + TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(ATTENTION_PRIOR_FOCUS); TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(HOST_CONTEXT_LENGTH); TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(QKV_BIAS_TENSOR); TLLM_GPT_ATTN_IDX_ENTRY_TO_STRING(SPEC_DECODING_GENERATION_LENGTHS); @@ -180,6 +183,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const case IdxEntry::CROSS_KV_LENGTH: return isCrossAttention(); case IdxEntry::LOGN_SCALING: return isLognScaling(); case IdxEntry::ENCODER_INPUT_LENGTH: return isCrossAttention(); + case IdxEntry::ATTENTION_PRIOR_FOCUS: return ApplyAttentionPrior(); case IdxEntry::HOST_CONTEXT_LENGTH: return mRemovePadding; case IdxEntry::QKV_BIAS_TENSOR: return mQKVBiasEnabled; case IdxEntry::SPEC_DECODING_GENERATION_LENGTHS: return mIsSpecDecodingEnabled; @@ -304,7 +308,6 @@ nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions( } else { - TLLM_CHECK(outputIndex == 0 || (!mPagedKVCache && useKVCache() && outputIndex == 1)); if (outputIndex == 0) { auto ret = inputs[getIdx(IdxEntry::QKV_TENSOR)]; @@ -315,7 +318,27 @@ nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions( return ret; } } - return inputs[getIdx(IdxEntry::PAST_KEY_VALUE)]; + int out_idx = mFuseFp4Quant ? 2 : 1; + if (!mPagedKVCache && useKVCache()) + { + if (outputIndex == out_idx) + { + return inputs[getIdx(IdxEntry::PAST_KEY_VALUE)]; + } + out_idx++; + } + if (mComputeAttentionPrior) + { + if (outputIndex == out_idx) + { + auto shape = nvinfer1::DimsExprs{1, + {exprBuilder.operation(DimensionOperation::kPROD, *inputs[getIdx(IdxEntry::ATTENTION_PRIOR_FOCUS)].d[0], + *exprBuilder.constant(mAttentionPriorLookahead))}}; + return shape; + } + out_idx++; + } + TLLM_CHECK_WITH_INFO(false, "Can't fetch output dimension for %d", outputIndex); } bool GPTAttentionPlugin::supportsFormatCombination( @@ -438,6 +461,16 @@ bool GPTAttentionPlugin::supportsFormatCombination( posCaseLine = __LINE__; result = inOut[pos].type == nvinfer1::DataType::kINT32; } + else if (ComputeAttentionPrior() && pos == (nbInputs + nbOutputs - 1)) + { + posCaseLine = __LINE__; + result = inOut[pos].type == nvinfer1::DataType::kFLOAT; + } + else if (ApplyAttentionPrior() && (pos == getIdx(IdxEntry::ATTENTION_PRIOR_FOCUS))) + { + posCaseLine = __LINE__; + result = inOut[pos].type == nvinfer1::DataType::kINT32; + } else if (isLognScaling() && pos == getIdx(IdxEntry::LOGN_SCALING)) { return inOut[pos].type == nvinfer1::DataType::kFLOAT; @@ -470,8 +503,8 @@ bool GPTAttentionPlugin::supportsFormatCombination( posCaseLine = __LINE__; result = (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } - TLLM_LOG_DEBUG( - "%s: pos: %d, result: %d, posCaseLine: %d", __PRETTY_FUNCTION__, pos, static_cast(result), posCaseLine); + TLLM_LOG_DEBUG("%s: pos: %d, result: %d, posCaseLine: %d. Number of inputs %d, number of outputs %d", + __PRETTY_FUNCTION__, pos, static_cast(result), posCaseLine, nbInputs, nbOutputs); return result; } @@ -645,16 +678,17 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, TLLM_CHECK(mRemovePadding && mPagedKVCache); } + auto nbGenerationSeq = nbSeq - nbContextRequests; if (nbContextRequests > 0) { auto seqIdxBeg = 0; auto tokenIdxBeg = 0; auto localNbTokens = contextTokenIdxEnd; - enqueueSome(seqIdxBeg, nbContextRequests, tokenIdxBeg, localNbTokens, + enqueueSome(seqIdxBeg, nbContextRequests, 0, tokenIdxBeg, localNbTokens, inputDesc, outputDesc, inputs, outputs, workspace, stream); } - if (auto nbGenerationSeq = nbSeq - nbContextRequests; nbGenerationSeq > 0) + if (nbGenerationSeq > 0) { auto seqIdxBeg = nbContextRequests; auto tokenIdxBeg = mCpSize > 1 ? contextTokenIdxEndForCp : contextTokenIdxEnd; @@ -664,8 +698,8 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, auto localNbTokens = mRemovePadding ? inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] - tokenIdxBeg : inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] * inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1]; - enqueueSome(seqIdxBeg, nbGenerationSeq, tokenIdxBeg, localNbTokens, inputDesc, - outputDesc, inputs, outputs, workspace, stream); + enqueueSome(seqIdxBeg, nbGenerationSeq, nbContextRequests, tokenIdxBeg, + localNbTokens, inputDesc, outputDesc, inputs, outputs, workspace, stream); } sync_check_cuda_error(stream); @@ -675,8 +709,8 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, } template -int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t tokenIdxBeg, int32_t localNbTokens, - nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, +int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t contextNbSeq, int32_t tokenIdxBeg, + int32_t localNbTokens, nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { // relative_attention_bias [head_num, max_seq_len, max_seq_len] (optional in relative position) @@ -787,7 +821,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 int max_encoder_context_len = isCrossAttention() ? inputDesc[getIdx(IdxEntry::CROSS_KV_LENGTH)].dims.d[0] : 0; // for enc-dec model, since decoder_input_ids could be longer than 1, // such model has an encoder context (for cross attn) and an decoder context (for self attn) - // clarify 3 lens: + // clarify 3 lensii: // -- max_context_q_len: len of decoder input. No "max" concept, it's what it is given. // Also called (decoder_)input_seq_length, normally 1 for encoder-decoder start token // -- max_seq_len: max allowed len of decoder output, i.e. final results @@ -1118,6 +1152,30 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 enqueue_params.spec_decoding_is_generation_length_variable = mSpecDecodingIsGenerationLengthVariable; enqueue_params.spec_decoding_max_generation_length = mSpecDecodingMaxGenerationLength; } + if (isCrossAttention()) + { + if (mComputeAttentionPrior) + { + // the attention prior is always last + float* attention_prior_scores_out = static_cast(outputs[getNbOutputs() - 1]); + // advance the prior scores pointer, skipping the space reserved for context requests + attention_prior_scores_out += contextNbSeq * mAttentionPriorLookahead; + enqueue_params.attention_prior_scores = attention_prior_scores_out; + } + else + { + enqueue_params.attention_prior_scores = nullptr; + } + if (mApplyAttentionPrior || mComputeAttentionPrior) + { + enqueue_params.attention_prior_focus + = static_cast(inputs[getIdx(IdxEntry::ATTENTION_PRIOR_FOCUS)]); + } + else + { + enqueue_params.attention_prior_focus = nullptr; + } + } if (mFuseFp4Quant) { enqueue_params.start_token_idx_sf = tokenIdxBeg; @@ -1226,26 +1284,43 @@ nvinfer1::DataType GPTAttentionPlugin::getOutputDataType( { if (mFuseFp4Quant) { - TLLM_CHECK(index == 0 || index == 1 || (!mPagedKVCache && useKVCache() && index == 2)); + if (index == 0) + { + return nvinfer1::DataType::kFP4; + } + if (index == 1) + { + return nvinfer1::DataType::kFP8; + } } else { - TLLM_CHECK(index == 0 || (!mPagedKVCache && useKVCache() && index == 1)); + if (index == 0) + { + return mFP8ContextFMHA && mEnableContextFMHA ? nvinfer1::DataType::kFP8 + : inputTypes[getIdx(IdxEntry::QKV_TENSOR)]; + } } - if (index == 0) + int out_idx = mFuseFp4Quant ? 2 : 1; + if (!mPagedKVCache && useKVCache()) { - if (mFuseFp4Quant) + if (index == out_idx) { - return nvinfer1::DataType::kFP4; + return inputTypes[getIdx(IdxEntry::PAST_KEY_VALUE)]; } - return mFP8ContextFMHA && mEnableContextFMHA ? nvinfer1::DataType::kFP8 - : inputTypes[getIdx(IdxEntry::QKV_TENSOR)]; + out_idx++; } - if (mFuseFp4Quant && index == 1) + + if (mComputeAttentionPrior) { - return nvinfer1::DataType::kFP8; + if (index == out_idx) + { + return nvinfer1::DataType::kFLOAT; + } + out_idx++; } - return inputTypes[getIdx(IdxEntry::PAST_KEY_VALUE)]; + + TLLM_CHECK_WITH_INFO(false, "Can't fetch output type for %d", index); } // IPluginV2 Methods @@ -1263,7 +1338,11 @@ char const* GPTAttentionPlugin::getPluginVersion() const noexcept int GPTAttentionPlugin::getNbOutputs() const noexcept { int nbOutputs = mFuseFp4Quant ? 2 : 1; - if (!mPagedKVCache && useKVCache()) + if (!mPagedKVCache && useKVCache()) // corresponds to PAST_KEY_VALUE + { + nbOutputs += 1; + } + if (mComputeAttentionPrior) // corresponds to ATTENTION_PRIOR_SCORES { nbOutputs += 1; } @@ -1340,6 +1419,11 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField p.getScalar("max_context_length").value(), static_cast(p.getScalar("qkv_bias_enabled").value()), static_cast(p.getScalar("do_cross_attention").value()), + static_cast(p.getScalar("compute_attention_prior").value()), + static_cast(p.getScalar("apply_attention_prior").value()), + p.getScalar("attention_prior_lookahead").value(), + p.getScalar("attention_prior_window_left").value(), + p.getScalar("attention_prior_window_right").value(), static_cast(p.getScalar("max_distance").value()), static_cast(p.getScalar("pos_shift_enabled").value()), static_cast(p.getScalar("dense_context_fmha").value()), diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index 13a3f0ecc66..f7d2b749858 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -114,13 +114,15 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon tensorrt_llm::kernels::AttentionMaskType mask_type, tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false, - int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false, - bool use_paged_context_fmha = true, bool use_fp8_context_fmha = true, bool has_full_attention_mask = false, - bool use_cache = true, bool is_spec_decoding_enabled = false, - bool spec_decoding_is_generation_length_variable = false, int spec_decoding_max_generation_length = 1, - bool is_mla_enabled = false, int q_lora_rank = 0, int kv_lora_rank = 0, int qk_nope_head_dim = 0, - int qk_rope_head_dim = 0, int v_head_dim = 0, bool fuse_fp4_quant = false, bool skip_attn = false, - int cp_size = 1, int cp_rank = 0, std::set cp_group = {}); + bool compute_attention_prior = false, bool apply_attention_prior = false, int attention_prior_lookahead = 5, + int attention_prior_window_left = 1, int attention_prior_window_right = 5, int max_distance = 0, + bool pos_shift_enabled = false, bool dense_context_fmha = false, bool use_paged_context_fmha = true, + bool use_fp8_context_fmha = true, bool has_full_attention_mask = false, bool use_cache = true, + bool is_spec_decoding_enabled = false, bool spec_decoding_is_generation_length_variable = false, + int spec_decoding_max_generation_length = 1, bool is_mla_enabled = false, int q_lora_rank = 0, + int kv_lora_rank = 0, int qk_nope_head_dim = 0, int qk_rope_head_dim = 0, int v_head_dim = 0, + bool fuse_fp4_quant = false, bool skip_attn = false, int cp_size = 1, int cp_rank = 0, + std::set cp_group = {}); GPTAttentionPlugin(void const* data, size_t length); @@ -173,9 +175,10 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon private: template - int enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t tokenIdxBeg, int32_t localNbTokens, - nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, - void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t contextNbSeq, int32_t tokenIdxBeg, + int32_t localNbTokens, nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream); using IndexType = std::int32_t; @@ -210,6 +213,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon CROSS_KV, CROSS_KV_LENGTH, ENCODER_INPUT_LENGTH, + ATTENTION_PRIOR_FOCUS, HOST_CONTEXT_LENGTH, QKV_BIAS_TENSOR, SPEC_DECODING_GENERATION_LENGTHS, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 9bcd22e39e4..420d7fef2d5 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -138,7 +138,7 @@ void initBindings(pybind11::module_& m) .def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming) .def_readwrite("end_id", &GenLlmReq::mEndId) .def_readwrite("pad_id", &GenLlmReq::mPadId) - .def_readwrite("seq_slot", &GenLlmReq::mSeqSlot) + .def_readwrite("seq_slots", &GenLlmReq::mSeqSlots) .def_property_readonly("return_log_probs", &GenLlmReq::returnLogProbs) .def_property_readonly("return_context_logits", &GenLlmReq::getReturnContextLogits) .def_property_readonly("return_generation_logits", &GenLlmReq::getReturnGenerationLogits) @@ -287,7 +287,8 @@ void initBindings(pybind11::module_& m) std::optional client_id, executor::PriorityType priority, std::optional encoder_input_features, std::optional encoder_output_length, - std::optional cross_attention_mask, tb::LlmRequestType llm_request_type, + std::optional decoder_context_features, std::optional cross_attention_mask, + tb::LlmRequestType llm_request_type, std::optional input_token_extra_ids, tb::LlmRequest::SizeType32 num_return_sequences, std::optional eagle_config, std::optional skip_cross_attn_blocks, bool return_perf_metrics, @@ -322,6 +323,7 @@ void initBindings(pybind11::module_& m) auto lora_config_tensor_ptr = makeOptionalTensor(lora_config); auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits); auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features); + auto decoder_context_features_tensor_ptr = makeOptionalTensor(decoder_context_features); auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); @@ -334,9 +336,9 @@ void initBindings(pybind11::module_& m) return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr, exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched, encoder_input_tokens, return_encoder_output, client_id, priority, - encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, - llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, - skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, + encoder_input_features_tensor_ptr, encoder_output_length, decoder_context_features_tensor_ptr, + cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, + eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time}; }), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), @@ -356,7 +358,8 @@ void initBindings(pybind11::module_& m) py::arg("apply_logits_post_processor_batched") = false, py::arg("encoder_input_tokens") = std::nullopt, py::arg("return_encoder_output") = false, py::arg("client_id") = std::nullopt, py::arg("priority") = executor::Request::kDefaultPriority, py::arg("encoder_input_features") = std::nullopt, - py::arg("encoder_output_len") = std::nullopt, py::arg("cross_attention_mask") = std::nullopt, + py::arg("encoder_output_len") = std::nullopt, py::arg("decoder_context_features") = std::nullopt, + py::arg("cross_attention_mask") = std::nullopt, py::arg_v("llm_request_type", tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, "LlmRequestType.LLMREQUEST_TYPE_CONTEXT_AND_GENERATION"), py::arg("input_token_extra_ids") = std::nullopt, py::arg("num_return_sequences") = 1, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index bcc9d4bf13f..42524dee0dd 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -114,6 +114,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mPriority, // from_torch(mEncoderInputFeatures), // mEncoderOutputLength, // + from_torch(mDecoderContextFeatures), // from_torch(mCrossAttentionMask), // getLlmRequestType(), // std::nullopt, // inputTokenExtraIds diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index b43fb8dd073..0c879692cb8 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -76,6 +76,7 @@ class LlmRequest : public tb::GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, + std::optional decoderContextFeatures = std::nullopt, std::optional crossAttentionMask = std::nullopt, tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, @@ -135,6 +136,7 @@ class LlmRequest : public tb::GenericLlmRequest priority, // encoderInputFeatures, // encoderOutputLength, // + decoderContextFeatures, // crossAttentionMask, // llmRequestType, // inputTokenExtraIds // diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 8fe558d973d..431a1195c9c 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -272,9 +272,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) py::class_(m, "ModelConfig") .def(py::init(), py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), - py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type")) + py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"), py::arg("vocab_sizes") = py::none()) .def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize) - .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size")) + .def( + "vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"), py::arg("vocab_size") = 0) .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1, py::arg("pipeline_parallelism_rank") = 0) .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1, @@ -287,6 +288,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize) .def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead) .def_property_readonly("data_type", &tr::ModelConfig::getDataType) + .def_property_readonly("num_vocabs", &tr::ModelConfig::getNumVocabs) + .def_property_readonly("vocab_sizes", &tr::ModelConfig::getVocabSizes) .def_property_readonly("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode) .def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) .def_property( @@ -296,6 +299,20 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) py::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) .def_property("use_packed_input", py::overload_cast<>(&tr::ModelConfig::usePackedInput, py::const_), py::overload_cast(&tr::ModelConfig::usePackedInput)) + .def_property("use_attention_prior", py::overload_cast<>(&tr::ModelConfig::useAttentionPrior, py::const_), + py::overload_cast(&tr::ModelConfig::useAttentionPrior)) + .def_property("use_context_embeddings", py::overload_cast<>(&tr::ModelConfig::useContextEmbeddings, py::const_), + py::overload_cast(&tr::ModelConfig::useContextEmbeddings)) + .def_property("compute_attention_prior_from_layers", &tr::ModelConfig::getComputeAttentionPriorFromLayers, + &tr::ModelConfig::setComputeAttentionPriorFromLayers) + .def_property("apply_attention_prior_to_layers", &tr::ModelConfig::getApplyAttentionPriorToLayers, + &tr::ModelConfig::setApplyAttentionPriorToLayers) + .def_property("attention_prior_lookahead", &tr::ModelConfig::getAttentionPriorLookahead, + &tr::ModelConfig::setAttentionPriorLookahead) + .def_property("attention_prior_window_left", &tr::ModelConfig::getAttentionPriorWindowLeft, + &tr::ModelConfig::setAttentionPriorWindowLeft) + .def_property("attention_prior_window_right", &tr::ModelConfig::getAttentionPriorWindowRight, + &tr::ModelConfig::setAttentionPriorWindowRight) .def_property("kv_cache_type", py::overload_cast<>(&tr::ModelConfig::getKVCacheType, py::const_), py::overload_cast(&tr::ModelConfig::setKVCacheType)) .def_property("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock) @@ -366,7 +383,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { - if (t.size() != 19) + if (t.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -391,6 +408,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) config.numReturnSequences = t[16].cast(); config.minP = t[17].cast>(); config.beamWidthArray = t[18].cast>>(); + config.cfgScale = t[19].cast>(); return config; }; @@ -418,6 +436,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("num_return_sequences", &tr::SamplingConfig::numReturnSequences) .def_readwrite("min_p", &tr::SamplingConfig::minP) .def_readwrite("beam_width_array", &tr::SamplingConfig::beamWidthArray) + .def_readwrite("cfg_scale", &tr::SamplingConfig::cfgScale) .def_readwrite("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs) .def(py::pickle(SamplingConfigGetState, SamplingConfigSetState)) .def("__eq__", &tr::SamplingConfig::operator==); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 097f598557b..47d8faab20f 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -77,51 +77,53 @@ void initRequestBindings(pybind11::module_& m) }; auto samplingConfigSetstate = [](py::tuple const& state) { - if (state.size() != 19) + if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } - return tle::SamplingConfig(state[0].cast(), // BeamWidth - state[1].cast>(), // TopK - state[2].cast>(), // TopP - state[3].cast>(), // TopPMin - state[4].cast>(), // TopPResetIds - state[5].cast>(), // TopPDecay - state[6].cast>(), // Seed - state[7].cast>(), // Temperature - state[8].cast>(), // MinTokens - state[9].cast>(), // BeamSearchDiversityRate - state[10].cast>(), // RepetitionPenalty - state[11].cast>(), // PresencePenalty - state[12].cast>(), // FrequencyPenalty - state[13].cast>(), // LengthPenalty - state[14].cast>(), // EarlyStopping - state[15].cast>(), // NoRepeatNgramSize - state[16].cast>(), // NumReturnSequences - state[17].cast>(), // MinP - state[18].cast>>() // BeamWidthArray + return tle::SamplingConfig(state[0].cast(), // BeamWidth + state[1].cast>(), // TopK + state[2].cast>(), // TopP + state[3].cast>(), // TopPMin + state[4].cast>(), // TopPResetIds + state[5].cast>(), // TopPDecay + state[6].cast>(), // Seed + state[7].cast>(), // Temperature + state[8].cast>(), // MinTokens + state[9].cast>(), // BeamSearchDiversityRate + state[10].cast>(), // RepetitionPenalty + state[11].cast>(), // PresencePenalty + state[12].cast>(), // FrequencyPenalty + state[13].cast>(), // LengthPenalty + state[14].cast>(), // EarlyStopping + state[15].cast>(), // NoRepeatNgramSize + state[16].cast>(), // NumReturnSequences + state[17].cast>(), // MinP + state[18].cast>>(), // BeamWidthArray + state[19].cast>() // CfgScale ); }; py::class_(m, "SamplingConfig") .def(py::init const&, // beamWidth - std::optional const&, // topP - std::optional const&, // topPMin - std::optional const&, // topPResetIds - std::optional const&, // topPDecay - std::optional const&, // seed - std::optional const&, // temperature - std::optional const&, // minTokens - std::optional const&, // beamSearchDiversityRate - std::optional const&, // repetitionPenalty - std::optional const&, // presencePenalty - std::optional const&, // frequencyPenalty - std::optional const&, // lengthPenalty - std::optional const&, // earlyStopping - std::optional const&, // noRepeatNgramSize - std::optional const&, // numReturnSequences - std::optional const&, // minP - std::optional> const& // beamWidthArray + std::optional const&, // beamWidth + std::optional const&, // topP + std::optional const&, // topPMin + std::optional const&, // topPResetIds + std::optional const&, // topPDecay + std::optional const&, // seed + std::optional const&, // temperature + std::optional const&, // minTokens + std::optional const&, // beamSearchDiversityRate + std::optional const&, // repetitionPenalty + std::optional const&, // presencePenalty + std::optional const&, // frequencyPenalty + std::optional const&, // lengthPenalty + std::optional const&, // earlyStopping + std::optional const&, // noRepeatNgramSize + std::optional const&, // numReturnSequences + std::optional const&, // minP + std::optional> const&, // beamWidthArray + std::optional const& // CfgScale >(), // clang-format off py::arg("beam_width") = 1, @@ -143,7 +145,8 @@ void initRequestBindings(pybind11::module_& m) py::arg("no_repeat_ngram_size") = py::none(), py::arg("num_return_sequences") = py::none(), py::arg("min_p") = py::none(), - py::arg("beam_width_array") = py::none()) // clang-format on + py::arg("beam_width_array") = py::none(), + py::arg("cfg_scale") = py::none()) // clang-format on .def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) @@ -170,6 +173,7 @@ void initRequestBindings(pybind11::module_& m) .def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) .def_property( "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) + .def_property("cfg_scale", &tle::SamplingConfig::getCfgScale, &tle::SamplingConfig::setCfgScale) .def(py::pickle(samplingConfigGetstate, samplingConfigSetstate)); auto additionalModelOutputGetstate @@ -525,12 +529,12 @@ void initRequestBindings(pybind11::module_& m) self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), - self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID()); + self.getDecoderContextFeatures(), self.getCrossAttentionMask(), self.getEagleConfig(), + self.getSkipCrossAttnBlocks(), self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getNumVocabs()); }; auto requestSetstate = [](py::tuple const& state) { - if (state.size() != 34) + if (state.size() != 36) { throw std::runtime_error("Invalid Request state!"); } @@ -549,9 +553,10 @@ void initRequestBindings(pybind11::module_& m) state[22].cast>(), state[23].cast(), state[24].cast(), state[25].cast(), state[26].cast>(), state[27].cast>(), state[28].cast>(), - state[29].cast>(), 1, state[30].cast>(), - state[31].cast>(), state[32].cast>(), - state[33].cast>()); + state[29].cast>() state[30].cast>(), 1, + state[31].cast>(), state[32].cast>(), + state[33].cast>(), + state[34].cast>(), state[35].cast()); }; py::class_ request(m, "Request", pybind11::dynamic_attr()); @@ -585,6 +590,7 @@ void initRequestBindings(pybind11::module_& m) std::optional, // contextPhaseParams std::optional, // encoderInputFeatures std::optional, // encoderOutputLength + std::optional, // decoderContextFeatures std::optional, // crossAttentionMask SizeType32, // numReturnSequences std::optional, // eagleConfig @@ -592,7 +598,8 @@ void initRequestBindings(pybind11::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional // cacheSaltID + std::optional, // cacheSaltID + tle::SizeType32 // numVocabs >(), // clang-format off py::arg("input_token_ids"), @@ -626,6 +633,7 @@ void initRequestBindings(pybind11::module_& m) py::arg("context_phase_params") = py::none(), py::arg("encoder_input_features") = py::none(), py::arg("encoder_output_length") = py::none(), + py::arg("decoder_context_features") = py::none(), py::arg("cross_attention_mask") = py::none(), py::arg("num_return_sequences") = 1, py::arg("eagle_config") = py::none(), @@ -633,7 +641,8 @@ void initRequestBindings(pybind11::module_& m) py::arg("guided_decoding_params") = py::none(), py::arg("language_adapter_uid") = py::none(), py::arg("allotted_time_ms") = py::none(), - py::arg("cache_salt_id") = py::none() + py::arg("cache_salt_id") = py::none(), + py::arg()"num_vocabs") = 1 ) // clang-format on .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_tokens", &tle::Request::getMaxTokens) @@ -670,6 +679,8 @@ void initRequestBindings(pybind11::module_& m) .def_property("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) .def_property( "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) + def_property( + "decoder_context_features", &tle::Request::getDecoderContextFeatures, &tle::Request::setDecoderContextFeatures) .def_property( "cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) .def_property("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) @@ -679,6 +690,7 @@ void initRequestBindings(pybind11::module_& m) "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) .def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) + .def_property("num_vocabs", &tle::Request::getNumVocabs, &tle::Request::setNumVocabs) .def_property( "context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def(py::pickle(requestGetstate, requestSetstate)); diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 469aafe6476..c60ed049401 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -388,7 +388,7 @@ void initBindings(pybind11::module_& m) py::call_guard()) .def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"), - py::call_guard()) + py::arg("vocab_size") = 0, py::call_guard()) .def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input"), py::call_guard()) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference) diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index b5851dc1c2d..596d1ab5a3c 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -459,9 +459,10 @@ void DecoderState::disableLookahead(RequestVector const& genRequests) for (auto const& llmReq : genRequests) { - if (llmReq->mSeqSlot) + if (!llmReq->mSeqSlots.empty()) + { - setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1); + setNumDecodingEngineTokens(llmReq->mSeqSlots.at(0), 1); } } diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 6df7b1634b8..9513adbb62c 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -53,7 +53,7 @@ void GptDecoderBatched::disableLookahead(RequestVector const& genRequests, Tenso for (auto const& llmReq : genRequests) { samplingConfigs.push_back(llmReq->mSamplingConfig); - batchSlotsRange[batchIdx] = llmReq->mSeqSlot.value(); + batchSlotsRange[batchIdx] = llmReq->mSeqSlots.at(0); batchIdx += 1; } auto const batchSize = batchIdx; @@ -73,7 +73,7 @@ void GptDecoderBatched::disableLookahead(RequestVector const& genRequests, Tenso } void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, - nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) + nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, SizeType32 vocabSize) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(maxNumSequences > 0); @@ -89,8 +89,11 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max mDecoderStream = std::make_shared(); TLLM_CHECK(mDecoderStream->getDevice() == device); - auto const vocabSize = modelConfig.getVocabSize(); - auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize()); + if (vocabSize == 0) + { + vocabSize = modelConfig.getVocabSize(); + } + auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize(), vocabSize); mDecoder = IGptDecoder::create(mode, dtype, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, mDecoderStream, speculativeDecodingModulePtr); diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index 311f63eaf1e..0519a2d243e 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -193,6 +193,8 @@ ModelConfig createModelConfig(Json const& json, bool engineVersionNone, SizeType auto const vocabSize = config.at("vocab_size").template get(); auto const hiddenSize = config.at("hidden_size").template get() / tensorParallelism; auto const sizePerHead = parseJsonFieldOr(config, "head_size", hiddenSize / numHeads); + // Read vocab sizes if available, otherwise use single vocab size + auto vocabSizes = parseJsonFieldOptional>(config, "vocab_sizes"); // Logits datatype auto const logitsDtypeStr = parseJsonFieldOr(config, "logits_dtype", std::string("float32")); @@ -211,8 +213,8 @@ ModelConfig createModelConfig(Json const& json, bool engineVersionNone, SizeType auto numKvHeadsPerCrossAttentionLayer = parseJsonFieldOr>( config, "num_kv_heads_per_cross_attn_layer", std::vector()); - auto modelConfig - = ModelConfig{vocabSize, numLayers, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType}; + auto modelConfig = ModelConfig{ + vocabSize, numLayers, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType, vocabSizes}; if (!numKvHeadsPerAttentionLayer.empty()) { @@ -229,6 +231,23 @@ ModelConfig createModelConfig(Json const& json, bool engineVersionNone, SizeType modelConfig.setNbKvHeads(numKvHeads); } + auto const useAttentionPrior = parseJsonFieldOr(config, "use_attention_prior", false); + auto const useContextEmbeddings = parseJsonFieldOr(config, "use_context_embeddings", false); + auto const computeAttentionPriorFromLayers = parseJsonFieldOr>( + config, "compute_attention_prior_from_layers", std::vector()); + auto const applyAttentionPriorToLayers = parseJsonFieldOr>( + config, "apply_attention_prior_to_layers", std::vector()); + auto const attentionPriorLookahead = parseJsonFieldOr(config, "attention_prior_lookahead", 0); + auto const attentionPriorWindowLeft = parseJsonFieldOr(config, "attention_prior_window_left", 0); + auto const attentionPriorWindowRight = parseJsonFieldOr(config, "attention_prior_window_right", 0); + modelConfig.useAttentionPrior(useAttentionPrior); + modelConfig.useContextEmbeddings(useContextEmbeddings); + modelConfig.setComputeAttentionPriorFromLayers(computeAttentionPriorFromLayers); + modelConfig.setApplyAttentionPriorToLayers(applyAttentionPriorToLayers); + modelConfig.setAttentionPriorLookahead(attentionPriorLookahead); + modelConfig.setAttentionPriorWindowLeft(attentionPriorWindowLeft); + modelConfig.setAttentionPriorWindowRight(attentionPriorWindowRight); + if (!numKvHeadsPerCrossAttentionLayer.empty()) { std::transform(numKvHeadsPerCrossAttentionLayer.cbegin(), numKvHeadsPerCrossAttentionLayer.cend(), diff --git a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp index 4730a93b0e3..02650ef00d9 100644 --- a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp +++ b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp @@ -51,13 +51,13 @@ void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, SizeType32 const b if (bid < numContextRequests) { totalInputSize += reqPromptLengths[bid]; - promptTasksHost.insert(promptTasksHost.end(), reqPromptLengths[bid], taskId); + promptTasksHost.insert(promptTasksHost.end(), reqPromptLengths[bid] * numVocabs, taskId); } else { for (SizeType32 beam = 0; beam < reqBeamWidths[bid]; ++beam) { - promptTasksHost.insert(promptTasksHost.end(), 1, taskId); + promptTasksHost.insert(promptTasksHost.end(), numVocabs, taskId); totalInputSize++; } } @@ -80,7 +80,7 @@ void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, SizeType32 const b if (packedInput) { tasks = manager.copyFrom( - promptTasksHost, runtime::ITensor::makeShape({totalInputSize}), runtime::MemoryType::kGPU); + promptTasksHost, runtime::ITensor::makeShape({totalInputSize * numVocabs}), runtime::MemoryType::kGPU); } else { diff --git a/cpp/tensorrt_llm/runtime/runtimeKernels.cu b/cpp/tensorrt_llm/runtime/runtimeKernels.cu index 3b3dbcac894..4538b40bae2 100644 --- a/cpp/tensorrt_llm/runtime/runtimeKernels.cu +++ b/cpp/tensorrt_llm/runtime/runtimeKernels.cu @@ -404,6 +404,38 @@ void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const srcDataPtr, dstDataPtr, srcOffsetsPtr, dstOffsetsPtr, sizesPtr, static_cast(dataTypeSize)); } +namespace +{ +template +__global__ void add(T* data, std::size_t size, T const value) +{ + auto const tidx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + auto const stride = static_cast(blockDim.x) * gridDim.x; + + for (auto idx = tidx; idx < size; idx += stride) + { + data[idx] += value; + } +} +} // namespace + +template +void invokeAdd(IBuffer& buffer, T const value, CudaStream const& stream) +{ + auto data = bufferCast(buffer); + auto const size = buffer.getSize(); + dim3 const blockSize{256}; + std::size_t const gridx{tc::ceilDiv(size, blockSize.x)}; + std::size_t const gridMax{std::numeric_limits::max()}; + dim3 const gridSize{static_cast(std::min(gridx, gridMax))}; + + add<<>>(data, size, value); +} + +template void invokeAdd(IBuffer&, std::int32_t, CudaStream const&); +template void invokeAdd(IBuffer&, std::int8_t, CudaStream const&); +template void invokeAdd(IBuffer&, float, CudaStream const&); + void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream) { switch (input.getDataType()) diff --git a/cpp/tensorrt_llm/runtime/runtimeKernels.h b/cpp/tensorrt_llm/runtime/runtimeKernels.h index ae251877e58..7b5cd88cb39 100644 --- a/cpp/tensorrt_llm/runtime/runtimeKernels.h +++ b/cpp/tensorrt_llm/runtime/runtimeKernels.h @@ -43,6 +43,9 @@ void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& sl void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const& srcOffsets, IBuffer const& dstOffsets, IBuffer const& sizes, std::size_t maxStride, CudaStream const& stream); +template +void invokeAdd(IBuffer& buffer, T value, CudaStream const& stream); + void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream); void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream); diff --git a/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp b/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp index 8358e987334..7b291ce771e 100644 --- a/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp +++ b/cpp/tests/e2e_tests/batch_manager/guidedDecoderTest.cpp @@ -120,10 +120,10 @@ class GuidedDecoderTest : public ::testing::Test auto llmReq1 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); texec::GuidedDecodingParams guidedDecodingParams(texec::GuidedDecodingParams::GuideType::kJSON); llmReq1->setGuidedDecodingParams(guidedDecodingParams); - llmReq1->mSeqSlot = 1; + llmReq1->mSeqSlots.push_back(1); auto llmReq2 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); - llmReq2->mSeqSlot = 2; + llmReq2->mSeqSlots.push_back(2); RequestVector contextRequests{llmReq1, llmReq2}; RequestVector generationRequests{}; diff --git a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp index d08d2f5bc47..041ce15d0fa 100644 --- a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp @@ -54,7 +54,7 @@ TEST_F(LlmRequestTest, fromExecutorRequest) EXPECT_EQ(llmReq.getOrigPromptLen(), inputTokens.size()); EXPECT_EQ(llmReq.getMaxSentTokenLen(), inputTokens.size()); EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT); - EXPECT_FALSE(llmReq.mSeqSlot); + EXPECT_FALSE(llmReq.mSeqSlots.empty()); // No speculative decoding config, draft tokens should be empty EXPECT_EQ(llmReq.getDraftTokens()->size(), 0); EXPECT_FALSE(llmReq.getEmbeddingBias().has_value()); @@ -488,7 +488,7 @@ TEST_F(LlmRequestTest, testCreateRequests) EXPECT_EQ(childReq1->getState(), llmReq.getState()); EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector{8}); EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector{7}); - EXPECT_FALSE(childReq1->mSeqSlot); + EXPECT_FALSE(childReq1->mSeqSlots.empty())); } { diff --git a/examples/models/core/enc_dec/convert_checkpoint.py b/examples/models/core/enc_dec/convert_checkpoint.py index 9c1951975bf..c1cd5b8c372 100755 --- a/examples/models/core/enc_dec/convert_checkpoint.py +++ b/examples/models/core/enc_dec/convert_checkpoint.py @@ -1734,6 +1734,7 @@ def convert_checkpoint(args): 'hidden_size': decoder_config.hidden_size, 'norm_epsilon': decoder_config.layernorm_eps, 'vocab_size': decoder_config.vocab_size, + 'vocab_sizes': [decoder_config.vocab_size], 'position_embedding_type': decoder_config.position_embedding_type, 'hidden_act': decoder_config.hidden_act, 'quantization': { diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 8294a9b67b4..a28061f5a73 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -909,7 +909,9 @@ def optimize_model_with_config(model: PretrainedModel, if build_config.plugin_config.lora_plugin is not None: model.use_lora(build_config.lora_config) - is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel"] + is_enc_dec = model.config.architecture in [ + "EncoderModel", "DecoderModel", "T5TTSEncoderModel", "T5TTSDecoderModel" + ] # FusedMLP does not support RecurrentGemma FP8 currently. is_recurrent_gemma = model.config.architecture in [ "RecurrentGemmaForCausalLM" @@ -1349,8 +1351,11 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: build_config.lora_config.lora_target_modules } - if model.config.architecture == "DecoderModel" or "mllama" in model.config.architecture.lower( - ): + + if "mllama" in model.config.architecture.lower() or \ + model.config.architecture == "DecoderModel" or \ + model.config.architecture == "T5TTSDecoderModel": + prepare_input_args["max_seq_len"] = build_config.max_seq_len prepare_input_args[ "max_decoder_input_len"] = build_config.max_input_len diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 282febd262e..f99a9969100 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -5280,9 +5280,16 @@ def gpt_attention( host_kv_cache_pool_pointers: Tensor = None, host_kv_cache_pool_mapping: Tensor = None, do_cross_attention: bool = False, + compute_attention_prior: bool = False, + apply_attention_prior: bool = False, + attention_prior_lookahead: int = 5, + attention_prior_window_left: int = 1, + attention_prior_window_right: int = 5, cross_kv: Optional[Tensor] = None, # for cross attention cross_kv_length: Optional[Tensor] = None, # for cross attention encoder_input_lengths: Optional[Tensor] = None, # for cross attention + attention_prior_focus: Optional[ + Tensor] = None, # when applying prior is enabled relative_attention_bias: Optional[Tensor] = None, # for relative attention logn_scaling: Optional[Tensor] = None, # for logn scaling max_distance: int = 0, # for relative attention @@ -5504,6 +5511,16 @@ def gpt_attention( do_cross_attention: bool = False Do we use this as cross attention instead of self attention, + compute_attention_prior: bool = False, + Whether to accumulate attention prior scores in the kernel. + only valid for generation requests and cross attention. + uses attention_prior_scores provided lower + + apply_attention_prior: bool = False, + Whether to apply attention prior. + only valid for generation requests and cross attention. + uses attention_prior_focus provided lower + cross_kv: Tensor = None The KV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 2 * kvHeadNum * headSize] in padded mode and [1, num_tokens, 2 * kvHeadNum * headSize] in packed mode, @@ -5514,6 +5531,10 @@ def gpt_attention( encoder_input_lengths: Tensor The tensor that stores the length of each encoder input sequence. Its shape is [batch_size], + attention_prior_focus: Optional[Tensor] = None + (B,) for each sequence specifies where start of the region on which to focus in cross attention. + rest of the encoder outputs are masked out. + logn_scaling: Tensor = None The logn scaling tensor [max_position_embedding_len], which is applied to q in order to help extrapolation @@ -5790,6 +5811,26 @@ def gpt_attention( "do_cross_attention", np.array(np.int8(do_cross_attention), dtype=np.int8), trt.PluginFieldType.INT8) + compute_attention_prior_field = trt.PluginField( + "compute_attention_prior", + np.array(np.int8(compute_attention_prior), dtype=np.int8), + trt.PluginFieldType.INT8) + apply_attention_prior_field = trt.PluginField( + "apply_attention_prior", + np.array(np.int8(apply_attention_prior), dtype=np.int8), + trt.PluginFieldType.INT8) + attention_prior_lookahead_field = trt.PluginField( + "attention_prior_lookahead", + np.array(attention_prior_lookahead, dtype=np.int32), + trt.PluginFieldType.INT32) + attention_prior_window_left_field = trt.PluginField( + "attention_prior_window_left", + np.array(attention_prior_window_left, dtype=np.int32), + trt.PluginFieldType.INT32) + attention_prior_window_right_field = trt.PluginField( + "attention_prior_window_right", + np.array(attention_prior_window_right, dtype=np.int32), + trt.PluginFieldType.INT32) max_distance = trt.PluginField("max_distance", np.array(max_distance, dtype=np.int32), trt.PluginFieldType.INT32) @@ -5838,8 +5879,11 @@ def gpt_attention( block_sparse_block_size, block_sparse_homo_head_pattern, block_sparse_num_local_blocks, block_sparse_vertical_stride, paged_kv_cache, tokens_per_block, pf_type, max_context_length, - qkv_bias_enabled, do_cross_attention_field, max_distance, - pos_shift_enabled, dense_context_fmha, use_paged_context_fmha_field, + qkv_bias_enabled, do_cross_attention_field, + compute_attention_prior_field, apply_attention_prior_field, + attention_prior_lookahead_field, attention_prior_window_left_field, + attention_prior_window_right_field, max_distance, pos_shift_enabled, + dense_context_fmha, use_paged_context_fmha_field, use_fp8_context_fmha_field, has_full_attention_mask_field, use_cache_pf, is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length, is_mla_enabled, q_lora_rank, @@ -5913,6 +5957,9 @@ def gpt_attention( if do_cross_attention: plug_inputs += [cross_kv, cross_kv_length, encoder_input_lengths] + if apply_attention_prior: + plug_inputs += [attention_prior_focus] + if default_net().plugin_config.remove_input_padding: plug_inputs += [host_context_lengths] @@ -5973,10 +6020,19 @@ def gpt_attention( expected_outputs += 1 present_key_value = None + present_key_idx = -1 if use_cache and not paged_kv_cache_flag: present_key_value = _create_tensor(layer.get_output(expected_outputs), layer) assert present_key_value is not None + present_key_idx = expected_outputs + expected_outputs += 1 + + attention_prior_scores_out = None + if compute_attention_prior: + attention_prior_scores_out = _create_tensor( + layer.get_output(expected_outputs), layer) + assert attention_prior_scores_out is not None expected_outputs += 1 assert layer.num_outputs == expected_outputs, \ @@ -5984,6 +6040,9 @@ def gpt_attention( if kv_cache_quant_mode.has_int8_kv_cache( ) and not default_net().strongly_typed: + if present_key_idx >= 0: + # present key value + layer.get_output(present_key_idx).set_dynamic_range(-127, 127) if not paged_kv_cache_flag: # past key value layer.get_input(8).set_dynamic_range(-127, 127) @@ -5995,10 +6054,14 @@ def gpt_attention( layer.get_output(expected_outputs - 1).set_dynamic_range(-127, 127) assert output is not None + outputs = [output] if fuse_fp4_quant: assert output_sf is not None - return (output, output_sf), present_key_value - return output, present_key_value + outputs.append(output_sf) + outputs.append(present_key_value) + if compute_attention_prior: + outputs.append(attention_prior_scores_out) + return outputs def assertion(condition: Tensor, message: str = '') -> None: diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index f995b6390d3..85e4085f3ea 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -30,7 +30,8 @@ allgather, arange, bert_attention, cast, clip, concat, constant, embedding, expand, expand_dims, expand_mask, generate_alibi_biases, identity, generate_alibi_slopes, generate_logn_scaling, gpt_attention, matmul, - minimum, repeat_interleave, shape, slice, softmax, split, unsqueeze, where) + maximum, minimum, repeat_interleave, shape, slice, softmax, split, + unsqueeze, where) # isort: on from ..mapping import Mapping from ..module import Module, ModuleList @@ -386,6 +387,11 @@ def __init__(self, quant_mode: QuantMode = QuantMode(0), q_scaling=1.0, cross_attention=False, + compute_attention_prior=False, + apply_attention_prior=False, + attention_prior_lookahead=5, + attention_prior_window_left=1, + attention_prior_window_right=5, relative_attention=False, max_distance=0, num_buckets=0, @@ -408,6 +414,11 @@ def __init__(self, self.local_layer_idx = local_layer_idx self.cross_attention = cross_attention + self.compute_attention_prior = compute_attention_prior + self.apply_attention_prior = apply_attention_prior + self.attention_prior_lookahead = attention_prior_lookahead + self.attention_prior_window_left = attention_prior_window_left + self.attention_prior_window_right = attention_prior_window_right self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size assert num_attention_heads % tp_size == 0, \ @@ -792,6 +803,7 @@ def forward( kv_cache_params=None, attention_params=None, encoder_output: Optional[Tensor] = None, + attention_prior_focus: Optional[Tensor] = None, position_embedding=None, norm_before_bmm1=False, lora_layer_params=None, @@ -1118,7 +1130,7 @@ def compute_cross_kv(encoder_output): assert long_rope_rotary_inv_freq is not None assert long_rope_rotary_cos_sin is not None - context, past_key_value = gpt_attention( + attn_outputs = gpt_attention( qkv=qkv, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1184,9 +1196,15 @@ def compute_cross_kv(encoder_output): host_kv_cache_pool_mapping if not self.cross_attention else kv_cache_params.host_cross_kv_cache_pool_mapping, do_cross_attention=self.cross_attention, + compute_attention_prior=self.compute_attention_prior, + apply_attention_prior=self.apply_attention_prior, + attention_prior_lookahead=self.attention_prior_lookahead, + attention_prior_window_left=self.attention_prior_window_left, + attention_prior_window_right=self.attention_prior_window_right, cross_kv=cross_kv, cross_kv_length=attention_params.encoder_max_input_length, encoder_input_lengths=attention_params.encoder_input_lengths, + attention_prior_focus=attention_prior_focus, logn_scaling=logn_scaling, relative_attention_bias=self.rel_attn_table.value if self.relative_attention else None, @@ -1217,6 +1235,13 @@ def compute_cross_kv(encoder_output): cp_rank=self.cp_rank, cp_group=self.cp_group) + if self.compute_attention_prior: + attention_prior_scores = attn_outputs.pop() + past_key_value = attn_outputs.pop() + # unpack if context is a single tensor + context = attn_outputs[0] if len( + attn_outputs) == 1 else attn_outputs + else: # plain TensorRT mode assert paged_kv_cache == False @@ -1582,10 +1607,12 @@ def transpose_for_scores(x, ).plugin_config.use_fp8_context_fmha: context = dense_conditional.add_output(skip_case, context) + outputs = [context] if use_cache: - return (context, past_key_value) - else: - return context + outputs.append(past_key_value) + if self.compute_attention_prior: + outputs.append(attention_prior_scores) + return outputs def set_rel_attn_table(self, max_seq_len, precomputed_relative_attention): self.rel_attn_table = Parameter(shape=(self.num_attention_heads, diff --git a/tensorrt_llm/layers/embedding.py b/tensorrt_llm/layers/embedding.py index fd77aa9e643..217d9214e67 100644 --- a/tensorrt_llm/layers/embedding.py +++ b/tensorrt_llm/layers/embedding.py @@ -111,6 +111,8 @@ def weight_loader(self, mapping: Mapping, param: Parameter, def postprocess(self, tllm_key, weights, **kwargs): if weights is None: return {} + if isinstance(weights, Sequence): + weights = torch.vstack(weights) weights = weights.to(str_dtype_to_torch(self.dtype)) return {tllm_key: weights} diff --git a/tensorrt_llm/llmapi/mgmn_worker_node.py b/tensorrt_llm/llmapi/mgmn_worker_node.py index e58ec68ae4e..ea9566926d8 100644 --- a/tensorrt_llm/llmapi/mgmn_worker_node.py +++ b/tensorrt_llm/llmapi/mgmn_worker_node.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 import logging +import torch from mpi4py.futures import MPICommExecutor from mpi4py.MPI import COMM_WORLD +from tensorrt_llm._utils import global_mpi_rank, local_mpi_size + +device_id = global_mpi_rank() % local_mpi_size() +torch.cuda.set_device(device_id) # For multi-node MPI, the worker nodes should launch MPICommExecutor to accept tasks sent from rank0 with MPICommExecutor(COMM_WORLD) as executor: if executor is not None: diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 96bd4eff96d..ef7cc06f1dc 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -62,6 +62,7 @@ from .recurrentgemma.model import RecurrentGemmaForCausalLM from .redrafter.model import ReDrafterForLLaMALM, ReDrafterForQWenLM from .stdit.model import STDiT3Model +from .t5tts.model import T5TTSDecoderModel, T5TTSEncoderModel __all__ = [ 'BertModel', @@ -134,6 +135,8 @@ 'SpeculativeDecodingMode', 'CohereForCausalLM', 'MLLaMAForCausalLM', + 'T5TTSEncoderModel', + 'T5TTSDecoderModel', ] MODEL_MAP = { @@ -220,4 +223,6 @@ 'RobertaModel': RobertaModel, 'RobertaForQuestionAnswering': RobertaForQuestionAnswering, 'RobertaForSequenceClassification': RobertaForSequenceClassification, + 'T5TTSEncoderModel': T5TTSEncoderModel, + 'T5TTSDecoderModel': T5TTSDecoderModel, } diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index be3c5afc49f..0e310ec21b2 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -24,9 +24,10 @@ str_dtype_to_torch) from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, - assertion, cast, gather_last_token_logits, - gelu, maximum, minimum, recv, send, shape, - transpose, unsqueeze) + assertion, cast, concat, + gather_last_token_logits, gelu, maximum, + minimum, recv, select, send, shape, + transpose, unsqueeze, view) # yapf: disable from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, AttentionMaskType, AttentionParams, @@ -63,6 +64,7 @@ class EncDecTransformer(Module): def __init__(self, vocab_size, + num_vocabs, hidden_size, has_vocab_embeddings=True, max_position_embeddings=None, @@ -85,6 +87,7 @@ def __init__(self, ln_type = layernorm_map[layernorm_type] self.mapping = mapping + self.num_vocabs = num_vocabs if self.mapping.is_first_pp_rank(): if has_vocab_embeddings: self.vocab_embedding = Embedding( @@ -165,6 +168,19 @@ def embedding(self, x = self.vocab_embedding(input_ids, *args) * self.embedding_scale self.register_network_output('word_embeddings', x) + if self.num_vocabs > 1: + x = view(x, + concat( + [shape(x, 0) / self.num_vocabs, self.num_vocabs, + -1])) # shape [totalSeqLen, nVocab, embDim] + # TODO: selecting just embedding from first vocab for e2e test + x = select(x, 1, 0) + + if self.position_embedding or self.token_type_embedding: + raise RuntimeError( + "For multi-vocab enc-dec, position ids and token type ids have to be repeated" + ) + if self.position_embedding: pos_emb = self.position_embedding(position_ids) self.register_network_output('position_embeddings', pos_emb) @@ -641,6 +657,7 @@ def __init__(self, config: PretrainedConfig): self.transformer = EncDecTransformer( self.config.vocab_size, + 1, self.config.hidden_size, max_position_embeddings=self.config.max_position_embeddings, has_position_embedding=self.has_position_embedding, @@ -1104,8 +1121,12 @@ def __init__(self, config: PretrainedConfig): 'language_adapter_config') else LanguageAdapterConfig.from_dict( self.config.language_adapter_config) + if not hasattr(self.config, 'vocab_sizes'): + self.config.vocab_sizes = [self.config.vocab_size] + self.transformer = EncDecTransformer( self.config.vocab_size, + len(self.config.vocab_sizes) if self.config.vocab_sizes else 1, self.config.hidden_size, max_position_embeddings=self.config.max_position_embeddings, has_position_embedding=self.config.has_position_embedding, @@ -1375,6 +1396,9 @@ def prepare_inputs(self, inlen_range = [ 1, 1, max_decoder_input_len ] # context phase >= 1 (if forced_input_ids), generation phase = 1 + num_vocabs = len(self.config.vocab_sizes) + decoder_inlen_range = [x * num_vocabs for x in inlen_range] + encoder_inlen_range = [ 1, (max_encoder_input_len + 1) // 2, max_encoder_input_len ] @@ -1393,6 +1417,10 @@ def prepare_inputs(self, max_beam_width * max_batch_size), ] + io_decoder_num_tokens_range = [ + x * num_vocabs for x in decoder_num_tokens_range + ] + # No enable_two_optimization_profiles support yet encoder_input_len_range = [ @@ -1437,7 +1465,7 @@ def prepare_inputs(self, shape=[-1], dim_range=OrderedDict([ ('decoder_num_tokens', - [decoder_num_tokens_range]), + [io_decoder_num_tokens_range]), ])) if self.has_position_embedding: position_ids = Tensor(name='position_ids', @@ -1445,7 +1473,7 @@ def prepare_inputs(self, shape=[-1], dim_range=OrderedDict([ ('decoder_num_tokens', - [decoder_num_tokens_range]), + [io_decoder_num_tokens_range]), ])) if self.has_token_type_embedding: token_type_ids = Tensor( @@ -1453,7 +1481,8 @@ def prepare_inputs(self, dtype=trt.int32, shape=[-1], dim_range=OrderedDict([('decoder_num_tokens', - [decoder_num_tokens_range])]), + [io_decoder_num_tokens_range]) + ]), ) else: hidden_states = Tensor(name='hidden_states_input', @@ -1471,7 +1500,7 @@ def prepare_inputs(self, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size_beam_width', [bb_range]), - ('input_len', [inlen_range]), + ('input_len', [decoder_inlen_range]), ])) if self.has_position_embedding: position_ids = Tensor(name='position_ids', @@ -1480,16 +1509,18 @@ def prepare_inputs(self, dim_range=OrderedDict([ ('batch_size_beam_width', [bb_range]), - ('input_len', [inlen_range]), + ('input_len', + [decoder_inlen_range]), ])) if self.has_token_type_embedding: token_type_ids = Tensor( name='token_type_ids', dtype=trt.int32, shape=[-1, -1], - dim_range=OrderedDict([('batch_size_beam_width', - [bb_range]), - ('input_len', [inlen_range])]), + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('input_len', [decoder_inlen_range]) + ]), ) else: hidden_states = Tensor(name='hidden_states_input', @@ -2063,6 +2094,7 @@ def __init__(self, config: PretrainedConfig): dtype=self._dtype) self.transformer = EncDecTransformer( 0, + 1, self.config.hidden_size, has_vocab_embeddings=False, max_position_embeddings=self.config.max_position_embeddings, diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index f97b8d436b7..7766c7bf98e 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -80,18 +80,18 @@ def split_num_tokens_range(max_num_tokens): return num_tokens_ranges @staticmethod - def get_profiles_ranges( - *, - max_batch_size, - max_beam_width, - max_input_len, - max_num_tokens, - max_draft_len, - opt_batch_size, - opt_num_tokens, - enable_ctx_gen_opt_profiles, - multiple_profiles, - kv_cache_type: KVCacheType = KVCacheType.CONTINUOUS): + def get_profiles_ranges(*, + max_batch_size, + max_beam_width, + max_input_len, + max_num_tokens, + max_draft_len, + opt_batch_size, + opt_num_tokens, + enable_ctx_gen_opt_profiles, + multiple_profiles, + kv_cache_type: KVCacheType = KVCacheType.CONTINUOUS, + num_vocabs: int = 1): default_range = GenerationMixin.default_range if opt_batch_size: bb_range_cxt = [1, opt_batch_size, max_batch_size] @@ -168,6 +168,8 @@ def get_profiles_ranges( [math.ceil(x[0] / (max_draft_len + 1)), x[1], x[2]], num_tokens_range)) + num_tokens_range = [[rr * num_vocabs for rr in r] + for r in num_tokens_range] ranges = { 'bb_range': bb_range, 'bbd_range': bbd_range, @@ -542,6 +544,7 @@ def prepare_basic_inputs( opt_batch_size=None, pp_reduce_scatter: bool = False, mrope_rotary_cos_sin_size: int = None, + num_vocabs: int = 1, ): enable_ctx_gen_opt_profiles = GenerationMixin.has_ctx_gen_opt_profiles( @@ -560,7 +563,8 @@ def prepare_basic_inputs( opt_num_tokens=opt_num_tokens, enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, multiple_profiles=multiple_profiles, - kv_cache_type=kv_cache_type) + kv_cache_type=kv_cache_type, + num_vocabs=num_vocabs) bb_range = ranges['bb_range'] bbd_range = ranges['bbd_range'] inlen_range = ranges['inlen_range'] diff --git a/tensorrt_llm/models/llama/config.py b/tensorrt_llm/models/llama/config.py index 7e0369a4ba0..5df028bdc90 100644 --- a/tensorrt_llm/models/llama/config.py +++ b/tensorrt_llm/models/llama/config.py @@ -16,7 +16,7 @@ import math import sys from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union from ...layers import MoeConfig from ...mapping import Mapping @@ -40,6 +40,7 @@ def __init__(self, attention_multiplier: float = 1.0, residual_multiplier: float = 1.0, output_multiplier_scale: float = 1.0, + vocab_sizes: Optional[List[int]] = None, **kwargs): self.mlp_bias = mlp_bias self.attn_bias = attn_bias @@ -68,6 +69,7 @@ def __init__(self, self.attention_multiplier = attention_multiplier self.residual_multiplier = residual_multiplier self.output_multiplier_scale = output_multiplier_scale + self.vocab_sizes = vocab_sizes self.has_partial_lora_mask = False super().__init__(**kwargs) @@ -87,6 +89,7 @@ def to_dict(self): 'use_input_layernorm_in_first_layer'] = self.use_input_layernorm_in_first_layer output['use_last_layernorm'] = self.use_last_layernorm output['layer_idx_offset'] = self.layer_idx_offset + output['vocab_sizes'] = self.vocab_sizes output['moe'] = self.moe.to_dict() return output @@ -189,6 +192,7 @@ def from_hugging_face( dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None)) tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) + vocab_sizes = getattr(hf_config, 'vocab_sizes', None) return cls( architecture=hf_config.architectures[0], @@ -219,6 +223,7 @@ def from_hugging_face( attention_multiplier=attention_multiplier, residual_multiplier=residual_multiplier, output_multiplier_scale=output_multiplier_scale, + vocab_sizes=vocab_sizes, **kwargs) @classmethod diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index d72b68ff923..7e8c71d3b70 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -1616,9 +1616,10 @@ def load_and_set(target, weights[target.replace("weight", "bias")] = bias if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = load( - param_name_map["vocab_embedding"], config.embedding_sharding_dim - if config.use_parallel_embedding else -1) # vocab_embedding + v = load(param_name_map["vocab_embedding"], + config.embedding_sharding_dim + if config.use_parallel_embedding else -1) # vocab_embedding + weights['transformer.vocab_embedding.weight'] = v if mapping.is_last_pp_rank(): v = load(param_name_map["lm_head"], -1, 1) if pad_vocab else load( diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 2e272772adb..10b56535e15 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -21,7 +21,7 @@ from ..._utils import pad_vocab_size from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor, allgather, concat, constant, div, non_gated_version, - recv, send, unsqueeze) + recv, select, send, shape, unsqueeze, view) from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, PositionEmbeddingType, RmsNorm) @@ -313,6 +313,7 @@ def __init__(self, config: LLaMAConfig) -> None: self.vocab_size = config.vocab_size self.has_partial_lora_mask = config.has_partial_lora_mask self.hidden_size = config.hidden_size + self.num_vocabs = len(config.vocab_sizes) if config.vocab_sizes else 1 if self.mapping.is_first_pp_rank(): self.vocab_embedding = Embedding(config.vocab_size, config.hidden_size, @@ -357,9 +358,22 @@ def forward(self, ] if prompt_embedding_table is not None else [] if self.mapping.is_first_pp_rank(): - hidden_states = self.vocab_embedding(input_ids, *ptuning_args) + hidden_states = self.vocab_embedding( + input_ids, *ptuning_args) # seqlen x num_vocabs x hidden_size + hidden_states = view(hidden_states, + concat([ + shape(hidden_states, 0) / self.num_vocabs, + self.num_vocabs, -1 + ])) # shape [totalSeqLen, nVocab, embDim] + # TODO: for debug pick the very first sampled token, ignore rest + # in the future: + # hidden_states = sum(hidden_states, 1, keepdim=False) + # shape [totalSeqLen, embDim] + # hidden_states[:, 0, :] + hidden_states = select(hidden_states, 1, 0) hidden_states *= self.embedding_multiplier else: + # TODO: not supported for multi-vocab yet. hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) if default_net().plugin_config.pp_reduce_scatter: hidden_states = allgather(hidden_states, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 6f4bfdf0bb0..d7c49eb8f12 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -899,7 +899,10 @@ def prepare_inputs( streamingllm=streamingllm, opt_batch_size=opt_batch_size, pp_reduce_scatter=pp_reduce_scatter, - mrope_rotary_cos_sin_size=mrope_rotary_cos_sin_size) + mrope_rotary_cos_sin_size=mrope_rotary_cos_sin_size, + num_vocabs=len(self.config.vocab_sizes) if + (hasattr(self.config, 'vocab_sizes') + and self.config.vocab_sizes) else 1) result = { 'input_ids': diff --git a/tensorrt_llm/models/t5tts/__init__.py b/tensorrt_llm/models/t5tts/__init__.py new file mode 100644 index 00000000000..0d106f45ce4 --- /dev/null +++ b/tensorrt_llm/models/t5tts/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/t5tts/model.py b/tensorrt_llm/models/t5tts/model.py new file mode 100644 index 00000000000..57562675181 --- /dev/null +++ b/tensorrt_llm/models/t5tts/model.py @@ -0,0 +1,1909 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections import OrderedDict +from typing import Optional + +import tensorrt as trt +import torch + +from tensorrt_llm._common import default_net +from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch +from tensorrt_llm.functional import ( + ACT2FN, LayerNormPositionType, LayerNormType, MLPType, + PositionEmbeddingType, Tensor, assertion, concat, expand_dims, + gather_last_token_logits, maximum, mean, minimum, recv, send, shape, + squeeze, stack, transpose, unsqueeze, view, where) +from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, + AttentionMaskType, AttentionParams, + BertAttention, ColumnLinear, Conv1d, Embedding, + FusedGatedMLP, GatedMLP, GroupNorm, + KeyValueCacheParams, LayerNorm, + PromptTuningEmbedding, RmsNorm) +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel +from tensorrt_llm.module import Module, ModuleList +from tensorrt_llm.parameter import Parameter +from tensorrt_llm.plugin.plugin import current_all_reduce_helper + +layernorm_map = { + LayerNormType.LayerNorm: LayerNorm, + LayerNormType.RmsNorm: RmsNorm, + LayerNormType.GroupNorm: GroupNorm, +} + +mlp_map = { + MLPType.MLP: MLP, + MLPType.GatedMLP: GatedMLP, + MLPType.FusedGatedMLP: FusedGatedMLP, +} + +COMPUTE_SCORES_FROM_LAYERS = [4, 6, 10] +APPLY_PRIOR_TO_LAYERS = [4, 5, 6, 7, 8, 9, 10] + + +class PositionwiseConvFF(Module): + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + kernel_size: int = 3, + has_bias: bool = False, + is_causal: bool = True, + hidden_act: str = 'gelu', + padding: int = 0, + dilation: int = 1, + dtype=None, + groups: int = 1, + ): + super().__init__() + + self.is_causal = is_causal + self.hidden_size = hidden_size + self.pos_ffn_hidden_size = ffn_hidden_size + + self.hidden_act = ACT2FN[hidden_act] + + if self.is_causal: + self.causal_padding = ((kernel_size - 1) * dilation, 0) + + padding = 0 + + self.proj = Conv1d(hidden_size, + ffn_hidden_size, + kernel_size=kernel_size, + padding=padding, + bias=has_bias, + dilation=dilation, + dtype=dtype) + self.o_net = Conv1d(ffn_hidden_size, + hidden_size, + kernel_size=kernel_size, + padding=padding, + bias=has_bias, + dilation=dilation, + dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + # input is BT x DIM + x = transpose(x, 0, 1) # DIM x BT + x = unsqueeze(x, 0) # 1 x DIM x BT + + x = self.proj(x) + x = self.o_net(self.hidden_act(x)) + + x = squeeze(x, 0) # DIM x BT + x = transpose(x, 1, 0) # BT x DIM + return x + + +class PositionalEmbedding(Module): + + def __init__(self, + max_position_embeddings, + hidden_size, + has_embedding_layernorm=False, + has_embedding_scale=False, + layernorm_eps=1e-5, + layernorm_type=LayerNormType.LayerNorm, + dtype=None, + use_parallel_embedding=False, + embedding_sharding_dim=0, + mapping=Mapping()): + super().__init__() + + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + self.max_position_embeddings = max_position_embeddings + self.position_embedding = None + self.position_embedding = Embedding( + max_position_embeddings, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.embedding_layernorm = None + if has_embedding_layernorm: + self.embedding_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.embedding_scale = 1.0 + if has_embedding_scale: + self.embedding_scale = math.sqrt(hidden_size) + + def forward(self, + input_ids, + position_ids=None, + prompt_tasks=None, + prompt_vocab_size=None): + pos_emb = self.position_embedding(position_ids) + x = input_ids + pos_emb + if self.embedding_layernorm: + x = self.embedding_layernorm(x) + return x + + +class EncoderDecoderEmbedding(Module): + + def __init__(self, + vocab_size, + num_vocabs, + hidden_size, + max_position_embeddings=None, + has_position_embedding=False, + type_vocab_size=None, + has_embedding_layernorm=False, + has_embedding_scale=False, + layernorm_eps=1e-5, + layernorm_type=LayerNormType.LayerNorm, + dtype=None, + use_parallel_embedding=False, + embedding_sharding_dim=0, + use_context_embeddings=False, + mapping=Mapping()): + super().__init__() + + self.num_vocabs = num_vocabs + self.use_context_embeddings = use_context_embeddings + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + self.vocab_embedding = Embedding( + vocab_size, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.position_embedding = None + self.max_position_embeddings = max_position_embeddings + if has_position_embedding: + self.position_embedding = Embedding( + max_position_embeddings, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.token_type_embedding = None + if type_vocab_size: + self.token_type_embedding = Embedding( + type_vocab_size, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + # e.g. BART true, T5 false + self.embedding_layernorm = None + if has_embedding_layernorm: + self.embedding_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + # e.g. BART true, T5 false + self.embedding_scale = 1.0 + if has_embedding_scale: + self.embedding_scale = math.sqrt(hidden_size) + + # Note: embedding offset in BART is not considered as a standard. For the specific case, + # we just need to shrink its position embedding table by [offset:] during weight loading + + def forward(self, + input_ids, + position_ids=None, + token_type_ids=None, + prompt_embedding_table=None, + prompt_tasks=None, + prompt_vocab_size=None, + decoder_context_features=None, + decoder_context_features_mask=None): + # position_ids and token_type_ids are provided inputs + # and should not be formulated deterministically + + args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + + x = self.vocab_embedding(input_ids, *args) * self.embedding_scale + if self.num_vocabs > 1: + x = view(x, + concat( + [shape(x, 0) / self.num_vocabs, self.num_vocabs, + -1])) # shape [totalSeqLen, nVocab, embDim] + # average across vocabs + x = mean(x, 1) # shape [totalSeqLen, embDim] + if self.use_context_embeddings and decoder_context_features is not None: + x = where(expand_dims(decoder_context_features_mask, 1), + decoder_context_features, x) + + if self.position_embedding: + pos_emb = self.position_embedding(position_ids) + x = x + pos_emb + if self.token_type_embedding: + x = x + self.token_type_embedding(token_type_ids) + + if self.embedding_layernorm: + x = self.embedding_layernorm(x) + + return x + + +class T5TTSEncoderLayer(Module): + + def __init__(self, + hidden_size, + ffn_hidden_size, + num_attention_heads, + num_kv_heads, + head_size, + max_position_embeddings=None, + q_scaling=1.0, + has_attention_qkvo_bias=False, + has_pos_ff_bias=False, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + layernorm_eps=1e-5, + hidden_act="gelu", + mapping=Mapping(), + dtype=None, + residual_scaling=1.0, + relative_attention=False, + max_distance=0, + num_buckets=0, + fp16_clamping=False, + conv_is_causal=False): + super().__init__() + + # e.g. BART regular, T5 RMS + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + # e.g. BART post, T5 pre + self.layernorm_position = layernorm_position + + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.attention = BertAttention( + hidden_size, + num_attention_heads, + attention_head_size=head_size, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + bias=has_attention_qkvo_bias, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + relative_attention=relative_attention, + max_distance=max_distance, + num_buckets=num_buckets) + + self.attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.pos_ff = PositionwiseConvFF( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + hidden_act=hidden_act, + has_bias=has_pos_ff_bias, + kernel_size=3, + padding=1, + groups=mapping.tp_group, + dtype=dtype, + is_causal=conv_is_causal, + ) + + self.pos_ff_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.residual_scaling = residual_scaling + + # T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow + # after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every + # residual add to avoid accuracy drop. + self.fp16_clamping = fp16_clamping + + def forward(self, + hidden_states: Tensor, + attention_mask=None, + input_lengths=None, + max_input_length=None): + assert isinstance(hidden_states, Tensor) + + # self attention + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.attention_layernorm(hidden_states) + + attention_output = self.attention(hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths, + max_input_length=max_input_length) + + self.register_network_output('attention_output', attention_output) + + hidden_states = residual + attention_output + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.attention_layernorm(hidden_states) + + # MLP + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.pos_ff_layernorm(hidden_states) + + hidden_states = self.pos_ff(hidden_states) + + self.register_network_output('pos_ff_output', hidden_states) + + hidden_states = residual + hidden_states + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.pos_ff_layernorm(hidden_states) + + return hidden_states + + +class T5TTSDecoderLayer(Module): + + def __init__(self, + *, + local_layer_idx, + hidden_size, + ffn_hidden_size, + num_attention_heads, + num_kv_heads, + head_size, + max_position_embeddings=None, + q_scaling=1.0, + has_attention_qkvo_bias=False, + has_pos_ff_bias=False, + has_encoder_input_layernorm=False, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + layernorm_eps=1e-5, + hidden_act="gelu", + mapping=Mapping(), + dtype=None, + residual_scaling=1.0, + relative_attention=False, + max_distance=0, + num_buckets=0, + fp16_clamping=False, + skip_cross_kv=False, + use_implicit_relative_attention=False, + compute_attention_prior=False, + apply_attention_prior=False, + attention_prior_lookahead=5, + attention_prior_window_left=1, + attention_prior_window_right=5): + super().__init__() + + self.has_encoder_input_layernorm = has_encoder_input_layernorm + self.compute_attention_prior = compute_attention_prior + + # e.g. BART regular, T5 RMS + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + # e.g. BART post, T5 pre + self.layernorm_position = layernorm_position + self.hidden_size = hidden_size + + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.self_attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_size=head_size, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + bias=has_attention_qkvo_bias, + attention_mask_type=AttentionMaskType.causal, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + cross_attention=False, + relative_attention=relative_attention, + max_distance=max_distance if use_implicit_relative_attention else 0, + num_buckets=num_buckets, + position_embedding_type=PositionEmbeddingType.relative + if relative_attention else PositionEmbeddingType.learned_absolute, + use_implicit_relative_attention=use_implicit_relative_attention) + + self.self_attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + # Note: self attn uses MMHA, mask is always causal triangular + # cross attn has two scenarios: + # - in context phase, all ones mask, same as padding type + # - in generation phase, same causal triangular mask as MMHA + # - context phase special handling is done in plugin by resetting mask type + # + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.cross_attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=hidden_size, + num_attention_heads=1, + attention_head_size=128, # TODO: make this part of model config + num_kv_heads=1, + max_position_embeddings=max_position_embeddings, + bias=has_attention_qkvo_bias, + attention_mask_type=AttentionMaskType.causal, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + cross_attention=True, + compute_attention_prior=self.compute_attention_prior, + apply_attention_prior=apply_attention_prior, + attention_prior_lookahead=attention_prior_lookahead, + attention_prior_window_left=attention_prior_window_left, + attention_prior_window_right=attention_prior_window_right, + relative_attention= + False, # Cross attention has no relative attention bias + max_distance=max_distance, + num_buckets=num_buckets, + position_embedding_type=PositionEmbeddingType.learned_absolute, + skip_cross_kv=skip_cross_kv) + + self.cache_cross_attention_memory = None + if has_encoder_input_layernorm: + self.cross_attention_memory_layernorm = ln_type( + normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.cross_attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.pos_ff = PositionwiseConvFF( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + kernel_size=1, + padding=0, + hidden_act=hidden_act, + has_bias=has_pos_ff_bias, + groups=mapping.tp_group, + dtype=dtype, + is_causal=True, + ) + + self.pos_ff_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.residual_scaling = residual_scaling + + # T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow + # after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every + # residual add to avoid accuracy drop. + self.fp16_clamping = fp16_clamping + + def forward(self, + hidden_states: Tensor, + encoder_output: Optional[Tensor] = None, + attention_prior_focus: Optional[Tensor] = None, + attention_mask_params=None, + use_cache=False, + kv_cache_params=None, + attention_params=None, + cross_kv_cache_gen: Optional[Tensor] = None, + cross_kv_reuse: Optional[Tensor] = None): + assert isinstance(hidden_states, Tensor) + + if encoder_output: + assert isinstance(encoder_output, Tensor) + + # self-attention + residual = hidden_states + hidden_states = self.self_attention_layernorm(hidden_states) + attention_outputs = self.self_attention( + hidden_states=hidden_states, + attention_mask=attention_mask_params.self_attention_mask, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params) + # pop past key value for self attention + if use_cache: + presents_self = attention_outputs.pop() + assert len(attention_outputs) == 1 + attention_output = attention_outputs[0] + hidden_states = residual + attention_output + + # cross attention + residual = hidden_states + + hidden_states = self.cross_attention_layernorm(hidden_states) + encoder_output = self.cross_attention_memory_layernorm(encoder_output) + cross_attention_mask = attention_mask_params.cross_attention_mask + attention_outputs = self.cross_attention( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + attention_packed_mask=attention_mask_params. + cross_attention_packed_mask, + encoder_output=encoder_output, + attention_prior_focus=attention_prior_focus, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + cross_kv_cache_gen=cross_kv_cache_gen, + cross_kv_reuse=cross_kv_reuse) + attention_prior_scores = None + if self.compute_attention_prior: + attention_prior_scores = attention_outputs.pop() + if use_cache: + presents_cross = attention_outputs.pop() + assert len(attention_outputs) == 1 + attention_output = attention_outputs[0] + hidden_states = residual + attention_output + + # conv ff (norm -> conv -> residual) + residual = hidden_states + hidden_states = self.pos_ff_layernorm(hidden_states) + hidden_states = self.pos_ff(hidden_states) + result = residual + hidden_states + + results = [result] + if use_cache: + results.extend([presents_self, presents_cross]) + if self.compute_attention_prior: + results.append(attention_prior_scores) + return results + + +class T5TTSEncoderModel(PretrainedModel): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + super().__init__(config) + self.mapping = self.config.mapping + + self.has_position_embedding = self.config.has_position_embedding + type_vocab_size = self.config.type_vocab_size + self.has_token_type_embedding = False if type_vocab_size is None else True + + # e.g. BART regular, T5 RMS + self.layernorm_type = self.config.layernorm_type + ln_type = layernorm_map[self.layernorm_type] + + # e.g. BART true, T5 false + self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias + self.has_pos_ff_bias = self.config.has_pos_ff_bias + + # e.g. BART false, T5 true + self.has_model_final_layernorm = self.config.has_model_final_layernorm + + self._dtype = self.config.dtype + + self.total_num_layers = self.config.num_hidden_layers + self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + num_kv_heads = self.num_heads + if num_kv_heads is None or num_kv_heads <= 0: + num_kv_heads = self.config.num_attention_heads + self.num_kv_heads = num_kv_heads + self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size + + self.fp16_clamping = (self.config.dtype + == 'float16') and (self.config.model_type == 't5') + self.mlp_type = MLPType.MLP if not hasattr( + self.config, "mlp_type") else self.config.mlp_type + + if self.mapping.is_first_pp_rank(): + self.embedding = EncoderDecoderEmbedding( + self.config.vocab_size, + 1, # number of vocabs + self.config.hidden_size, + max_position_embeddings=self.config.max_position_embeddings, + has_position_embedding=self.has_position_embedding, + type_vocab_size=type_vocab_size, + has_embedding_layernorm=self.config.has_embedding_layernorm, + has_embedding_scale=self.config.has_embedding_scale, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + dtype=self.config.dtype, + use_parallel_embedding=self.config.use_parallel_embedding, + embedding_sharding_dim=self.config.embedding_sharding_dim, + mapping=self.mapping) + + self.encoder_layers = ModuleList([ + T5TTSEncoderLayer( + hidden_size=self.hidden_size, + ffn_hidden_size=self.config.intermediate_size, + num_attention_heads=self.num_heads, + num_kv_heads=num_kv_heads, + head_size=self.head_size, + max_position_embeddings=self.config.max_position_embeddings, + q_scaling=self.config.q_scaling, + has_attention_qkvo_bias=self.has_attention_qkvo_bias, + has_pos_ff_bias=self.has_pos_ff_bias, + layernorm_position=self.config.layernorm_position, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + hidden_act=self.config.hidden_act, + mapping=self.mapping, + dtype=self.config.dtype, + residual_scaling=1.0 + if not hasattr(self.config, "residual_scaling") else + self.config.residual_scaling, + relative_attention=self.config.relative_attention, + max_distance=self.config.max_distance, + num_buckets=self.config.num_buckets, + fp16_clamping=self.fp16_clamping) + for _ in self.mapping.pp_layers(self.total_num_layers) + ]) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + self.final_layernorm = ln_type( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + bias=False) + + def check_config(self, config: PretrainedConfig): + config.set_if_not_exist('has_position_embedding', False) + config.set_if_not_exist('type_vocab_size', None) + config.set_if_not_exist('rescale_before_lm_head', False) + config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm) + config.set_if_not_exist('layernorm_position', + LayerNormPositionType.pre_layernorm) + config.set_if_not_exist('has_attention_qkvo_bias', False) + config.set_if_not_exist('has_pos_ff_bias', False) + config.set_if_not_exist('has_model_final_layernorm', False) + config.set_if_not_exist('encoder_hidden_size', None) + config.set_if_not_exist('encoder_num_heads', None) + config.set_if_not_exist('encoder_num_kv_heads', None) + config.set_if_not_exist('encoder_head_size', None) + config.set_if_not_exist('model_type', 't5') + config.set_if_not_exist('skip_cross_kv', False) + config.set_if_not_exist('has_embedding_scale', False) + config.set_if_not_exist('residual_scaling', 1.0) + config.set_if_not_exist('has_lm_head_bias', False) + config.set_if_not_exist('num_buckets', None) + config.set_if_not_exist('max_distance', None) + config.set_if_not_exist('relative_attention', False) + config.set_if_not_exist('residual_scaling', 1.0) + + def forward(self, + input_ids: Tensor, + input_lengths=None, + position_ids=None, + token_type_ids=None, + hidden_states=None, + max_input_length=None, + prompt_embedding_table=None, + prompt_tasks=None, + prompt_vocab_size=None, + attention_mask=None): + + # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs + if self.mapping.is_first_pp_rank(): + ptuning_args = [ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + + hidden_states = self.embedding(input_ids, position_ids, + token_type_ids, *ptuning_args) + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + for layer_idx, encoder_layer in enumerate(self.encoder_layers): + + hidden_states = encoder_layer(hidden_states=hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths, + max_input_length=max_input_length) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + hidden_states = self.final_layernorm(hidden_states) + hidden_states.mark_output('encoder_output', self._dtype) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + hidden_states.mark_output('hidden_states_output', self._dtype) + + return hidden_states + + def prepare_inputs(self, + max_batch_size, + max_input_len, + prompt_embedding_table_size: int = 0, + *args, + **kwargs): + '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the + ranges of the dimensions of when using TRT dynamic shapes. + + @return: a list contains values which can be fed into the self.forward() + ''' + + hidden_size = self.hidden_size + + bs_range = [1, (max_batch_size + 1) // 2, max_batch_size] + inlen_range = [1, (max_input_len + 1) // 2, max_input_len] + num_tokens_range = [ + 1, + (max_input_len * max_batch_size + 1) // 2, + max_input_len * max_batch_size, + ] + + input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None + remove_input_padding = default_net().plugin_config.remove_input_padding + + attention_mask = None + if remove_input_padding: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor( + name="input_ids", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("num_tokens", [num_tokens_range])]), + ) + if self.has_position_embedding: + position_ids = Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', + [num_tokens_range])]), + ) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', + [num_tokens_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, hidden_size], + dim_range=OrderedDict([ + ('num_tokens', [num_tokens_range]), + ('hidden_size', [hidden_size]), + ])) + else: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor( + name="input_ids", + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([("batch_size", [bs_range]), + ("input_len", [inlen_range])]), + ) + if self.has_position_embedding: + position_ids = Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, -1, hidden_size], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]), + ('hidden_size', [hidden_size]), + ])) + + if not default_net().plugin_config.bert_attention_plugin: + attention_mask = Tensor( + name='attention_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]), + ]), + ) + + # if self.mapping.tp_size > 1: + # current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) + # FIXME(TRTLLM-996): Support custom allreduce for encoder models on C++ runtime + + input_lengths = Tensor( + name="input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size", [bs_range])]), + ) + max_input_length = Tensor( + name="max_input_length", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("max_input_length", [inlen_range])]), + ) + + prompt_embedding_table = None + tasks = None + prompt_vocab_size = None + + if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0: + p_embedding_range = [[ + 1, prompt_embedding_table_size // 2, prompt_embedding_table_size + ]] + + prompt_embedding_table = Tensor(name='prompt_embedding_table', + dtype=self._dtype, + shape=[-1, hidden_size], + dim_range=OrderedDict([ + ('prompt_embedding_table_size', + p_embedding_range), + ('hidden_size', [hidden_size]), + ])) + if remove_input_padding: + tasks = Tensor(name='tasks', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('input_len_task', + [num_tokens_range])])) + else: + tasks = Tensor(name='tasks', + dtype=trt.int32, + shape=[-1, 1], + dim_range=OrderedDict([ + ('batch_size', bs_range), + ('broadcast_dim', [1]), + ])) + prompt_vocab_size = Tensor(name='prompt_vocab_size', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([('size', [1])])) + + result = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids, + 'hidden_states': hidden_states, + 'max_input_length': max_input_length, + 'prompt_embedding_table': prompt_embedding_table, + 'prompt_tasks': tasks, + 'prompt_vocab_size': prompt_vocab_size, + 'attention_mask': attention_mask, + } + + return result + + def use_prompt_tuning(self): + embedding = self.embedding.vocab_embedding + self.embedding.vocab_embedding = PromptTuningEmbedding( + num_embeddings=embedding.num_embeddings, + embedding_dim=embedding.embedding_dim, + dtype=embedding.dtype, + tp_size=embedding.tp_size, + tp_group=embedding.tp_group, + sharding_dim=embedding.sharding_dim, + tp_rank=embedding.tp_rank) + + self.embedding.vocab_embedding.weight.value = embedding.weight.raw_value + + def precompute_relative_attention_bias(self, build_config): + pass + + +class T5TTSDecoderModel(PretrainedModel): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + super().__init__(config) + + self.mapping = self.config.mapping + self.num_vocabs = len(self.config.vocab_sizes) + self.use_context_embeddings = self.config.use_context_embeddings + + self.has_position_embedding = self.config.has_position_embedding + type_vocab_size = self.config.type_vocab_size + self.has_token_type_embedding = (type_vocab_size is not None) + self.rescale_before_lm_head = self.config.rescale_before_lm_head + + # e.g. BART regular, T5 RMS + self.layernorm_type = self.config.layernorm_type + ln_type = layernorm_map[self.layernorm_type] + + # e.g. BART true, T5 false + self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias + self.has_pos_ff_bias = self.config.has_pos_ff_bias + self.has_encoder_input_layernorm = self.config.has_encoder_input_layernorm + + # e.g. BART false, T5 true + self.has_model_final_layernorm = self.config.has_model_final_layernorm + self._dtype = self.config.dtype + # no quantization considered for now + self._kv_dtype = self._dtype + self._logits_dtype = self.config.logits_dtype + + self.total_num_layers = self.config.num_hidden_layers + self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + + num_kv_heads = self.num_heads + if num_kv_heads is None or num_kv_heads <= 0: + num_kv_heads = self.num_heads + self.num_kv_heads = num_kv_heads + self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size + + self.encoder_hidden_size = self.config.encoder_hidden_size + self.encoder_num_heads = self.config.encoder_num_heads + encoder_num_kv_heads = None if not hasattr( + self.config, + "encoder_num_kv_heads") else self.config.encoder_num_kv_heads + if encoder_num_kv_heads is None or encoder_num_kv_heads <= 0: + encoder_num_kv_heads = self.encoder_num_heads + self.encoder_num_kv_heads = encoder_num_kv_heads + self.encoder_head_size = self.encoder_hidden_size // self.num_heads if self.config.encoder_head_size is None else self.config.encoder_head_size + + self.has_position_embedding = self.config.has_position_embedding + self.has_token_type_embedding = type_vocab_size is not None + + self.fp16_clamping = (self.config.dtype + == 'float16') and (self.config.model_type + in ['t5', 'pix2struct']) + + self.skip_cross_kv = self.config.skip_cross_kv + self.mlp_type = MLPType.MLP if not hasattr( + self.config, "mlp_type") else self.config.mlp_type + self.use_implicit_relative_attention = self.config.use_implicit_relative_attention if hasattr( + self.config, "use_implicit_relative_attention") else False + + if self.mapping.is_first_pp_rank(): + self.embedding = EncoderDecoderEmbedding( + # TODO: vocab is expanded to incorporate service token used for unconditional generation + # during CFG + self.config.vocab_size + 1, + self.num_vocabs, + self.config.hidden_size, + max_position_embeddings=self.config.max_position_embeddings, + has_position_embedding=self.has_position_embedding, + type_vocab_size=type_vocab_size, + has_embedding_layernorm=self.config.has_embedding_layernorm, + has_embedding_scale=self.config.has_embedding_scale, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + dtype=self.config.dtype, + use_parallel_embedding=self.config.use_parallel_embedding, + embedding_sharding_dim=self.config.embedding_sharding_dim, + use_context_embeddings=self.config.use_context_embeddings, + mapping=self.mapping) + + layers_range = self.mapping.pp_layers(self.total_num_layers) + self.decoder_layers = ModuleList([ + T5TTSDecoderLayer( + local_layer_idx=layer_idx - layers_range[0], + hidden_size=self.config.hidden_size, + ffn_hidden_size=self.config.intermediate_size, + num_attention_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + max_position_embeddings=self.config.max_position_embeddings, + q_scaling=self.config.q_scaling, + has_attention_qkvo_bias=self.config.has_attention_qkvo_bias, + has_pos_ff_bias=self.config.has_pos_ff_bias, + has_encoder_input_layernorm=self.config. + has_encoder_input_layernorm, + layernorm_position=self.config.layernorm_position, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.config.layernorm_type, + hidden_act=self.config.hidden_act, + mapping=self.mapping, + dtype=self._dtype, + residual_scaling=self.config.residual_scaling, + relative_attention=self.config.relative_attention, + max_distance=self.config.max_distance, + num_buckets=self.config.num_buckets, + fp16_clamping=self.fp16_clamping, + skip_cross_kv=self.skip_cross_kv, + use_implicit_relative_attention=self. + use_implicit_relative_attention, + compute_attention_prior=( + layer_idx in self.config.compute_attention_prior_from_layers + and self.config.use_attention_prior), + apply_attention_prior=( + layer_idx in self.config.apply_attention_prior_to_layers + and self.config.use_attention_prior), + attention_prior_lookahead=self.config.attention_prior_lookahead, + attention_prior_window_left=self.config. + attention_prior_window_left, + attention_prior_window_right=self.config. + attention_prior_window_right, + ) for layer_idx in layers_range + ]) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + self.final_layernorm = ln_type( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + bias=False) + + self.lm_head = ColumnLinear( + self.config.hidden_size, + self.config.vocab_size, + bias=False if not hasattr(self.config, "has_lm_head_bias") else + self.config.has_lm_head_bias, + dtype=self.config.dtype, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + gather_output=True, + ) + + if self.config.relative_attention and not self.use_implicit_relative_attention: + self.rel_attn_table = Parameter( + shape=(self.config.num_attention_heads // self.mapping.tp_size, + self.config.num_buckets), + dtype=self._dtype) + + def check_config(self, config: PretrainedConfig): + config.set_if_not_exist('use_context_embeddings', True) + config.set_if_not_exist('has_position_embedding', False) + config.set_if_not_exist('type_vocab_size', None) + config.set_if_not_exist('rescale_before_lm_head', False) + config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm) + config.set_if_not_exist('layernorm_position', + LayerNormPositionType.pre_layernorm) + config.set_if_not_exist('has_attention_qkvo_bias', False) + config.set_if_not_exist('has_pos_ff_bias', False) + config.set_if_not_exist('has_encoder_input_layernorm', True) + config.set_if_not_exist('has_model_final_layernorm', False) + config.set_if_not_exist('audio_embedding_dim', 768) + + config.set_if_not_exist('encoder_hidden_size', None) + config.set_if_not_exist('encoder_num_heads', None) + config.set_if_not_exist('encoder_num_kv_heads', None) + config.set_if_not_exist('encoder_head_size', None) + config.set_if_not_exist('model_type', 't5') + config.set_if_not_exist('skip_cross_kv', False) + config.set_if_not_exist('has_embedding_scale', False) + config.set_if_not_exist('residual_scaling', 1.0) + config.set_if_not_exist('has_lm_head_bias', False) + config.set_if_not_exist('num_buckets', None) + config.set_if_not_exist('max_distance', None) + config.set_if_not_exist('relative_attention', False) + config.set_if_not_exist('residual_scaling', 1.0) + + def forward(self, + decoder_input_ids: Tensor, + encoder_output: Tensor, + decoder_context_features: Optional[Tensor] = None, + decoder_context_features_mask: Optional[Tensor] = None, + attention_prior_focus: Optional[Tensor] = None, + position_ids=None, + token_type_ids=None, + use_cache=False, + attention_mask_params=None, + last_token_ids=None, + kv_cache_params=None, + attention_params=None, + hidden_states=None, + cross_kv_cache_gen: Optional[Tensor] = None, + cross_kv_reuse: Optional[Tensor] = None): + if self.mapping.is_first_pp_rank(): + assert isinstance(decoder_input_ids, Tensor) + else: + assert isinstance(hidden_states, Tensor) + + # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs + if self.mapping.is_first_pp_rank(): + hidden_states = self.embedding( + decoder_input_ids, + position_ids, + None, + decoder_context_features=decoder_context_features, + decoder_context_features_mask=decoder_context_features_mask) + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + kv_cache_params.fill_none_tensor_list(len(self.decoder_layers)) + + if use_cache: + presents = [] + + all_attention_prior_scores = [] + for i, (decoder_layer, past) in enumerate( + zip(self.decoder_layers, kv_cache_params.past_key_value)): + + hidden_states = decoder_layer( + hidden_states, + encoder_output=encoder_output, + attention_prior_focus=attention_prior_focus, + attention_mask_params=attention_mask_params, + use_cache=use_cache, + kv_cache_params=KeyValueCacheParams( + past_key_value=past, + host_past_key_value_lengths=kv_cache_params. + host_past_key_value_lengths, + host_max_attention_window_sizes=kv_cache_params. + host_max_attention_window_sizes, + host_sink_token_length=kv_cache_params. + host_sink_token_length, + cache_indirection=kv_cache_params.cache_indirection, + kv_cache_block_offsets=kv_cache_params. + kv_cache_block_offsets, + host_kv_cache_block_offsets=kv_cache_params. + host_cross_kv_cache_block_offsets, + host_kv_cache_pool_pointers=kv_cache_params. + host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, + cross_kv_cache_block_offsets=kv_cache_params. + cross_kv_cache_block_offsets, + host_cross_kv_cache_block_offsets=kv_cache_params. + host_cross_kv_cache_block_offsets, + host_cross_kv_cache_pool_pointers=kv_cache_params. + host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping=kv_cache_params. + host_cross_kv_cache_pool_mapping), + attention_params=attention_params, + cross_kv_cache_gen=cross_kv_cache_gen, + cross_kv_reuse=cross_kv_reuse) + + if decoder_layer.compute_attention_prior: + attention_prior_scores = hidden_states.pop() + all_attention_prior_scores.append(attention_prior_scores) + if use_cache: + presents_cross = hidden_states.pop() + presents_self = hidden_states.pop() + presents.append((presents_self, presents_cross)) + assert len(hidden_states) == 1 + hidden_states = hidden_states[0] + + scores_stacked = stack(all_attention_prior_scores, 0) # [layers x b*5] + mean_scores = mean(scores_stacked, 0) # [b*5] + mean_scores.mark_output("attention_prior_scores") + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + hidden_states = self.final_layernorm(hidden_states) + + # [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size] + hidden_states = gather_last_token_logits( + hidden_states, last_token_ids, + default_net().plugin_config.remove_input_padding) + + # [bs, hidden_size] -> [bs, vocab_size] + lm_logits = self.lm_head(hidden_states) + lm_logits.mark_output('logits', self._logits_dtype) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + hidden_states.mark_output('hidden_states_output', self._dtype) + + if use_cache and default_net().plugin_config.paged_kv_cache == False: + for i, present in zip(self.mapping.pp_layers(self.total_num_layers), + presents): + present[0].mark_output(f'present_key_value_{i}', self._kv_dtype) + if default_net().plugin_config.gpt_attention_plugin: + present[1].mark_output(f'cross_present_key_value_{i}', + self._kv_dtype) + if self.mapping.is_last_pp_rank(): + return (lm_logits, tuple(presents)) + return (hidden_states, tuple(presents)) + else: + if self.mapping.is_last_pp_rank(): + return lm_logits + return hidden_states + + def prepare_inputs(self, + max_batch_size, + max_decoder_input_len, + max_seq_len, + max_encoder_input_len, + gather_context_logits: bool = False, + gather_generation_logits: bool = False, + use_cache=True, + max_beam_width=1, + *args, + **kwargs): + '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the + ranges of the dimensions of when using TRT dynamic shapes. + + @return: a list contains values which can be fed into the self.forward() + ''' + # Prepare inputs + max_output_len = max_decoder_input_len + max_seq_len + + head_size = self.head_size + num_kv_heads = (self.num_kv_heads + self.mapping.tp_size - + 1) // self.mapping.tp_size + + encoder_head_size = self.encoder_head_size + encoder_num_kv_heads = (self.encoder_num_kv_heads + self.mapping.tp_size + - 1) // self.mapping.tp_size + + bb_range = [ + 1, (max_batch_size * max_beam_width + 1) // 2, + max_batch_size * max_beam_width + ] + bs_range = [1, (max_batch_size + 1) // 2, max_batch_size] + beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width] + inlen_range = [ + 1, 1, max_decoder_input_len + ] # context phase >= 1 (if forced_input_ids), generation phase = 1 + multivocab_inlen_range = [x * self.num_vocabs for x in inlen_range] + encoder_inlen_range = [ + 1, (max_encoder_input_len + 1) // 2, max_encoder_input_len + ] + mask_len_range = [1, (max_output_len + 1) // 2 + 1, max_output_len + 1] + max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len] + + encoder_num_tokens_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_encoder_input_len * max_batch_size + 1) // 2, + max_encoder_input_len * max_batch_size, + ] + decoder_num_tokens_range = [ + 1, + max_batch_size * max_beam_width, + max(max_decoder_input_len * max_batch_size, + max_beam_width * max_batch_size), + ] + multivocab_decoder_num_tokens_range = [ + x * self.num_vocabs for x in decoder_num_tokens_range + ] + + # No enable_two_optimization_profiles support yet + + encoder_input_len_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_encoder_input_len + 1) // 2, + max_encoder_input_len + ] + max_cross_packed_mask_dim0 = max_batch_size * ( + (max_decoder_input_len + 128 - 1) // 128) * 128 + max_cross_packed_mask_dim1 = ( + (max_encoder_input_len + 256 - 1) // 256) * 256 // 32 + cross_packed_mask_dim0_range = [ + 1, (max_cross_packed_mask_dim0 + 1) // 2, max_cross_packed_mask_dim0 + ] + cross_packed_mask_dim1_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_cross_packed_mask_dim1 + 1) // 2, + max_cross_packed_mask_dim1 + ] + + past_key_value = [] + sequence_length = None + host_past_key_value_lengths = None + runtime_perf_knobs = None + context_progress = None + attention_mask = None + cross_attention_mask = None + cross_attention_packed_mask = None + attention_mask_params = AttentionMaskParams() + use_gpt_attention_plugin = default_net( + ).plugin_config.gpt_attention_plugin + remove_input_padding = default_net().plugin_config.remove_input_padding + paged_kv_cache = default_net().plugin_config.paged_kv_cache + tokens_per_block = default_net().plugin_config.tokens_per_block + + input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None + decoder_context_features = None + decoder_context_features_mask = None + if remove_input_padding: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('multivocab_decoder_num_tokens', + [multivocab_decoder_num_tokens_range]) + ])) + if self.use_context_embeddings: + decoder_context_features = Tensor( + name='decoder_context_features', + dtype=self._dtype, + shape=[-1, self.hidden_size], + dim_range=OrderedDict([ + ('decoder_num_tokens', [decoder_num_tokens_range]), + ('hidden_size', [self.hidden_size]), + ])) + decoder_context_features_mask = Tensor( + name='decoder_context_features_mask', + dtype=trt.bool, + shape=[-1], + dim_range=OrderedDict([ + ('decoder_num_tokens', [decoder_num_tokens_range]), + ])) + if self.has_position_embedding: + position_ids = Tensor(name='position_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('decoder_num_tokens', + [decoder_num_tokens_range]), + ])) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('decoder_num_tokens', + [decoder_num_tokens_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, self.hidden_size], + dim_range=OrderedDict([ + ('decoder_num_tokens', + [decoder_num_tokens_range]), + ('hidden_size', [self.hidden_size]), + ])) + else: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('multivocab_input_len', + [multivocab_inlen_range]), + ])) + if self.has_position_embedding: + position_ids = Tensor(name='position_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]), + ('input_len', [inlen_range]), + ])) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size_beam_width', + [bb_range]), + ('input_len', [inlen_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, -1, self.hidden_size], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range + ]), + ('input_len', [inlen_range]), + ('hidden_size', [self.hidden_size]), + ])) + + encoder_input_lengths = Tensor( + name="encoder_input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size_beam_width", [bb_range])]), + ) + encoder_max_input_length = Tensor( + name="encoder_max_input_length", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("encoder_max_input_length", + [encoder_inlen_range])]), + ) + encoder_output = None + if remove_input_padding: + encoder_output = Tensor( + name="encoder_output", + dtype=self._dtype, + shape=[-1, self.encoder_hidden_size], + dim_range=OrderedDict([ + ("encoder_num_tokens", [encoder_num_tokens_range]), + ("encoder_hidden_size", [self.encoder_hidden_size]), + ]), + ) + else: + encoder_output = Tensor( + name="encoder_output", + dtype=self._dtype, + shape=[-1, -1, self.encoder_hidden_size], + dim_range=OrderedDict([ + ("batch_size_beam_width_encoder", [bb_range]), + ("encoder_input_len", [encoder_input_len_range]), + ("encoder_hidden_size", [self.encoder_hidden_size]), + ]), + ) + attention_prior_focus = None + if remove_input_padding and use_gpt_attention_plugin: + focus_dim_range = list(bb_range) + focus_dim_range[0] = 0 # could be zero if not provided + attention_prior_focus = Tensor( + name="attention_prior_focus", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ("batch_size_beam_width_focus", [focus_dim_range]), + ]), + ) + + if use_gpt_attention_plugin: + host_past_key_value_lengths = Tensor( + name='host_past_key_value_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]), + ) + + context_lengths = None + host_context_lengths = None + host_request_types = None + if use_gpt_attention_plugin and remove_input_padding: + host_context_lengths = Tensor(name='host_context_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]) + ])) + + if use_gpt_attention_plugin: + sequence_length = Tensor( + name='sequence_length', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]), + ) + + context_lengths = Tensor(name='context_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]) + ])) + host_request_types = Tensor(name='host_request_types', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]) + ])) + runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs', + dtype=trt.int64, + shape=[16], + dim_range=OrderedDict([ + ('perf_knob_size', [16]) + ])) + context_progress = Tensor(name='host_context_progress', + dtype=trt.int64, + shape=[1], + dim_range=OrderedDict([ + ('context_progress_size', [1]) + ])) + + last_token_ids = None + if self.mapping.is_last_pp_rank() and not gather_context_logits: + last_token_ids = Tensor( + name="last_token_ids", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size_last_token_ids", [bb_range]) + ]), + ) + + if not use_gpt_attention_plugin: + attention_mask = Tensor( + name='attention_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('mask_len', [mask_len_range]), + ]), + ) + + cross_attention_mask = Tensor( + name='cross_attention_mask', + dtype=trt.int32, + shape=[-1, -1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('query_len', [1]), + ('encoder_input_len_2', [encoder_input_len_range]), + ]), + ) + else: + cross_attention_mask = Tensor( + name='cross_attention_mask', + dtype=trt.bool, + shape=[-1, -1], + dim_range=OrderedDict([ + ('decoder_num_tokens_2', + [decoder_num_tokens_range + ]), # TODO (bhsueh) should use same name as input_ids + ('encoder_input_len_2', [encoder_input_len_range]), + ]), + ) + + cross_attention_packed_mask = Tensor( + name='cross_attention_packed_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('cross_packed_mask_dim0', [cross_packed_mask_dim0_range]), + ('cross_packed_mask_dim1', [cross_packed_mask_dim1_range]), + ]), + ) + + # create the attention_mask_params. + attention_mask_params = AttentionMaskParams( + attention_mask, None, cross_attention_mask, + cross_attention_packed_mask) + + cache_indirection = Tensor( + name='cache_indirection', + dtype=trt.int32, + shape=[-1, -1, -1], + dim_range=OrderedDict([ + ('batch_size_cache', [bs_range]), + ('beam_width', [beam_width_range]), + ('max_seq_len', [max_output_len_range]), + ]), + ) + + if self.mapping.tp_size > 1: + current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) + + layers_range = self.mapping.pp_layers(self.total_num_layers) + num_pp_layers = len(layers_range) + + host_max_attention_window_sizes = None + host_sink_token_length = None + if use_gpt_attention_plugin: + host_max_attention_window_sizes = Tensor( + name=f'host_max_attention_window_sizes', + dtype=trt.int32, + shape=[num_pp_layers], + dim_range=OrderedDict([('num_layers', [num_pp_layers])])) + host_sink_token_length = Tensor(name='host_sink_token_length', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([('scalar', + [1])])) + + kv_cache_block_offsets = None + host_kv_cache_block_offsets = None + host_kv_cache_pool_pointers = None + host_kv_cache_pool_mapping = None + + cross_kv_cache_block_offsets = None + host_cross_kv_cache_block_offsets = None + host_cross_kv_cache_pool_pointers = None + host_cross_kv_cache_pool_mapping = None + + if use_cache: + if not paged_kv_cache: + for i in layers_range: + kv_dim_range = OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('num_heads', [num_kv_heads]), + ('past_key_len', [max_output_len_range]), + ('head_size', [head_size]), + ]) + kv = Tensor(name=f'past_key_value_{i}', + dtype=self._kv_dtype, + shape=[-1, 2, num_kv_heads, -1, head_size], + dim_range=kv_dim_range) + + if use_gpt_attention_plugin: + cross_kv_dim_range = OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('cross_num_heads', [encoder_num_kv_heads]), + ('cross_past_key_len', [encoder_input_len_range]), + ('cross_head_size', [encoder_head_size]), + ]) + cross_kv = Tensor(name=f'cross_past_key_value_{i}', + dtype=self._kv_dtype, + shape=[ + -1, 2, encoder_num_kv_heads, -1, + encoder_head_size + ], + dim_range=cross_kv_dim_range) + past_key_value.append((kv, cross_kv)) + else: + # use encoder_output directly, no need to save cross_past_key_value + past_key_value.append((kv, )) + + # TODO: Remove this when TRT fix the named dimension + if not remove_input_padding: + assertion( + shape( + input_ids if self.mapping.is_first_pp_rank() else + hidden_states, 0) == shape(kv, 0), 'batch size') + + else: # paged_kv_cache == True + # PagedKV setup for KV cache of self-attention + max_blocks_per_seq_range = [[ + math.ceil(max_output_len_range[0] / tokens_per_block), + math.ceil(max_output_len_range[1] / tokens_per_block), + math.ceil(max_output_len_range[2] / tokens_per_block) + ]] + max_blocks_per_seq_range = [[ + x for x in max_blocks_per_seq_range[0] + ]] + + # PagedKV setup for KV cache of cross-attention + max_cross_blocks_per_seq_range = [[ + math.ceil(encoder_input_len_range[0] / tokens_per_block), + math.ceil(encoder_input_len_range[1] / tokens_per_block), + math.ceil(encoder_input_len_range[2] / tokens_per_block) + ]] + max_cross_blocks_per_seq_range = [[ + x for x in max_cross_blocks_per_seq_range[0] + ]] + + # TODO(oargov): add support for vgqa, meanwhile assume a single kv cache pool + num_kv_cache_pools = 1 + + kv_cache_block_offsets = Tensor( + name=f'kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) + host_kv_cache_block_offsets = Tensor( + name=f'host_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) + host_kv_cache_pool_pointers = Tensor( + name=f'host_kv_cache_pool_pointers', + dtype=trt.int64, + shape=[num_kv_cache_pools, 2], + dim_range=OrderedDict([ + ('num_pools_layers', [num_kv_cache_pools]), + ('num_pools_kv', [2]), + ])) + host_kv_cache_pool_mapping = Tensor( + name=f"host_kv_cache_pool_mapping", + dtype=trt.int32, + # 2: (Index of pool, Index of layer within pool) + shape=[num_pp_layers, 2], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), + ('layer_cache_pool_locator', [2]), + ])) + + # paged blocks for cross kv + cross_kv_cache_block_offsets = Tensor( + name=f'cross_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_cross_blocks_per_seq', + max_cross_blocks_per_seq_range), + ])) + host_cross_kv_cache_block_offsets = Tensor( + name=f'host_cross_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_cross_blocks_per_seq', + max_cross_blocks_per_seq_range), + ])) + host_cross_kv_cache_pool_pointers = Tensor( + name=f'host_cross_kv_cache_pool_pointers', + dtype=trt.int64, + shape=[num_kv_cache_pools, 2], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('num_pools', [2]), + ])) + host_cross_kv_cache_pool_mapping = Tensor( + name=f"host_cross_kv_cache_pool_mapping", + dtype=trt.int32, + # 2: (Index of pool, Index of layer within pool) + shape=[num_pp_layers, 2], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), + ('layer_cache_pool_locator', [2]), + ])) + + for i in layers_range: + past_key_value.append(None) + + kv_cache_params = KeyValueCacheParams( + past_key_value=past_key_value, + host_past_key_value_lengths=host_past_key_value_lengths, + host_max_attention_window_sizes=host_max_attention_window_sizes, + host_sink_token_length=host_sink_token_length, + cache_indirection=cache_indirection, + kv_cache_block_offsets=kv_cache_block_offsets, + host_kv_cache_block_offsets=host_kv_cache_block_offsets, + host_kv_cache_pool_pointers=host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping, + cross_kv_cache_block_offsets=cross_kv_cache_block_offsets, + host_cross_kv_cache_block_offsets= + host_cross_kv_cache_block_offsets, + host_cross_kv_cache_pool_pointers= + host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping= + host_cross_kv_cache_pool_mapping, + ) + + attention_params = AttentionParams( + sequence_length=sequence_length, + context_lengths=context_lengths, + host_context_lengths=host_context_lengths, + max_context_length=max_decoder_input_len, + host_request_types=host_request_types, + encoder_input_lengths=encoder_input_lengths, + encoder_max_input_length=encoder_max_input_length, + host_runtime_perf_knobs=runtime_perf_knobs, + host_context_progress=context_progress) + + cross_kv_cache_gen = Tensor(name='cross_kv_cache_gen', + dtype=trt.bool, + shape=[1], + dim_range=OrderedDict([ + ('boolean', [1]), + ])) + cross_kv_reuse = None + num_heads = (self.num_heads + self.mapping.tp_size - + 1) // self.mapping.tp_size + cross_kv_out_dim = 2 * num_kv_heads * self.head_size + if self.skip_cross_kv: + if remove_input_padding: + cross_kv_reuse = Tensor( + name="cross_kv_reuse", + dtype=self._dtype, + shape=[-1, cross_kv_out_dim], + dim_range=OrderedDict([ + ("encoder_num_tokens", [encoder_num_tokens_range]), + ("encoder_kv_size", [cross_kv_out_dim]), + ]), + ) + else: + cross_kv_reuse = Tensor( + name="cross_kv_reuse", + dtype=self._dtype, + shape=[-1, -1, cross_kv_out_dim], + dim_range=OrderedDict([ + ("batch_size_beam_width_encoder", [bb_range]), + ("encoder_input_len", [encoder_input_len_range]), + ("encoder_kv_size", [cross_kv_out_dim]), + ]), + ) + + result = { + 'decoder_input_ids': input_ids, + 'encoder_output': encoder_output, + 'decoder_context_features': decoder_context_features, + 'decoder_context_features_mask': decoder_context_features_mask, + 'attention_prior_focus': attention_prior_focus, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids, + 'use_cache': True, + 'attention_mask_params': attention_mask_params, + 'last_token_ids': last_token_ids, + 'kv_cache_params': kv_cache_params, + 'attention_params': attention_params, + 'hidden_states': hidden_states, + 'cross_kv_cache_gen': cross_kv_cache_gen, + 'cross_kv_reuse': cross_kv_reuse, + } + + return result + + def precompute_relative_attention_bias(self, build_config): + if self.config.relative_attention and not self.use_implicit_relative_attention: + relative_attention_bias_builder = torch.ops.tensorrt_llm.relative_attention_bias + rel_attn_precomputed = torch.zeros( + (self.config.num_attention_heads // self.mapping.tp_size, + build_config.max_seq_len + 1, build_config.max_seq_len + 1), + dtype=str_dtype_to_torch(self.config.dtype), + device='cuda') + rel_attn_table = numpy_to_torch( + self.rel_attn_table.raw_value).to('cuda') + relative_attention_bias_builder( + rel_attn_precomputed, + rel_attn_table, + self.config.num_attention_heads // self.mapping.tp_size, + build_config.max_seq_len, + self.config.num_buckets, + False, + self.config.max_distance, + ) + for layer_idx in range(self.num_layers): + self.decoder_layers[ + layer_idx].self_attention.set_rel_attn_table( + build_config.max_seq_len, rel_attn_precomputed) diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index b701f245f6f..30e7d7da6df 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -460,7 +460,7 @@ def _check_inputs(self, batch_input_ids: List[List[int]], len(x) for x in encoder_input_ids ] if encoder_input_ids else [len(x) for x in batch_input_ids] max_length = max(input_lengths) - if max_length > self.max_input_len: + if max_length > self.max_input_len * len(self.model_config.vocab_sizes): raise RuntimeError( f"Maximum input length ({max_length}) exceeds the engine or specified limit ({self.max_input_len})" ) @@ -471,7 +471,8 @@ def _check_inputs(self, batch_input_ids: List[List[int]], f"Decoder prefix tokens ({decoder_max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" ) else: - if max_length + max_new_tokens > self.max_seq_len: + if max_length + max_new_tokens > self.max_seq_len * len( + self.model_config.vocab_sizes): raise RuntimeError( f"Maximum input length ({max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" ) @@ -537,6 +538,7 @@ def generate( encoder_input_features: List[ torch.Tensor] = None, # TODO: add to doc string encoder_output_lengths: List[int] = None, + decoder_context_features: List[torch.Tensor] = None, cross_attention_masks: List[ torch.Tensor] = None, # TODO: add to doc string mrope_params: Optional[MropeParams] = None, @@ -579,6 +581,8 @@ def generate( A list of encoder input feature tensors for multimodal encoder-decoder models (optional). Each tensor is of shape (sequence_length, feature_dim). encoder_output_lengths: (List[int]): A list of encoder output lengths (optional) if encoder output has different length from encoder input (due to convolution down-sampling, etc.) + decoder_context_features (List[torch.Tensor]): + A list of decoder context feature tensors for multimodal decoder-only models (optional). Each tensor is of shape (sequence_length, feature_dim). sampling_config (SamplingConfig): The sampling configuration to be used as base parametrization for the generation call. The passed **kwargs matching the sampling_config's attributes will override them. @@ -651,6 +655,7 @@ def generate( "num_return_sequences", "min_p", "beam_width_array", + "cfg_scale", ] rename_params = {"num_beams": "beam_width", "random_seed": "seed"} sampling_params = { @@ -758,6 +763,8 @@ def generate( if encoder_output_lengths is not None else None, encoder_input_features=encoder_input_features[i].contiguous() if encoder_input_features is not None else None, + decoder_context_features=decoder_context_features[i].contiguous( + ) if decoder_context_features is not None else None, position_ids=position_ids[i].tolist() if position_ids is not None else None, cross_attention_mask=cross_attention_masks[i].contiguous() if @@ -782,6 +789,9 @@ def generate( external_draft_tokens_config=external_draft_tokens_config, skip_cross_attn_blocks=skip_cross_attn_blocks, language_adapter_uid=language_adapter_uid, + num_vocabs=len(self.model_config.vocab_sizes) if + (hasattr(self.model_config, 'vocab_sizes') + and self.model_config.vocab_sizes) else 1, ) for i, (input_ids, stop_words, bad_words, prompt_tuning_config, mrope_config, lora_config, logits_post_processor_name, @@ -1083,7 +1093,8 @@ def fill_output_ids(result_token_ids, batch_idx, seq_idx): input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], dtype=torch.int32, - device=cuda_device) + device=cuda_device) // len( + self.model_config.vocab_sizes) if output_sequence_lengths: outputs['sequence_lengths'] = torch.tensor(sequence_lengths,