Skip to content

Commit 0c1ba98

Browse files
pcastonguayShunkang
authored andcommitted
[None][feat] Support for cancelling requests with disaggregation (NVIDIA#8114)
Signed-off-by: Shunkang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Co-authored-by: Shunkang <[email protected]>
1 parent c0d747e commit 0c1ba98

File tree

12 files changed

+617
-53
lines changed

12 files changed

+617
-53
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class BaseCacheTransceiver
7171
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
7272

7373
[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;
74+
75+
virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
7476
};
7577

7678
class CacheTransceiver : public BaseCacheTransceiver
@@ -111,6 +113,8 @@ class CacheTransceiver : public BaseCacheTransceiver
111113

112114
[[nodiscard]] bool checkGenTransferComplete() const override;
113115

116+
virtual bool cancelRequest(LlmRequest* llmRequest) override;
117+
114118
private:
115119
void initializeCommState();
116120

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,4 +567,17 @@ bool CacheTransceiver::checkGenTransferComplete() const
567567
return mRequesterFutures.empty();
568568
}
569569

570+
bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
571+
{
572+
if (llmRequest->isContextOnlyRequest())
573+
{
574+
return mCacheSender->cancelRequest(*llmRequest);
575+
}
576+
else if (llmRequest->isGenerationOnlyRequest())
577+
{
578+
return mCacheReceiver->cancelRequest(*llmRequest);
579+
}
580+
return false;
581+
}
582+
570583
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 141 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class CacheSender::Impl
270270
auto future = promise.get_future();
271271
{
272272
{
273-
std::unique_lock lkResp(mSenderMutex);
273+
std::scoped_lock lkResp(mSenderMutex);
274274
mReadyResponses.emplace(
275275
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
276276
}
@@ -380,6 +380,49 @@ class CacheSender::Impl
380380
mFormatter->format(*session);
381381
}
382382

383+
bool cancelRequest(LlmRequest const& llmRequest)
384+
{
385+
bool isCancelled = false;
386+
std::scoped_lock lkResp(mSenderMutex);
387+
auto it = mReadyResponses.find(llmRequest.mRequestId);
388+
// If the request is not the current request and already in the ready queue, we can cancel it.
389+
if (it != mReadyResponses.end()
390+
&& (!mCurrentRequest.has_value() || getCurrentRequestId() != llmRequest.mRequestId))
391+
{
392+
mCancelledRequests.insert(llmRequest.mRequestId);
393+
isCancelled = true;
394+
}
395+
else
396+
{
397+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
398+
}
399+
return isCancelled;
400+
}
401+
402+
void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
403+
{
404+
auto it = mRequestToSession.find(requestId);
405+
TLLM_CHECK(it != mRequestToSession.end());
406+
auto& session = it->second;
407+
auto const& connections = session.getConnections();
408+
for (size_t i = 0; i < connections.size(); i++)
409+
{
410+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
411+
if (agentConnectionManager)
412+
{
413+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
414+
TLLM_CHECK(agentConnection);
415+
agentConnection->sendReadySignal(
416+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, isReady);
417+
}
418+
else
419+
{
420+
connections.at(i)->send(
421+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
422+
}
423+
}
424+
}
425+
383426
~Impl()
384427
{
385428
terminate();
@@ -506,7 +549,11 @@ class CacheSender::Impl
506549
auto const& requestInfo = recvRequestInfo();
507550
auto reqId = requestInfo.getRequestId();
508551

509-
mCurrentRequest = reqId;
552+
{
553+
std::scoped_lock lk(mSenderMutex);
554+
mCurrentRequest = reqId;
555+
}
556+
510557
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
511558
{
512559
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
@@ -564,7 +611,7 @@ class CacheSender::Impl
564611
void removeResponse(std::map<RequestIdType, Response>::iterator it)
565612
{
566613
{
567-
std::unique_lock lkResp(mSenderMutex);
614+
std::scoped_lock lkResp(mSenderMutex);
568615
mReadyResponses.erase(it);
569616
}
570617
if (mReadyResponses.empty())
@@ -581,12 +628,13 @@ class CacheSender::Impl
581628

582629
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
583630
{
584-
std::unique_lock lk(mSenderMutex);
631+
std::scoped_lock lk(mSenderMutex);
585632
return mReadyResponses.find(getCurrentRequestId());
586633
}
587634

588635
private:
589636
std::optional<RequestIdType> mCurrentRequest;
637+
std::set<LlmRequest::RequestIdType> mCancelledRequests;
590638
std::map<RequestIdType, Response> mReadyResponses;
591639
std::mutex mSenderMutex, mCondMutex;
592640
std::atomic<bool> mAnyReady{false}, mTerminate{false};
@@ -634,7 +682,7 @@ class CacheReceiver::Impl
634682
auto promise = std::make_unique<std::promise<void>>();
635683
auto future = promise->get_future();
636684
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
637-
std::string processInfo = "default";
685+
std::string processInfo = kDefaultProcessInfo;
638686
if (common::getEnvRequestKVCacheConcurrent())
639687
{
640688
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
@@ -718,7 +766,7 @@ class CacheReceiver::Impl
718766

719767
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
720768
std::optional<size_t> cacheBufferId = std::nullopt;
721-
if (agentConnectionManager != nullptr)
769+
if (agentConnectionManager)
722770
{
723771
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
724772
TLLM_CHECK(cacheBufferId.has_value());
@@ -741,7 +789,7 @@ class CacheReceiver::Impl
741789
auto const* connection = counterPartConnections[i];
742790
// if Manager is agentConnectionManager, then send request info to agent
743791
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
744-
if (agentConnectionManager != nullptr)
792+
if (agentConnectionManager)
745793
{
746794
// TODO: index -> validConnectionIdx conversion
747795
auto validConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
@@ -766,7 +814,7 @@ class CacheReceiver::Impl
766814
{
767815
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
768816
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
769-
std::string processString = "default";
817+
std::string processString = kDefaultProcessInfo;
770818
if (common::getEnvRequestKVCacheConcurrent())
771819
{
772820
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
@@ -792,6 +840,62 @@ class CacheReceiver::Impl
792840
connection->send(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
793841
}
794842

843+
bool cancelRequest(LlmRequest const& llmRequest)
844+
{
845+
846+
std::string processInfo = kDefaultProcessInfo;
847+
if (common::getEnvRequestKVCacheConcurrent())
848+
{
849+
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
850+
}
851+
852+
bool isCancelled = false;
853+
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
854+
{
855+
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
856+
auto it = std::find_if(asyncResource->mRequestsQueue.begin(), asyncResource->mRequestsQueue.end(),
857+
[&llmRequest](RequestAndPromise const& requestAndPromise)
858+
{ return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; });
859+
if (it != asyncResource->mRequestsQueue.end())
860+
{
861+
asyncResource->mRequestsQueue.erase(it);
862+
isCancelled = true;
863+
}
864+
else
865+
{
866+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
867+
}
868+
}
869+
return isCancelled;
870+
}
871+
872+
bool receiveReadySignal(TransferSession& session)
873+
{
874+
bool isReadyFinal = true;
875+
bool isReady = false;
876+
auto const& connections = session.getConnections();
877+
878+
for (size_t i = 0; i < connections.size(); i++)
879+
{
880+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
881+
if (agentConnectionManager)
882+
{
883+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
884+
TLLM_CHECK(agentConnection);
885+
isReady = agentConnection->recvReadySignal(
886+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
887+
}
888+
else
889+
{
890+
connections.at(i)->recv(
891+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
892+
}
893+
isReadyFinal &= isReady;
894+
}
895+
896+
return isReadyFinal;
897+
}
898+
795899
~Impl()
796900
{
797901
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
@@ -814,6 +918,14 @@ class CacheReceiver::Impl
814918
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
815919
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
816920
auto session = sendRequestInfo(llmRequest);
921+
bool isReady = receiveReadySignal(session);
922+
if (!isReady)
923+
{
924+
// Reuse the error state for the cancelled request.
925+
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
926+
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
927+
return;
928+
}
817929
receiveSync(session);
818930
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
819931

@@ -930,6 +1042,7 @@ class CacheReceiver::Impl
9301042
}
9311043

9321044
int mDeviceId{-1};
1045+
static constexpr char const* kDefaultProcessInfo = "default";
9331046
std::vector<std::future<void>> mRequestFutures;
9341047
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
9351048
executor::kv_cache::ConnectionManager* mManager;
@@ -985,6 +1098,16 @@ RequestInfo CacheSender::recvRequestInfo()
9851098
return mImpl->recvRequestInfo();
9861099
}
9871100

1101+
bool CacheSender::cancelRequest(LlmRequest const& llmRequest)
1102+
{
1103+
return mImpl->cancelRequest(llmRequest);
1104+
}
1105+
1106+
void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
1107+
{
1108+
mImpl->sendReadySignal(requestId, isReady);
1109+
}
1110+
9881111
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
9891112
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
9901113
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
@@ -1008,4 +1131,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
10081131
mImpl->receiveSync(session);
10091132
}
10101133

1134+
bool CacheReceiver::cancelRequest(LlmRequest const& llmRequest)
1135+
{
1136+
return mImpl->cancelRequest(llmRequest);
1137+
}
1138+
1139+
bool CacheReceiver::receiveReadySignal(TransferSession& session)
1140+
{
1141+
return mImpl->receiveReadySignal(session);
1142+
}
1143+
10111144
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ struct TransceiverTag
144144
static constexpr int32_t kID_TAG{19};
145145
static constexpr int32_t kINFO_SIZE_TAG{22};
146146
static constexpr int32_t kINFO_TAG{32};
147+
static constexpr int32_t kREADY_SIGNAL_TAG{42};
147148
};
148149

149150
// Used to store the information that needs to be sent to the context executor to ensure the generation
@@ -240,6 +241,16 @@ class CacheSender
240241
/// @param llmRequest The request object to which the data belongs.
241242
virtual RequestInfo recvRequestInfo();
242243

244+
/// @brief Cancel the request.
245+
/// @param requestId The ID used in the context phase of the current request.
246+
/// @return Whether the request is cancelled.
247+
virtual bool cancelRequest(LlmRequest const& llmRequest);
248+
249+
/// @brief Send ready signal.
250+
/// @param requestId The ID used in the context phase of the current request.
251+
/// @param isReady Whether the request is ready to be received.
252+
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);
253+
243254
/// @brief Destructor.
244255
virtual ~CacheSender();
245256

@@ -272,6 +283,17 @@ class CacheReceiver
272283
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);
273284

274285
virtual void receiveSync(TransferSession& session);
286+
287+
/// @brief Cancel the request.
288+
/// @param llmRequest Request object.
289+
/// @return Whether the request is cancelled.
290+
virtual bool cancelRequest(LlmRequest const& llmRequest);
291+
292+
/// @brief Receive ready signal.
293+
/// @param session The session object.
294+
/// @return Whether the request is ready to be received.
295+
virtual bool receiveReadySignal(TransferSession& session);
296+
275297
/// @brief Destructor.
276298
virtual ~CacheReceiver();
277299

0 commit comments

Comments
 (0)