Skip to content

Commit f98fa0c

Browse files
authored
[None][feat] Optimize kv cache transfer TEP (NVIDIA#7613)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 4c0f848 commit f98fa0c

File tree

9 files changed

+113
-53
lines changed

9 files changed

+113
-53
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ bool CacheFormatter::needSendCache(
9090
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
9191
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
9292
}
93+
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
9394

94-
// only TP rank % dupHeadFactor == 0 need to send cache.
95-
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
95+
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
9696
}
9797

9898
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
@@ -140,11 +140,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
140140
return ret;
141141
}
142142
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
143+
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
143144

144145
std::vector<size_t> ret;
145146
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
146147
{
147-
if (i % targetInfo.mPeerDupHeadFactor == 0)
148+
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
148149
{
149150
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
150151
{

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ CacheTransBufferManager::CacheTransBufferManager(
219219
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
220220
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
221221
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
222-
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
222+
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
223223
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
224224
&& FabricMemory::supportFbaricMemory();
225225
if (mUseFabricMemory)
@@ -269,7 +269,7 @@ size_t CacheTransBufferManager::preAllocBufferSize(
269269
TransferBufferSize = FabricMemory::getAlignedSize(TransferBufferSize);
270270
}
271271
size_t RecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
272-
size_t SendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
272+
size_t SendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
273273
size_t PreAllocBufferSize = TransferBufferSize * (RecvBufferCount + SendBufferCount);
274274
return PreAllocBufferSize;
275275
}

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ class CacheSender::Impl
256256
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
257257
mCurrentRequest = std::nullopt;
258258
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
259+
int asyncSendThreadNum = common::getEnvKVCacheSendMaxConcurrenceNum();
260+
for (int i = 0; i < asyncSendThreadNum; i++)
261+
{
262+
mAsyncSendFutures.emplace_back(
263+
std::async(std::launch::async, &Impl::handleAsyncSend, this, std::ref(mAsyncSendResource)));
264+
}
259265
}
260266

261267
[[nodiscard]] std::future<void> sendAsync(LlmRequest& llmRequest)
@@ -294,9 +300,9 @@ class CacheSender::Impl
294300

295301
void release(LlmRequest::RequestIdType requestId)
296302
{
303+
std::unique_lock<std::mutex> lk(mMtxForMap);
297304
auto it = mRequestToSession.find(requestId);
298305
TLLM_CHECK(it != mRequestToSession.end());
299-
std::unique_lock<std::mutex> lk(mMtxForMap);
300306
if (!common::getEnvKVCacheTransferOutputPath().empty())
301307
{
302308
if (!mMeasuresFile.is_open())
@@ -368,11 +374,15 @@ class CacheSender::Impl
368374

369375
void sendSync(LlmRequest const& llmRequest)
370376
{
371-
auto it = mRequestToSession.find(llmRequest.mRequestId);
372-
TLLM_CHECK(it != mRequestToSession.end());
373-
auto& session = it->second;
374-
session.setLlmRequest(llmRequest);
375-
mFormatter->format(session);
377+
TransferSession* session = nullptr;
378+
{
379+
std::unique_lock<std::mutex> lk(mMtxForMap);
380+
auto it = mRequestToSession.find(llmRequest.mRequestId);
381+
TLLM_CHECK(it != mRequestToSession.end());
382+
session = std::addressof(it->second);
383+
}
384+
session->setLlmRequest(llmRequest);
385+
mFormatter->format(*session);
376386
}
377387

378388
~Impl()
@@ -387,6 +397,40 @@ class CacheSender::Impl
387397
std::promise<void> mPromise;
388398
};
389399

400+
struct AsyncSendResource
401+
{
402+
std::deque<Response> mSendQueue;
403+
std::mutex mMtxForQueue;
404+
std::condition_variable mCVforQueue;
405+
std::atomic<bool> mTerminate{false};
406+
};
407+
408+
void handleAsyncSend(AsyncSendResource& resource)
409+
{
410+
tensorrt_llm::common::setThreadName("dataTransAsyncSend");
411+
while (!resource.mTerminate)
412+
{
413+
Response resp;
414+
{
415+
std::unique_lock lk(resource.mMtxForQueue);
416+
resource.mCVforQueue.wait(
417+
lk, [&resource] { return !resource.mSendQueue.empty() || resource.mTerminate; });
418+
if (resource.mTerminate)
419+
{
420+
if (!resource.mSendQueue.empty())
421+
{
422+
TLLM_LOG_WARNING("There are still %zu requests in the mSendQueue, but encountered terminate.",
423+
resource.mSendQueue.size());
424+
}
425+
break;
426+
}
427+
resp = std::move(resource.mSendQueue.front());
428+
resource.mSendQueue.pop_front();
429+
}
430+
sendAndRemoveResponse(resp.mRequest->mRequestId, std::move(resp));
431+
}
432+
}
433+
390434
void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept
391435
{
392436
try
@@ -409,6 +453,13 @@ class CacheSender::Impl
409453
}
410454
}
411455

456+
void asyncSendAndRemoveResponse(RequestIdType id, Response resp) noexcept
457+
{
458+
std::unique_lock lk(mAsyncSendResource.mMtxForQueue);
459+
mAsyncSendResource.mSendQueue.emplace_back(std::move(resp));
460+
mAsyncSendResource.mCVforQueue.notify_one();
461+
}
462+
412463
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
413464
{
414465
auto reqId = mCurrentRequest.value();
@@ -422,15 +473,7 @@ class CacheSender::Impl
422473
auto llmRequest = it->second.mRequest;
423474
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
424475

425-
if (common::getEnvParallelCacheSend())
426-
{
427-
// TODO: Use a thread pool and check for thread safety.
428-
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach();
429-
}
430-
else
431-
{
432-
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
433-
}
476+
asyncSendAndRemoveResponse(it->first, std::move(it->second));
434477
removeResponse(it);
435478
}
436479
mCurrentRequest = std::nullopt;
@@ -454,7 +497,7 @@ class CacheSender::Impl
454497
break;
455498
}
456499
std::vector<size_t> blockHashes;
457-
if (!isSending() && !mReadyResponses.empty())
500+
if (!mReadyResponses.empty())
458501
{
459502
auto const& requestInfo = recvRequestInfo();
460503
auto reqId = requestInfo.getRequestId();
@@ -507,6 +550,12 @@ class CacheSender::Impl
507550
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
508551
// to the terminate flag.
509552
mSenderCv.notify_all();
553+
mAsyncSendResource.mTerminate = true;
554+
mAsyncSendResource.mCVforQueue.notify_all();
555+
for (auto& future : mAsyncSendFutures)
556+
{
557+
future.get();
558+
}
510559
}
511560

512561
void removeResponse(std::map<RequestIdType, Response>::iterator it)
@@ -522,11 +571,6 @@ class CacheSender::Impl
522571
}
523572
}
524573

