@@ -270,7 +270,7 @@ class CacheSender::Impl
270270 auto future = promise.get_future ();
271271 {
272272 {
273- std::unique_lock lkResp (mSenderMutex );
273+ std::scoped_lock lkResp (mSenderMutex );
274274 mReadyResponses .emplace (
275275 llmRequest.mRequestId , Response{std::addressof (llmRequest), std::move (promise)});
276276 }
@@ -380,6 +380,49 @@ class CacheSender::Impl
380380 mFormatter ->format (*session);
381381 }
382382
383+ bool cancelRequest (LlmRequest const & llmRequest)
384+ {
385+ bool isCancelled = false ;
386+ std::scoped_lock lkResp (mSenderMutex );
387+ auto it = mReadyResponses .find (llmRequest.mRequestId );
388+ // If the request is not the current request and already in the ready queue, we can cancel it.
389+ if (it != mReadyResponses .end ()
390+ && (!mCurrentRequest .has_value () || getCurrentRequestId () != llmRequest.mRequestId ))
391+ {
392+ mCancelledRequests .insert (llmRequest.mRequestId );
393+ isCancelled = true ;
394+ }
395+ else
396+ {
397+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
398+ }
399+ return isCancelled;
400+ }
401+
402+ void sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
403+ {
404+ auto it = mRequestToSession .find (requestId);
405+ TLLM_CHECK (it != mRequestToSession .end ());
406+ auto & session = it->second ;
407+ auto const & connections = session.getConnections ();
408+ for (size_t i = 0 ; i < connections.size (); i++)
409+ {
410+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
411+ if (agentConnectionManager)
412+ {
413+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
414+ TLLM_CHECK (agentConnection);
415+ agentConnection->sendReadySignal (
416+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, isReady);
417+ }
418+ else
419+ {
420+ connections.at (i)->send (
421+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
422+ }
423+ }
424+ }
425+
383426 ~Impl ()
384427 {
385428 terminate ();
@@ -506,7 +549,11 @@ class CacheSender::Impl
506549 auto const & requestInfo = recvRequestInfo ();
507550 auto reqId = requestInfo.getRequestId ();
508551
509- mCurrentRequest = reqId;
552+ {
553+ std::scoped_lock lk (mSenderMutex );
554+ mCurrentRequest = reqId;
555+ }
556+
510557 if (mRemainSendCount .find (reqId) == mRemainSendCount .end ())
511558 {
512559 mRemainSendCount [reqId] = getCounterpartsCount (reqId);
@@ -564,7 +611,7 @@ class CacheSender::Impl
564611 void removeResponse (std::map<RequestIdType, Response>::iterator it)
565612 {
566613 {
567- std::unique_lock lkResp (mSenderMutex );
614+ std::scoped_lock lkResp (mSenderMutex );
568615 mReadyResponses .erase (it);
569616 }
570617 if (mReadyResponses .empty ())
@@ -581,12 +628,13 @@ class CacheSender::Impl
581628
582629 [[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse ()
583630 {
584- std::unique_lock lk (mSenderMutex );
631+ std::scoped_lock lk (mSenderMutex );
585632 return mReadyResponses .find (getCurrentRequestId ());
586633 }
587634
588635private:
589636 std::optional<RequestIdType> mCurrentRequest ;
637+ std::set<LlmRequest::RequestIdType> mCancelledRequests ;
590638 std::map<RequestIdType, Response> mReadyResponses ;
591639 std::mutex mSenderMutex , mCondMutex ;
592640 std::atomic<bool > mAnyReady {false }, mTerminate {false };
@@ -634,7 +682,7 @@ class CacheReceiver::Impl
634682 auto promise = std::make_unique<std::promise<void >>();
635683 auto future = promise->get_future ();
636684 TLLM_CHECK (llmRequest.getDataTransceiverState ().getCommState ().has_value ());
637- std::string processInfo = " default " ;
685+ std::string processInfo = kDefaultProcessInfo ;
638686 if (common::getEnvRequestKVCacheConcurrent ())
639687 {
640688 processInfo = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
@@ -718,7 +766,7 @@ class CacheReceiver::Impl
718766
719767 auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
720768 std::optional<size_t > cacheBufferId = std::nullopt ;
721- if (agentConnectionManager != nullptr )
769+ if (agentConnectionManager)
722770 {
723771 cacheBufferId = agentConnectionManager->getCacheTransBufferManager ()->assignBufferIndexForRecv ();
724772 TLLM_CHECK (cacheBufferId.has_value ());
@@ -741,7 +789,7 @@ class CacheReceiver::Impl
741789 auto const * connection = counterPartConnections[i];
742790 // if Manager is agentConnectionManager, then send request info to agent
743791 auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
744- if (agentConnectionManager != nullptr )
792+ if (agentConnectionManager)
745793 {
746794 // TODO: index -> validConnectionIdx conversion
747795 auto validConnectionIdx = std::find (pickUpIdx.begin (), pickUpIdx.end (), i) - pickUpIdx.begin ();
@@ -766,7 +814,7 @@ class CacheReceiver::Impl
766814 {
767815 std::scoped_lock<std::mutex> lock (mProcessIoResouceMutex );
768816 TLLM_CHECK (llmRequest.getDataTransceiverState ().getCommState ().has_value ());
769- std::string processString = " default " ;
817+ std::string processString = kDefaultProcessInfo ;
770818 if (common::getEnvRequestKVCacheConcurrent ())
771819 {
772820 processString = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
@@ -792,6 +840,62 @@ class CacheReceiver::Impl
792840 connection->send (DataContext{TransceiverTag::kINFO_TAG }, serializedInfo.data (), infoSize);
793841 }
794842
843+ bool cancelRequest (LlmRequest const & llmRequest)
844+ {
845+
846+ std::string processInfo = kDefaultProcessInfo ;
847+ if (common::getEnvRequestKVCacheConcurrent ())
848+ {
849+ processInfo = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
850+ }
851+
852+ bool isCancelled = false ;
853+ auto & asyncResource = mInstanceToAsyncResource .at (processInfo);
854+ {
855+ std::unique_lock<std::mutex> lck (asyncResource->mMtxForQueue );
856+ auto it = std::find_if (asyncResource->mRequestsQueue .begin (), asyncResource->mRequestsQueue .end (),
857+ [&llmRequest](RequestAndPromise const & requestAndPromise)
858+ { return requestAndPromise.mRequest ->mRequestId == llmRequest.mRequestId ; });
859+ if (it != asyncResource->mRequestsQueue .end ())
860+ {
861+ asyncResource->mRequestsQueue .erase (it);
862+ isCancelled = true ;
863+ }
864+ else
865+ {
866+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
867+ }
868+ }
869+ return isCancelled;
870+ }
871+
872+ bool receiveReadySignal (TransferSession& session)
873+ {
874+ bool isReadyFinal = true ;
875+ bool isReady = false ;
876+ auto const & connections = session.getConnections ();
877+
878+ for (size_t i = 0 ; i < connections.size (); i++)
879+ {
880+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
881+ if (agentConnectionManager)
882+ {
883+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
884+ TLLM_CHECK (agentConnection);
885+ isReady = agentConnection->recvReadySignal (
886+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG });
887+ }
888+ else
889+ {
890+ connections.at (i)->recv (
891+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
892+ }
893+ isReadyFinal &= isReady;
894+ }
895+
896+ return isReadyFinal;
897+ }
898+
795899 ~Impl ()
796900 {
797901 for (auto && [processInfo, asyncResource] : mInstanceToAsyncResource )
@@ -814,6 +918,14 @@ class CacheReceiver::Impl
814918 llmRequest.setKvCacheTransferStart (std::chrono::steady_clock::now ());
815919 TLLM_CUDA_CHECK (cudaSetDevice (mDeviceId ));
816920 auto session = sendRequestInfo (llmRequest);
921+ bool isReady = receiveReadySignal (session);
922+ if (!isReady)
923+ {
924+ // Reuse the error state for the cancelled request.
925+ llmRequest.setState (LlmRequestState::kDISAGG_TRANS_ERROR );
926+ llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
927+ return ;
928+ }
817929 receiveSync (session);
818930 llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
819931
@@ -930,6 +1042,7 @@ class CacheReceiver::Impl
9301042 }
9311043
9321044 int mDeviceId {-1 };
1045+ static constexpr char const * kDefaultProcessInfo = " default" ;
9331046 std::vector<std::future<void >> mRequestFutures ;
9341047 std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource ;
9351048 executor::kv_cache::ConnectionManager* mManager ;
@@ -985,6 +1098,16 @@ RequestInfo CacheSender::recvRequestInfo()
9851098 return mImpl ->recvRequestInfo ();
9861099}
9871100
1101+ bool CacheSender::cancelRequest (LlmRequest const & llmRequest)
1102+ {
1103+ return mImpl ->cancelRequest (llmRequest);
1104+ }
1105+
1106+ void CacheSender::sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
1107+ {
1108+ mImpl ->sendReadySignal (requestId, isReady);
1109+ }
1110+
9881111CacheReceiver::CacheReceiver (executor::kv_cache::ConnectionManager* manager,
9891112 executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
9901113 : mImpl {std::unique_ptr<Impl, ImplDeleter>(new Impl (manager, selfCacheState, selfIndex, std::move (formatter)))}
@@ -1008,4 +1131,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
10081131 mImpl ->receiveSync (session);
10091132}
10101133
1134+ bool CacheReceiver::cancelRequest (LlmRequest const & llmRequest)
1135+ {
1136+ return mImpl ->cancelRequest (llmRequest);
1137+ }
1138+
1139+ bool CacheReceiver::receiveReadySignal (TransferSession& session)
1140+ {
1141+ return mImpl ->receiveReadySignal (session);
1142+ }
1143+
10111144} // namespace tensorrt_llm::batch_manager
0 commit comments