Skip to content

Commit a3d83e8

Browse files
pcastonguayShunkang
authored andcommitted
[None][feat] Support for cancelling requests with disaggregation (NVIDIA#8114)
Signed-off-by: Shunkang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Co-authored-by: Shunkang <[email protected]> Signed-off-by: Faradawn Yang <[email protected]>
1 parent acd6dbf commit a3d83e8

File tree

12 files changed

+653
-55
lines changed

12 files changed

+653
-55
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class BaseCacheTransceiver
7171
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
7272

7373
[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;
74+
75+
virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
7476
};
7577

7678
class CacheTransceiver : public BaseCacheTransceiver
@@ -111,6 +113,8 @@ class CacheTransceiver : public BaseCacheTransceiver
111113

112114
[[nodiscard]] bool checkGenTransferComplete() const override;
113115

116+
virtual bool cancelRequest(LlmRequest* llmRequest) override;
117+
114118
private:
115119
void initializeCommState();
116120

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,4 +567,17 @@ bool CacheTransceiver::checkGenTransferComplete() const
567567
return mRequesterFutures.empty();
568568
}
569569

570+
bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
571+
{
572+
if (llmRequest->isContextOnlyRequest())
573+
{
574+
return mCacheSender->cancelRequest(*llmRequest);
575+
}
576+
else if (llmRequest->isGenerationOnlyRequest())
577+
{
578+
return mCacheReceiver->cancelRequest(*llmRequest);
579+
}
580+
return false;
581+
}
582+
570583
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 177 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class CacheSender::Impl
270270
auto future = promise.get_future();
271271
{
272272
{
273-
std::unique_lock lkResp(mSenderMutex);
273+
std::scoped_lock lkResp(mSenderMutex);
274274
mReadyResponses.emplace(
275275
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
276276
}
@@ -379,6 +379,49 @@ class CacheSender::Impl
379379
mFormatter->format(*session);
380380
}
381381

382+
bool cancelRequest(LlmRequest const& llmRequest)
383+
{
384+
bool isCancelled = false;
385+
std::scoped_lock lkResp(mSenderMutex);
386+
auto it = mReadyResponses.find(llmRequest.mRequestId);
387+
// If the request is not the current request and already in the ready queue, we can cancel it.
388+
if (it != mReadyResponses.end()
389+
&& (!mCurrentRequest.has_value() || getCurrentRequestId() != llmRequest.mRequestId))
390+
{
391+
mCancelledRequests.insert(llmRequest.mRequestId);
392+
isCancelled = true;
393+
}
394+
else
395+
{
396+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
397+
}
398+
return isCancelled;
399+
}
400+
401+
void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
402+
{
403+
auto it = mRequestToSession.find(requestId);
404+
TLLM_CHECK(it != mRequestToSession.end());
405+
auto& session = it->second;
406+
auto const& connections = session.getConnections();
407+
for (size_t i = 0; i < connections.size(); i++)
408+
{
409+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
410+
if (agentConnectionManager)
411+
{
412+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
413+
TLLM_CHECK(agentConnection);
414+
agentConnection->sendReadySignal(
415+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, isReady);
416+
}
417+
else
418+
{
419+
connections.at(i)->send(
420+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
421+
}
422+
}
423+
}
424+
382425
~Impl()
383426
{
384427
terminate();
@@ -463,8 +506,42 @@ class CacheSender::Impl
463506
{
464507
mRemainSendCount.erase(reqId);
465508

466-
asyncSendAndRemoveResponse(it->first, std::move(it->second));
467-
removeResponse(it);
509+
// Check if the request is cancelled
510+
bool isReady = true;
511+
{
512+
std::scoped_lock lk(mSenderMutex);
513+
if (mCancelledRequests.find(reqId) != mCancelledRequests.end())
514+
{
515+
isReady = false;
516+
}
517+
}
518+
sendReadySignal(reqId, isReady);
519+
520+
if (isReady)
521+
{
522+
asyncSendAndRemoveResponse(it->first, std::move(it->second));
523+
removeResponse(it);
524+
}
525+
else
526+
{
527+
// TODO: if the generation does not require the kv cache, the request will
528+
// not be removed from mCancelledRequests. This should be handled by timeout.
529+
auto it = mReadyResponses.find(mCurrentRequest.value());
530+
TLLM_CHECK(it != mReadyResponses.end());
531+
{
532+
std::scoped_lock lkResp(mSenderMutex);
533+
mReadyResponses.erase(it);
534+
mCancelledRequests.erase(mCurrentRequest.value());
535+
mRemainSendCount.erase(mCurrentRequest.value());
536+
}
537+
mCurrentRequest = std::nullopt;
538+
539+
if (mReadyResponses.empty())
540+
{
541+
std::unique_lock lk(mCondMutex);
542+
mAnyReady = false;
543+
}
544+
}
468545
}
469546
mCurrentRequest = std::nullopt;
470547
}
@@ -491,7 +568,11 @@ class CacheSender::Impl
491568
auto const& requestInfo = recvRequestInfo();
492569
auto reqId = requestInfo.getRequestId();
493570

494-
mCurrentRequest = reqId;
571+
{
572+
std::scoped_lock lk(mSenderMutex);
573+
mCurrentRequest = reqId;
574+
}
575+
495576
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
496577
{
497578
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
@@ -549,7 +630,7 @@ class CacheSender::Impl
549630
void removeResponse(std::map<RequestIdType, Response>::iterator it)
550631
{
551632
{
552-
std::unique_lock lkResp(mSenderMutex);
633+
std::scoped_lock lkResp(mSenderMutex);
553634
mReadyResponses.erase(it);
554635
}
555636
if (mReadyResponses.empty())
@@ -566,12 +647,13 @@ class CacheSender::Impl
566647

567648
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
568649
{
569-
std::unique_lock lk(mSenderMutex);
650+
std::scoped_lock lk(mSenderMutex);
570651
return mReadyResponses.find(getCurrentRequestId());
571652
}
572653

573654
private:
574655
std::optional<RequestIdType> mCurrentRequest;
656+
std::set<LlmRequest::RequestIdType> mCancelledRequests;
575657
std::map<RequestIdType, Response> mReadyResponses;
576658
std::mutex mSenderMutex, mCondMutex;
577659
std::atomic<bool> mAnyReady{false}, mTerminate{false};
@@ -619,7 +701,7 @@ class CacheReceiver::Impl
619701
auto promise = std::make_unique<std::promise<void>>();
620702
auto future = promise->get_future();
621703
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
622-
std::string processInfo = "default";
704+
std::string processInfo = kDefaultProcessInfo;
623705
if (common::getEnvRequestKVCacheConcurrent())
624706
{
625707
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
@@ -703,7 +785,7 @@ class CacheReceiver::Impl
703785

704786
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
705787
std::optional<size_t> cacheBufferId = std::nullopt;
706-
if (agentConnectionManager != nullptr)
788+
if (agentConnectionManager)
707789
{
708790
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
709791
TLLM_CHECK(cacheBufferId.has_value());
@@ -726,7 +808,7 @@ class CacheReceiver::Impl
726808
auto const* connection = counterPartConnections[i];
727809
// if Manager is agentConnectionManager, then send request info to agent
728810
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
729-
if (agentConnectionManager != nullptr)
811+
if (agentConnectionManager)
730812
{
731813
// TODO: index -> validConnectionIdx conversion
732814
auto validConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
@@ -751,7 +833,7 @@ class CacheReceiver::Impl
751833
{
752834
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
753835
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
754-
std::string processString = "default";
836+
std::string processString = kDefaultProcessInfo;
755837
if (common::getEnvRequestKVCacheConcurrent())
756838
{
757839
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
@@ -777,6 +859,62 @@ class CacheReceiver::Impl
777859
connection->send(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
778860
}
779861

862+
bool cancelRequest(LlmRequest const& llmRequest)
863+
{
864+
865+
std::string processInfo = kDefaultProcessInfo;
866+
if (common::getEnvRequestKVCacheConcurrent())
867+
{
868+
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
869+
}
870+
871+
bool isCancelled = false;
872+
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
873+
{
874+
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
875+
auto it = std::find_if(asyncResource->mRequestsQueue.begin(), asyncResource->mRequestsQueue.end(),
876+
[&llmRequest](RequestAndPromise const& requestAndPromise)
877+
{ return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; });
878+
if (it != asyncResource->mRequestsQueue.end())
879+
{
880+
asyncResource->mRequestsQueue.erase(it);
881+
isCancelled = true;
882+
}
883+
else
884+
{
885+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
886+
}
887+
}
888+
return isCancelled;
889+
}
890+
891+
bool receiveReadySignal(TransferSession& session)
892+
{
893+
bool isReadyFinal = true;
894+
bool isReady = false;
895+
auto const& connections = session.getConnections();
896+
897+
for (size_t i = 0; i < connections.size(); i++)
898+
{
899+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
900+
if (agentConnectionManager)
901+
{
902+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
903+
TLLM_CHECK(agentConnection);
904+
isReady = agentConnection->recvReadySignal(
905+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
906+
}
907+
else
908+
{
909+
connections.at(i)->recv(
910+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
911+
}
912+
isReadyFinal &= isReady;
913+
}
914+
915+
return isReadyFinal;
916+
}
917+
780918
~Impl()
781919
{
782920
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
@@ -799,6 +937,14 @@ class CacheReceiver::Impl
799937
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
800938
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
801939
auto session = sendRequestInfo(llmRequest);
940+
bool isReady = receiveReadySignal(session);
941+
if (!isReady)
942+
{
943+
// Reuse the error state for the cancelled request.
944+
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
945+
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
946+
return;
947+
}
802948
receiveSync(session);
803949
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
804950

@@ -915,6 +1061,7 @@ class CacheReceiver::Impl
9151061
}
9161062

9171063
int mDeviceId{-1};
1064+
static constexpr char const* kDefaultProcessInfo = "default";
9181065
std::vector<std::future<void>> mRequestFutures;
9191066
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
9201067
executor::kv_cache::ConnectionManager* mManager;
@@ -970,6 +1117,16 @@ RequestInfo CacheSender::recvRequestInfo()
9701117
return mImpl->recvRequestInfo();
9711118
}
9721119

1120+
bool CacheSender::cancelRequest(LlmRequest const& llmRequest)
1121+
{
1122+
return mImpl->cancelRequest(llmRequest);
1123+
}
1124+
1125+
void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
1126+
{
1127+
mImpl->sendReadySignal(requestId, isReady);
1128+
}
1129+
9731130
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
9741131
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
9751132
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
@@ -993,4 +1150,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
9931150
mImpl->receiveSync(session);
9941151
}
9951152

1153+
bool CacheReceiver::cancelRequest(LlmRequest const& llmRequest)
1154+
{
1155+
return mImpl->cancelRequest(llmRequest);
1156+
}
1157+
1158+
bool CacheReceiver::receiveReadySignal(TransferSession& session)
1159+
{
1160+
return mImpl->receiveReadySignal(session);
1161+
}
1162+
9961163
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ struct TransceiverTag
144144
static constexpr int32_t kID_TAG{19};
145145
static constexpr int32_t kINFO_SIZE_TAG{22};
146146
static constexpr int32_t kINFO_TAG{32};
147+
static constexpr int32_t kREADY_SIGNAL_TAG{42};
147148
};
148149

149150
// Used to store the information that needs to be sent to the context executor to ensure the generation
@@ -240,6 +241,16 @@ class CacheSender
240241
/// @param llmRequest The request object to which the data belongs.
241242
virtual RequestInfo recvRequestInfo();
242243

244+
/// @brief Cancel the request.
245+
/// @param requestId The ID used in the context phase of the current request.
246+
/// @return Whether the request is cancelled.
247+
virtual bool cancelRequest(LlmRequest const& llmRequest);
248+
249+
/// @brief Send ready signal.
250+
/// @param requestId The ID used in the context phase of the current request.
251+
/// @param isReady Whether the request is ready to be received.
252+
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);
253+
243254
/// @brief Destructor.
244255
virtual ~CacheSender();
245256

@@ -272,6 +283,17 @@ class CacheReceiver
272283
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);
273284

274285
virtual void receiveSync(TransferSession& session);
286+
287+
/// @brief Cancel the request.
288+
/// @param llmRequest Request object.
289+
/// @return Whether the request is cancelled.
290+
virtual bool cancelRequest(LlmRequest const& llmRequest);
291+
292+
/// @brief Receive ready signal.
293+
/// @param session The session object.
294+
/// @return Whether the request is ready to be received.
295+
virtual bool receiveReadySignal(TransferSession& session);
296+
275297
/// @brief Destructor.
276298
virtual ~CacheReceiver();
277299

0 commit comments

Comments
 (0)