Skip to content

Commit 1280cf3

Browse files
authored
Merge branch 'main' into refactor_loop
2 parents 7503c03 + 908f49a commit 1280cf3

39 files changed

+867
-144
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,8 @@ class ExecutorConfig
14841484
std::optional<GuidedDecodingConfig> guidedDecodingConfig = std::nullopt,
14851485
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs = std::nullopt,
14861486
std::optional<CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt,
1487-
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false);
1487+
bool gatherGenerationLogits = false, bool promptTableOffloading = false, bool enableTrtOverlap = false,
1488+
bool failFastOnAttentionWindowTooLarge = false);
14881489

14891490
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
14901491
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@@ -1519,6 +1520,7 @@ class ExecutorConfig
15191520
[[nodiscard]] bool getPromptTableOffloading() const;
15201521
[[nodiscard]] std::optional<CacheTransceiverConfig> getCacheTransceiverConfig() const;
15211522
[[nodiscard]] bool getEnableTrtOverlap() const;
1523+
[[nodiscard]] bool getFailFastOnAttentionWindowTooLarge() const;
15221524

15231525
void setMaxBeamWidth(SizeType32 maxBeamWidth);
15241526
void setMaxBatchSize(SizeType32 maxBatchSize);
@@ -1548,6 +1550,7 @@ class ExecutorConfig
15481550
void setPromptTableOffloading(bool promptTableOffloading);
15491551
void setCacheTransceiverConfig(CacheTransceiverConfig const& cacheTransceiverConfig);
15501552
void setEnableTrtOverlap(bool enableTrtOverlap);
1553+
void setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge);
15511554

15521555
private:
15531556
friend class Serialization;
@@ -1634,6 +1637,10 @@ class ExecutorConfig
16341637

16351638
/// @brief Controls whether preparation and TRT engine execution should be overlapped.
16361639
bool mEnableTrtOverlap{false};
1640+
1641+
/// @brief Controls whether to fail fast when attention window is too large to fit even a single sequence in the KV
1642+
/// cache.
1643+
bool mFailFastOnAttentionWindowTooLarge{false};
16371644
};
16381645

16391646
struct KVCacheCreatedData

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,27 +296,27 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
296296

297297
auto const [freePrimaryMemBytes, freeSecondaryMemBytes]
298298
= BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig);
299-
300299
if (mModelConfig.useCrossAttention())
301300
{
302301
TLLM_CHECK_WITH_INFO(kvCacheConfig.getCrossKvCacheFraction().has_value(),
303302
"Must set crossKvCacheFraction for encoder-decoder model");
304303
auto const crossKvCacheFraction = kvCacheConfig.getCrossKvCacheFraction().value();
305304
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF,
306305
freePrimaryMemBytes * (1.0f - crossKvCacheFraction),
307-
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize);
308-
mCrossKvCacheManager
309-
= createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS, freePrimaryMemBytes * crossKvCacheFraction,
310-
freeSecondaryMemBytes * crossKvCacheFraction, cacheTransPreAllocaSize);
306+
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize,
307+
executorConfig.getFailFastOnAttentionWindowTooLarge());
308+
mCrossKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS,
309+
freePrimaryMemBytes * crossKvCacheFraction, freeSecondaryMemBytes * crossKvCacheFraction,
310+
cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
311311
TLLM_LOG_INFO("This is an Encoder-Decoder model, set %0.1f cross KV cache fraction based on the config.",
312312
crossKvCacheFraction);
313313
}
314314
else
315315
{
316316
TLLM_CHECK_WITH_INFO(!kvCacheConfig.getCrossKvCacheFraction().has_value(),
317317
"Do not set crossKvCacheFraction for decoder-only model");
318-
mKvCacheManager = createKvCacheManager(
319-
kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize);
318+
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes,
319+
freeSecondaryMemBytes, cacheTransPreAllocaSize, executorConfig.getFailFastOnAttentionWindowTooLarge());
320320
}
321321

