@@ -348,6 +348,48 @@ class CacheSender::Impl
348348 mFormatter ->format (session);
349349 }
350350
351+ bool cancelRequest (LlmRequest const & llmRequest)
352+ {
353+ bool isCancelled = false ;
354+ std::unique_lock lkResp (mSenderMutex );
355+ auto it = mReadyResponses .find (llmRequest.mRequestId );
356+ // If the request is not the current request and already in the ready queue, we can cancel it.
357+ if (it != mReadyResponses .end () && (!isSending () || getCurrentRequestId () != llmRequest.mRequestId ))
358+ {
359+ mCancelledRequests .insert (llmRequest.mRequestId );
360+ isCancelled = true ;
361+ }
362+ else
363+ {
364+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
365+ }
366+ return isCancelled;
367+ }
368+
369+ void sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
370+ {
371+ auto it = mRequestToSession .find (requestId);
372+ TLLM_CHECK (it != mRequestToSession .end ());
373+ auto & session = it->second ;
374+ auto connections = session.getConnections ();
375+ for (size_t i = 0 ; i < connections.size (); i++)
376+ {
377+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
378+ if (agentConnectionManager != nullptr )
379+ {
380+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
381+ TLLM_CHECK (agentConnection != nullptr );
382+ agentConnection->sendReadySignal (
383+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, isReady);
384+ }
385+ else
386+ {
387+ connections.at (i)->send (
388+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
389+ }
390+ }
391+ }
392+
351393 ~Impl ()
352394 {
353395 terminate ();
@@ -391,20 +433,54 @@ class CacheSender::Impl
391433 {
392434 mRemainSendCount .erase (reqId);
393435
394- // TODO(zhengd): pass the hashes directly instead of update llmRequest
395- auto llmRequest = it->second .mRequest ;
396- llmRequest->setRequestedBlockHashes (std::move (blockHashes));
436+ // Check if the request is cancelled
437+ bool isReady = true ;
438+ {
439+ std::unique_lock lk (mSenderMutex );
440+ if (mCancelledRequests .find (reqId) != mCancelledRequests .end ())
441+ {
442+ isReady = false ;
443+ }
444+ }
445+ sendReadySignal (reqId, isReady);
397446
398- if (common::getEnvParallelCacheSend () )
447+ if (isReady )
399448 {
400- // TODO: Use a thread pool and check for thread safety.
401- std::thread (&CacheSender::Impl::sendAndRemoveResponse, this , it->first , std::move (it->second )).detach ();
449+ // TODO(zhengd): pass the hashes directly instead of update llmRequest
450+ auto llmRequest = it->second .mRequest ;
451+ llmRequest->setRequestedBlockHashes (std::move (blockHashes));
452+
453+ if (common::getEnvParallelCacheSend ())
454+ {
455+ // TODO: Use a thread pool and check for thread safety.
456+ std::thread (&CacheSender::Impl::sendAndRemoveResponse, this , it->first , std::move (it->second ))
457+ .detach ();
458+ }
459+ else
460+ {
461+ CacheSender::Impl::sendAndRemoveResponse (it->first , std::move (it->second ));
462+ }
463+ removeResponse (it);
402464 }
403465 else
404466 {
405- CacheSender::Impl::sendAndRemoveResponse (it->first , std::move (it->second ));
467+ // TODO: if the generation does not require the kv cache, the request will
468+ // not be removed from mCancelledRequests. This should be handled by timeout.
469+ auto it = mReadyResponses .find (mCurrentRequest .value ());
470+ {
471+ std::unique_lock lkResp (mSenderMutex );
472+ mReadyResponses .erase (it);
473+ mCancelledRequests .erase (mCurrentRequest .value ());
474+ mRemainSendCount .erase (mCurrentRequest .value ());
475+ }
476+ mCurrentRequest = std::nullopt ;
477+
478+ if (mReadyResponses .empty ())
479+ {
480+ std::unique_lock lk (mCondMutex );
481+ mAnyReady = false ;
482+ }
406483 }
407- removeResponse (it);
408484 }
409485 mCurrentRequest = std::nullopt ;
410486 }
@@ -433,7 +509,11 @@ class CacheSender::Impl
433509 auto reqId = requestInfo.getRequestId ();
434510 blockHashes = requestInfo.getBlockHashes ();
435511
436- mCurrentRequest = reqId;
512+ {
513+ std::unique_lock lk (mSenderMutex );
514+ mCurrentRequest = reqId;
515+ }
516+
437517 if (mRemainSendCount .find (reqId) == mRemainSendCount .end ())
438518 {
439519 mRemainSendCount [reqId] = getCounterpartsCount (reqId);
@@ -513,6 +593,7 @@ class CacheSender::Impl
513593
514594private:
515595 std::optional<RequestIdType> mCurrentRequest ;
596+ std::set<LlmRequest::RequestIdType> mCancelledRequests ;
516597 std::map<RequestIdType, Response> mReadyResponses ;
517598 std::mutex mSenderMutex , mCondMutex ;
518599 std::atomic<bool > mAnyReady {false }, mTerminate {false };
@@ -685,6 +766,62 @@ class CacheReceiver::Impl
685766 connection->send (executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG }, serializedInfo.data (), infoSize);
686767 }
687768
769+ bool cancelRequest (LlmRequest const & llmRequest)
770+ {
771+
772+ std::string processInfo = " default" ;
773+ if (common::getEnvRequestKVCacheConcurrent ())
774+ {
775+ processInfo = llmRequest.getDataTransceiverState ().getCommState ()->toString ();
776+ }
777+
778+ bool isCancelled = false ;
779+ auto & asyncResource = mInstanceToAsyncResource .at (processInfo);
780+ {
781+ std::unique_lock<std::mutex> lck (asyncResource->mMtxForQueue );
782+ auto it = std::find_if (asyncResource->mRequestsQueue .begin (), asyncResource->mRequestsQueue .end (),
783+ [&llmRequest](RequestAndPromise const & requestAndPromise)
784+ { return requestAndPromise.mRequest ->mRequestId == llmRequest.mRequestId ; });
785+ if (it != asyncResource->mRequestsQueue .end ())
786+ {
787+ asyncResource->mRequestsQueue .erase (it);
788+ isCancelled = true ;
789+ }
790+ else
791+ {
792+ TLLM_LOG_WARNING (" Cannot cancel request %zu" , llmRequest.mRequestId );
793+ }
794+ }
795+ return isCancelled;
796+ }
797+
798+ bool receiveReadySignal (TransferSession& session)
799+ {
800+ bool isReadyFinal = true ;
801+ bool isReady = false ;
802+ auto connections = session.getConnections ();
803+
804+ for (size_t i = 0 ; i < connections.size (); i++)
805+ {
806+ auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
807+ if (agentConnectionManager != nullptr )
808+ {
809+ auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections.at (i));
810+ TLLM_CHECK (agentConnection != nullptr );
811+ isReady = agentConnection->recvReadySignal (
812+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG });
813+ }
814+ else
815+ {
816+ connections.at (i)->recv (
817+ executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG }, &isReady, sizeof (isReady));
818+ }
819+ isReadyFinal &= isReady;
820+ }
821+
822+ return isReadyFinal;
823+ }
824+
688825 ~Impl ()
689826 {
690827 for (auto && [processInfo, asyncResource] : mInstanceToAsyncResource )
@@ -707,6 +844,14 @@ class CacheReceiver::Impl
707844 llmRequest.setKvCacheTransferStart (std::chrono::steady_clock::now ());
708845 TLLM_CUDA_CHECK (cudaSetDevice (mDeviceId ));
709846 auto session = sendRequestInfo (llmRequest);
847+ bool isReady = receiveReadySignal (session);
848+ if (!isReady)
849+ {
850+ // Reuse the error state for the cancelled request.
851+ llmRequest.setState (LlmRequestState::kDISAGG_TRANS_ERROR );
852+ llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
853+ return ;
854+ }
710855 receiveSync (session);
711856 llmRequest.setKvCacheTransferEnd (std::chrono::steady_clock::now ());
712857
@@ -876,6 +1021,16 @@ RequestInfo CacheSender::recvRequestInfo()
8761021 return mImpl ->recvRequestInfo ();
8771022}
8781023
1024+ bool CacheSender::cancelRequest (LlmRequest const & llmRequest)
1025+ {
1026+ return mImpl ->cancelRequest (llmRequest);
1027+ }
1028+
1029+ void CacheSender::sendReadySignal (LlmRequest::RequestIdType requestId, bool isReady)
1030+ {
1031+ mImpl ->sendReadySignal (requestId, isReady);
1032+ }
1033+
8791034CacheReceiver::CacheReceiver (executor::kv_cache::ConnectionManager* manager,
8801035 executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
8811036 : mImpl {std::unique_ptr<Impl, ImplDeleter>(new Impl (manager, selfCacheState, selfIndex, std::move (formatter)))}
@@ -899,4 +1054,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
8991054 mImpl ->receiveSync (session);
9001055}
9011056
1057+ bool CacheReceiver::cancelRequest (LlmRequest const & llmRequest)
1058+ {
1059+ return mImpl ->cancelRequest (llmRequest);
1060+ }
1061+
1062+ bool CacheReceiver::receiveReadySignal (TransferSession& session)
1063+ {
1064+ return mImpl ->receiveReadySignal (session);
1065+ }
1066+
9021067} // namespace tensorrt_llm::batch_manager
0 commit comments