525-
[[nodiscard]] bool isSending() const
526-
{
527-
return mCurrentRequest.has_value();
528-
}
529-
530574
[[nodiscard]] RequestIdType getCurrentRequestId() const
531575
{
532576
return mCurrentRequest.value();
@@ -546,6 +590,8 @@ class CacheSender::Impl
546590
std::condition_variable mSenderCv;
547591
std::future<void> mResponseFuture;
548592
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
593+
AsyncSendResource mAsyncSendResource;
594+
std::vector<std::future<void>> mAsyncSendFutures;
549595
int mDeviceId{-1};
550596

551597
executor::kv_cache::ConnectionManager* mManager;

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
7272
TLLM_CHECK(targetInfo.mDomainCPSize == 1);
7373
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
7474
std::vector<size_t> ret;
75-
// targetInfo , mRanks [tpranks, dpranks]
75+
// targetInfo , mRanks [tpranks, ppranks]
76+
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
77+
7678
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
7779
{
78-
ret.push_back(i);
80+
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
7981
}
8082
return ret;
8183
}
@@ -85,19 +87,24 @@ bool MLACacheFormatter::needSendCache(
8587
{
8688
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
8789

90+
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
91+
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
92+
: destConfig.getParallelConfig().mTensorParallelism;
93+
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
94+
8895
if (selfConfig.getParallelConfig().mEnableAttentionDP)
8996
{
9097
int selfTPNumInDPGroup
9198
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
92-
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
93-
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
94-
: destConfig.getParallelConfig().mTensorParallelism;
99+
95100
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
96101
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
97102
{
98103
return true;
99104
}
100-
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
105+
106+
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
107+
return selfTPrankINDPGroup % dupHeadFactor == destDPRank % dupHeadFactor;
101108
}
102109

103110
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
@@ -108,7 +115,8 @@ bool MLACacheFormatter::needSendCache(
108115
{
109116
return true;
110117
}
111-
return selfTpRank % (selfTPNum / destTPNum) == 0;
118+
int dupHeadFactor = selfTPNum / destTPNum;
119+
return selfTpRank % dupHeadFactor == destDPRank % dupHeadFactor;
112120
}
113121

114122
void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,6 @@ bool getEnvDisableSelectiveCacheTransfer()
324324
return disableSelectiveCacheTransfer;
325325
}
326326

327-
bool getEnvParallelCacheSend()
328-
{
329-
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
330-
return parallelCacheSend;
331-
}
332-
333327
bool getEnvRequestKVCacheConcurrent()
334328
{
335329
static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT");
@@ -414,7 +408,7 @@ bool getEnvKVCacheTransferUseSyncBuffer()
414408
size_t getEnvKVCacheSendMaxConcurrenceNum()
415409
{
416410

417-
static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(2);
411+
static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(1);
418412
return maxConcurrenceNum;
419413
}
420414

cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize)
108108
size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent()
109109
? tensorrt_llm::common::getEnvKVCacheRecvBufferCount()
110110
: 1;
111-
size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend()
112-
? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum()
113-
: 1;
111+
size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum();
114112
size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELFKONLY);
115113
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow{
116114
{maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}};
@@ -152,9 +150,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2)
152150
size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent()
153151
? tensorrt_llm::common::getEnvKVCacheRecvBufferCount()
154152
: 1;
155-
size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend()
156-
? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum()
157-
: 1;
153+
size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum();
158154
size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELF);
159155
tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{
160156
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens};
@@ -260,7 +256,7 @@ TEST_F(CacheTransBufferTest, TestBufferIndexAssignment1)
260256
SizeType32 tokensPerBlock = 8;
261257
std::optional<size_t> maxNumTokens = maxBlocksPerSeq * tokensPerBlock;
262258
setenv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT", "1", 1);
263-
setenv("TRTLLM_PARALLEL_CACHE_SEND", "1", 1);
259+
setenv("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM", "2", 1);
264260
SetUpCacheTransBuffer(4, 2, 64, tokensPerBlock, CacheType::kSELF, maxNumTokens, maxBlocksPerSeq);
265261
auto bufferId = mTransBufferManager->assignBufferIndexForSend();
266262
EXPECT_TRUE(bufferId.has_value());

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,18 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWi
14321432
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
14331433
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
14341434