322322
mCacheTransceiver
@@ -550,7 +550,8 @@ void TrtGptModelInflightBatching::reshapeKvTensors(OffsetTableDimensions const&
550550
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
551551

552552
std::pair<BlocksPerWindow, std::vector<SizeType32>>
553-
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWindow const& blocksPerWindow)
553+
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(
554+
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge)
554555
{
555556
// At this point, we can only validate that the cheapest sequence in terms of kv-cache resources still fits. More
556557
// validation is needed on a per-request basis, once the prompt / output lengths and the actual beam width are
@@ -591,6 +592,16 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi
591592
}
592593
TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s",
593594
common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str());
595+
596+
if (failFastOnAttentionWindowTooLarge)
597+
{
598+
throw std::runtime_error(
599+
"Attention window too large to fit even a single sequence in the KV cache. Failing fast rather than "
600+
"attempting an adjustment of the window sizes. "
601+
"Old: "
602+
+ common::vec2str(getMaxAttentionWindowVec()) + ", New: " + common::vec2str(newMaxAttentionWindowVec));
603+
}
604+
594605
setMaxAttentionWindowVec(newMaxAttentionWindowVec);
595606
if (getMaxSequenceLen() > getMaxAttentionWindow())
596607
{
@@ -613,7 +624,7 @@ TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWi
613624

614625
std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::createKvCacheManager(
615626
KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, uint64_t freePrimaryMemBytes,
616-
uint64_t freeSecondaryMemBytes, size_t extraCostMemory)
627+
uint64_t freeSecondaryMemBytes, size_t extraCostMemory, bool const failFastOnAttentionWindowTooLarge)
617628
{
618629
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
619630
bool isCrossAttention = kvCacheType == KvCacheType::kCROSS;
@@ -657,7 +668,8 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c
657668
// and user also didn't provide maxAttentionWindow, which leads it to be equal to maxSeqLen
658669
if (kvCacheType == KvCacheType::kSELF)
659670
{
660-
std::tie(blocksPerWindow, maxAttentionWindowVec) = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow);
671+
std::tie(blocksPerWindow, maxAttentionWindowVec)
672+
= clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow, failFastOnAttentionWindowTooLarge);
661673
}
662674

663675
kv_cache_manager::TempAttentionWindowInputs tempAttentionWindowInputs;

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ class TrtGptModelInflightBatching : public TrtGptModel
280280
void createBuffers(executor::DecodingConfig const& decodingConfig,
281281
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs);
282282
std::unique_ptr<KVCacheManager> createKvCacheManager(KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType,
283-
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory);
283+
uint64_t freePrimaryMemBytes, uint64_t freeSecondaryMemBytes, size_t extraCostMemory,
284+
bool const failFastOnAttentionWindowTooLarge = false);
284285
void createRnnStateManager();
285286
void createCustomAllReduceWorkspace();
286287
void createRuntimePerfKnobsTensor(executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
@@ -378,9 +379,11 @@ class TrtGptModelInflightBatching : public TrtGptModel
378379
/// window.
379380
///
380381
/// @param blocksPerWindow map of window size to number of blocks.
382+
/// @param failFastOnAttentionWindowTooLarge if true, the function will report a runtime error if the attention
383+
/// window is too large to fit even a single sequence in the KV cache.
381384
/// @return pair of new blocks per window and new maxAttentionWindowVec
382385
[[nodiscard]] std::pair<BlocksPerWindow, std::vector<SizeType32>> clampWindowSizesToFitAtLeastOneSequence(
383-
BlocksPerWindow const& blocksPerWindow);
386+
BlocksPerWindow const& blocksPerWindow, bool const failFastOnAttentionWindowTooLarge = false);
384387

385388
/// @brief Change the speculative decoding mode.
386389
void changeSpecDecMode(ScheduledRequests const& scheduledRequests);

cpp/tensorrt_llm/executor/executorConfig.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
3434
std::optional<SpeculativeDecodingConfig> specDecConfig, std::optional<GuidedDecodingConfig> guidedDecodingConfig,
3535
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs,
3636
std::optional<CacheTransceiverConfig> cacheTransceiverConfig, bool gatherGenerationLogits,
37-
bool promptTableOffloading, bool enableTrtOverlap)
37+
bool promptTableOffloading, bool enableTrtOverlap, bool failFastOnAttentionWindowTooLarge)
3838
: mMaxBeamWidth(maxBeamWidth)
3939
, mSchedulerConfig(std::move(schedulerConfig))
4040
, mKvCacheConfig(std::move(kvCacheConfig))
@@ -63,6 +63,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
6363
, mGatherGenerationLogits(gatherGenerationLogits)
6464
, mPromptTableOffloading(promptTableOffloading)
6565
, mEnableTrtOverlap(enableTrtOverlap)
66+
, mFailFastOnAttentionWindowTooLarge(failFastOnAttentionWindowTooLarge)
6667
{
6768
TLLM_CHECK(iterStatsMaxIterations >= 0);
6869
TLLM_CHECK(requestStatsMaxIterations >= 0);
@@ -222,6 +223,11 @@ bool ExecutorConfig::getEnableTrtOverlap() const
222223
return mEnableTrtOverlap;
223224
}
224225

