@@ -270,7 +270,7 @@ class CacheSender::Impl
270270 auto future = promise.get_future ();
271271 {
272272 {
273- std::unique_lock lkResp (mSenderMutex );
273+ std::scoped_lock lkResp (mSenderMutex );
274274 mReadyResponses .emplace (
275275 llmRequest.mRequestId , Response{std::addressof (llmRequest), std::move (promise)});
276276 }
@@ -379,6 +379,49 @@ class CacheSender::Impl
379379 mFormatter ->format (*session);
380380 }
381381
382+ bool cancelRequest (LlmRequest const & llmRequest)
383+ {
384+ bool isCancelled = false ;
385+ std::scoped_lock lkResp (mSenderMutex );
386+ auto it = mReadyResponses .find (llmRequest.mRequestId );
387+ // If the request is not the current request and already in the ready queue, we can cancel it.
388+ if (it != mReadyResponses .end ()
389+ && (!mCurrentRequest .has_value () || getCurrentRequestId () != llmRequest.mRequestId ))
390+ {
391+ mCancelledRequests .insert (llmRequest.mRequestId );
392+ isCancelled = true ;
393+ }
394+ else
395+ {
396+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
397+ }
398+ return isCancelled;
399+ }
400+
401+ void sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
402+ {
403+ auto it = mRequestToSession .find (requestId);
404+ TLLM_CHECK (it != mRequestToSession .end ());
405+ auto & session = it->second ;
406+ auto const & connections = session.getConnections ();
407+ for (size_t i = 0 ; i < connections.size (); i++)
408+ {
409+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
410+ if (agentConnectionManager)
411+ {
412+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
413+ TLLM_CHECK (agentConnection);
414+ agentConnection->sendReadySignal (
415+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, isReady);
416+ }
417+ else
418+ {
419+ connections.at (i)->send (
420+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
421+ }
422+ }
423+ }
424+
382425 ~Impl ()
383426 {
384427 terminate ();
@@ -463,8 +506,42 @@ class CacheSender::Impl
463506 {
464507 mRemainSendCount .erase (reqId);
465508
466- asyncSendAndRemoveResponse (it->first , std::move (it->second ));
467- removeResponse (it);
509+ // Check if the request is cancelled
510+ bool isReady = true ;
511+ {
512+ std::scoped_lock lk (mSenderMutex );
513+ if (mCancelledRequests .find (reqId) != mCancelledRequests .end ())
514+ {
515+ isReady = false ;
516+ }
517+ }
518+ sendReadySignal (reqId, isReady);
519+
520+ if (isReady)
521+ {
522+ asyncSendAndRemoveResponse (it->first , std::move (it->second ));
523+ removeResponse (it);
524+ }
525+ else
526+ {
527+ // TODO: if the generation does not require the kv cache, the request will
528+ // not be removed from mCancelledRequests. This should be handled by timeout.
529+ auto it = mReadyResponses .find (mCurrentRequest .value ());
530+ TLLM_CHECK (it != mReadyResponses .end ());
531+ {
532+ std::scoped_lock lkResp (mSenderMutex );
533+ mReadyResponses .erase (it);
534+ mCancelledRequests .erase (mCurrentRequest .value ());
535+ mRemainSendCount .erase (mCurrentRequest .value ());
536+ }
537+ mCurrentRequest = std::nullopt ;
538+
539+ if (mReadyResponses .empty ())
540+ {
541+ std::unique_lock lk (mCondMutex );
542+ mAnyReady = false ;
543+ }
544+ }
468545 }
469546 mCurrentRequest = std::nullopt ;
470547 }
@@ -491,7 +568,11 @@ class CacheSender::Impl
491568 auto const & requestInfo = recvRequestInfo ();
492569 auto reqId = requestInfo.getRequestId ();
493570
494- mCurrentRequest = reqId;
571+ {
572+ std::scoped_lock lk (mSenderMutex );
573+ mCurrentRequest = reqId;
574+ }
575+
495576 if (mRemainSendCount .find (reqId) == mRemainSendCount .end ())
496577 {
497578 mRemainSendCount [reqId] = getCounterpartsCount (reqId);
@@ -549,7 +630,7 @@ class CacheSender::Impl
549630 void removeResponse (std::map<RequestIdType, Response>::iterator it)
550631 {
551632 {
552- std::unique_lock lkResp (mSenderMutex );
633+ std::scoped_lock lkResp (mSenderMutex );
553634 mReadyResponses .erase (it);
554635 }
555636 if (mReadyResponses .empty ())
@@ -566,12 +647,13 @@ class CacheSender::Impl
566647
567648 [[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse ()
568649 {
569- std::unique_lock lk (mSenderMutex );
650+ std::scoped_lock lk (mSenderMutex );
570651 return mReadyResponses .find (getCurrentRequestId ());
571652 }
572653
573654private:
574655 std::optional<RequestIdType> mCurrentRequest ;
656+ std::set<LlmRequest::RequestIdType> mCancelledRequests ;
575657 std::map<RequestIdType, Response> mReadyResponses ;
576658 std::mutex mSenderMutex , mCondMutex ;
577659 std::atomic<bool > mAnyReady {false }, mTerminate {false };
@@ -619,7 +701,7 @@ class CacheReceiver::Impl
619701 auto promise = std::make_unique<std::promise<void >>();
620702 auto future = promise->get_future ();
621703 TLLM_CHECK (llmRequest.getDataTransceiverState ().getCommState ().has_value ());
622- std::string processInfo = " default " ;
704+ std::string processInfo = kDefaultProcessInfo ;
623705 if (common::getEnvRequestKVCacheConcurrent ())
624706 {
625707 processInfo = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
@@ -703,7 +785,7 @@ class CacheReceiver::Impl
703785
704786 auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
705787 std::optional<size_t > cacheBufferId = std::nullopt ;
706- if (agentConnectionManager != nullptr )
788+ if (agentConnectionManager)
707789 {
708790 cacheBufferId = agentConnectionManager->getCacheTransBufferManager ()->assignBufferIndexForRecv ();
709791 TLLM_CHECK (cacheBufferId.has_value ());
@@ -726,7 +808,7 @@ class CacheReceiver::Impl
726808 auto const * connection = counterPartConnections[i];
727809 // if Manager is agentConnectionManager, then send request info to agent
728810 auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
729- if (agentConnectionManager != nullptr )
811+ if (agentConnectionManager)
730812 {
731813 // TODO: index -> validConnectionIdx conversion
732814 auto validConnectionIdx = std::find (pickUpIdx.begin (), pickUpIdx.end (), i) - pickUpIdx.begin ();
@@ -751,7 +833,7 @@ class CacheReceiver::Impl
751833 {
752834 std::scoped_lock<std::mutex> lock (mProcessIoResouceMutex );
753835 TLLM_CHECK (llmRequest.getDataTransceiverState ().getCommState ().has_value ());
754- std::string processString = " default " ;
836+ std::string processString = kDefaultProcessInfo ;
755837 if (common::getEnvRequestKVCacheConcurrent ())
756838 {
757839 processString = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
@@ -777,6 +859,62 @@ class CacheReceiver::Impl
777859 connection->send (DataContext{TransceiverTag::kINFO_TAG }, serializedInfo.data (), infoSize);
778860 }
779861
862+ bool cancelRequest (LlmRequest const & llmRequest)
863+ {
864+
865+ std::string processInfo = kDefaultProcessInfo ;
866+ if (common::getEnvRequestKVCacheConcurrent ())
867+ {
868+ processInfo = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
869+ }
870+
871+ bool isCancelled = false ;
872+ auto & asyncResource = mInstanceToAsyncResource .at (processInfo);
873+ {
874+ std::unique_lock<std::mutex> lck (asyncResource->mMtxForQueue );
875+ auto it = std::find_if (asyncResource->mRequestsQueue .begin (), asyncResource->mRequestsQueue .end (),
876+ [&llmRequest](RequestAndPromise const & requestAndPromise)
877+ { return requestAndPromise.mRequest ->mRequestId == llmRequest.mRequestId ; });
878+ if (it != asyncResource->mRequestsQueue .end ())
879+ {
880+ asyncResource->mRequestsQueue .erase (it);
881+ isCancelled = true ;
882+ }
883+ else
884+ {
885+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
886+ }
887+ }
888+ return isCancelled;
889+ }
890+
891+ bool receiveReadySignal (TransferSession& session)
892+ {
893+ bool isReadyFinal = true ;
894+ bool isReady = false ;
895+ auto const & connections = session.getConnections ();
896+
897+ for (size_t i = 0 ; i < connections.size (); i++)
898+ {
899+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
900+ if (agentConnectionManager)
901+ {
902+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
903+ TLLM_CHECK (agentConnection);
904+ isReady = agentConnection->recvReadySignal (
905+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG });
906+ }
907+ else
908+ {
909+ connections.at (i)->recv (
910+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
911+ }
912+ isReadyFinal &= isReady;
913+ }
914+
915+ return isReadyFinal;
916+ }
917+
780918 ~Impl ()
781919 {
782920 for (auto && [processInfo, asyncResource] : mInstanceToAsyncResource )
@@ -799,6 +937,14 @@ class CacheReceiver::Impl
799937 llmRequest.setKvCacheTransferStart (std::chrono::steady_clock::now ());
800938 TLLM_CUDA_CHECK (cudaSetDevice (mDeviceId ));
801939 auto session = sendRequestInfo (llmRequest);
940+ bool isReady = receiveReadySignal (session);
941+ if (!isReady)
942+ {
943+ // Reuse the error state for the cancelled request.
944+ llmRequest.setState (LlmRequestState::kDISAGG_TRANS_ERROR );
945+ llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
946+ return ;
947+ }
802948 receiveSync (session);
803949 llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
804950
@@ -915,6 +1061,7 @@ class CacheReceiver::Impl
9151061 }
9161062
9171063 int mDeviceId {-1 };
1064+ static constexpr char const * kDefaultProcessInfo = " default" ;
9181065 std::vector<std::future<void >> mRequestFutures ;
9191066 std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource ;
9201067 executor::kv_cache::ConnectionManager* mManager ;
@@ -970,6 +1117,16 @@ RequestInfo CacheSender::recvRequestInfo()
9701117 return mImpl ->recvRequestInfo ();
9711118}
9721119
1120+ bool CacheSender::cancelRequest (LlmRequest const & llmRequest)
1121+ {
1122+ return mImpl ->cancelRequest (llmRequest);
1123+ }
1124+
1125+ void CacheSender::sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
1126+ {
1127+ mImpl ->sendReadySignal (requestId, isReady);
1128+ }
1129+
9731130CacheReceiver::CacheReceiver (executor::kv_cache::ConnectionManager* manager,
9741131 executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
9751132 : mImpl {std::unique_ptr<Impl, ImplDeleter>(new Impl (manager, selfCacheState, selfIndex, std::move (formatter)))}
@@ -993,4 +1150,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
9931150 mImpl ->receiveSync (session);
9941151}
9951152
1153+ bool CacheReceiver::cancelRequest (LlmRequest const & llmRequest)
1154+ {
1155+ return mImpl ->cancelRequest (llmRequest);
1156+ }
1157+
1158+ bool CacheReceiver::receiveReadySignal (TransferSession& session)
1159+ {
1160+ return mImpl ->receiveReadySignal (session);
1161+ }
1162+
9961163} // namespace tensorrt_llm::batch_manager
0 commit comments