1435+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA4, AsymmetricalCacheTestWithDP,
1436+
testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1),
1437+
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
1438+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
1439+
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
1440+
1441+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA5, AsymmetricalCacheTestWithDP,
1442+
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(2), testing::Values(1),
1443+
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
1444+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
1445+
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
1446+
14351447
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP,
14361448
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
14371449
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
@@ -1472,6 +1484,11 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, Asymmetrical
14721484
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
14731485
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
14741486
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
1487+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate3, AsymmetricalCacheTestWithDP,
1488+
testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1),
1489+
testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
1490+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1491+
testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false)));
14751492

14761493
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
14771494
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2),
@@ -1849,13 +1866,13 @@ TEST(targetTest, CacheStateContextDP)
18491866
/*expectNeedSend*/ true);
18501867
verifyContext(
18511868
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1852-
/*expectNeedSend*/ true);
1869+
/*expectNeedSend*/ false);
18531870
verifyContext(
18541871
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
18551872
/*expectNeedSend*/ false);
18561873
verifyContext(
18571874
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1858-
/*expectNeedSend*/ false);
1875+
/*expectNeedSend*/ true);
18591876
verifyContext(
18601877
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
18611878
/*expectNeedSend*/ false);

docs/source/features/disagg-serving.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ For more information on how to use Dynamo with TensorRT-LLM, please refer to [th
192192

193193
TRT-LLM uses some environment variables to control the behavior of disaggregated service.
194194

195-
* `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`.
196195

197196
* `TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP`: If set to `1`, generationExecutor will not overlap KV cache transfer with model inference. The default value is `0`.
198197

@@ -206,7 +205,7 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated
206205

207206
* `TRTLLM_KVCACHE_TRANSFER_USE_ASYNC_BUFFER`: If set to `1`, TRT-LLM will use `cudaMallocAsync` to allocate buffers for KV cache transmission. The default value is `0`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
208207

209-
* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
208+
* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `1`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
210209

211210
There are some other useful environment variables that may help when encountering failures or performance issues.
212211

0 commit comments

Comments
 (0)