226+
bool ExecutorConfig::getFailFastOnAttentionWindowTooLarge() const
227+
{
228+
return mFailFastOnAttentionWindowTooLarge;
229+
}
230+
225231
// setters
226232

227233
void ExecutorConfig::setMaxBeamWidth(SizeType32 maxBeamWidth)
@@ -371,4 +377,9 @@ void ExecutorConfig::setEnableTrtOverlap(bool enableTrtOverlap)
371377
mEnableTrtOverlap = enableTrtOverlap;
372378
}
373379

380+
void ExecutorConfig::setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge)
381+
{
382+
mFailFastOnAttentionWindowTooLarge = failFastOnAttentionWindowTooLarge;
383+
}
384+
374385
} // namespace tensorrt_llm::executor

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
5252
unsigned int kernel_m_tilesize
5353
= getKernelMTileSize(num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit);
5454

55+
// precompiled XQA does not use is_fp8_output as hashing key
5556
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
5657
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
57-
xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
58+
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false};
5859
}
5960

6061
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,11 @@ class XQAKernelList
124124
m_tilesize = num_q_heads_over_kv;
125125
}
126126

127+
// precompiled XQA does not support param is_fp8_output in hash key
127128
XQAKernelRuntimeHashKey hash_key
128129
= {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
129130
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
130-
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
131+
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0 /* xqa jit param is_fp8_output */};
131132
auto const findIter = mFunctions.find(hash_key);
132133
return findIter != mFunctions.end();
133134
}

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ void initConfigBindings(pybind11::module_& m)
459459
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
460460
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
461461
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
462-
c.getPromptTableOffloading(), c.getEnableTrtOverlap());
462+
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
463463
auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__"));
464464
return pickle_tuple;
465465
};
@@ -472,7 +472,7 @@ void initConfigBindings(pybind11::module_& m)
472472

473473
// Restore C++ data
474474
auto cpp_states = state[0].cast<py::tuple>();
475-
if (cpp_states.size() != 28)
475+
if (cpp_states.size() != 29)
476476
{
477477
throw std::runtime_error("Invalid cpp_states!");
478478
}
@@ -505,7 +505,8 @@ void initConfigBindings(pybind11::module_& m)
505505
cpp_states[24].cast<std::optional<tle::CacheTransceiverConfig>>(), // CacheTransceiverConfig
506506
cpp_states[25].cast<bool>(), // GatherGenerationLogits
507507
cpp_states[26].cast<bool>(), // PromptTableOffloading
508-
cpp_states[27].cast<bool>() // EnableTrtOverlap
508+
cpp_states[27].cast<bool>(), // EnableTrtOverlap
509+
cpp_states[28].cast<bool>() // FailFastOnAttentionWindowTooLarge
509510
);
510511

511512
auto py_state = state[1].cast<py::dict>();
@@ -542,7 +543,8 @@ void initConfigBindings(pybind11::module_& m)
542543
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
543544
bool, // GatherGenerationLogits
544545
bool, // PromptTableOffloading
545-
bool // EnableTrtOverlap
546+
bool, // EnableTrtOverlap
547+
bool // FailFastOnAttentionWindowTooLarge
546548
>(),
547549
py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"),
548550
py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"),
@@ -563,7 +565,7 @@ void initConfigBindings(pybind11::module_& m)
563565
py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(),
564566
py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(),
565567
py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false,
566-
py::arg("enable_trt_overlap") = false)
568+
py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false)
567569
.def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
568570
.def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
569571
.def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
@@ -613,6 +615,9 @@ void initConfigBindings(pybind11::module_& m)
613615
&tle::ExecutorConfig::setPromptTableOffloading)
614616
.def_property(
615617
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
618+
.def_property("fail_fast_on_attention_window_too_large",
619+
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
620+
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
616621
.def(py::pickle(executorConfigGetState, executorConfigSetState));
617622
}
618623

103 KB
Loading
150 KB
Loading
10.5 KB
Loading

0 commit comments

Comments
 (0)