Skip to content

Commit ba55434

Browse files
ShunkangShunkang
authored andcommitted
Add cancel request support
Signed-off-by: Shunkang <[email protected]>
1 parent 8aead22 commit ba55434

File tree

11 files changed

+446
-19
lines changed

11 files changed

+446
-19
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class BaseCacheTransceiver
7171
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;
7272

7373
[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;
74+
75+
virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
7476
};
7577

7678
class CacheTransceiver : public BaseCacheTransceiver
@@ -111,6 +113,8 @@ class CacheTransceiver : public BaseCacheTransceiver
111113

112114
[[nodiscard]] bool checkGenTransferComplete() const override;
113115

116+
virtual bool cancelRequest(LlmRequest* llmRequest) override;
117+
114118
private:
115119
void initializeCommState();
116120

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,17 @@ bool CacheTransceiver::checkGenTransferComplete() const
572572
return mRequesterFutures.empty();
573573
}
574574

575+
bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
576+
{
577+
if (llmRequest->isContextOnlyRequest())
578+
{
579+
return mCacheSender->cancelRequest(*llmRequest);
580+
}
581+
else if (llmRequest->isGenerationOnlyRequest())
582+
{
583+
return mCacheReceiver->cancelRequest(*llmRequest);
584+
}
585+
return false;
586+
}
587+
575588
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 174 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,48 @@ class CacheSender::Impl
348348
mFormatter->format(session);
349349
}
350350

351+
bool cancelRequest(LlmRequest const& llmRequest)
352+
{
353+
bool isCancelled = false;
354+
std::unique_lock lkResp(mSenderMutex);
355+
auto it = mReadyResponses.find(llmRequest.mRequestId);
356+
// If the request is not the current request and already in the ready queue, we can cancel it.
357+
if (it != mReadyResponses.end() && (!isSending() || getCurrentRequestId() != llmRequest.mRequestId))
358+
{
359+
mCancelledRequests.insert(llmRequest.mRequestId);
360+
isCancelled = true;
361+
}
362+
else
363+
{
364+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
365+
}
366+
return isCancelled;
367+
}
368+
369+
void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
370+
{
371+
auto it = mRequestToSession.find(requestId);
372+
TLLM_CHECK(it != mRequestToSession.end());
373+
auto& session = it->second;
374+
auto connections = session.getConnections();
375+
for (size_t i = 0; i < connections.size(); i++)
376+
{
377+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
378+
if (agentConnectionManager != nullptr)
379+
{
380+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
381+
TLLM_CHECK(agentConnection != nullptr);
382+
agentConnection->sendReadySignal(
383+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, isReady);
384+
}
385+
else
386+
{
387+
connections.at(i)->send(
388+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
389+
}
390+
}
391+
}
392+
351393
~Impl()
352394
{
353395
terminate();
@@ -391,20 +433,54 @@ class CacheSender::Impl
391433
{
392434
mRemainSendCount.erase(reqId);
393435

394-
// TODO(zhengd): pass the hashes directly instead of update llmRequest
395-
auto llmRequest = it->second.mRequest;
396-
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
436+
// Check if the request is cancelled
437+
bool isReady = true;
438+
{
439+
std::unique_lock lk(mSenderMutex);
440+
if (mCancelledRequests.find(reqId) != mCancelledRequests.end())
441+
{
442+
isReady = false;
443+
}
444+
}
445+
sendReadySignal(reqId, isReady);
397446

398-
if (common::getEnvParallelCacheSend())
447+
if (isReady)
399448
{
400-
// TODO: Use a thread pool and check for thread safety.
401-
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach();
449+
// TODO(zhengd): pass the hashes directly instead of update llmRequest
450+
auto llmRequest = it->second.mRequest;
451+
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
452+
453+
if (common::getEnvParallelCacheSend())
454+
{
455+
// TODO: Use a thread pool and check for thread safety.
456+
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
457+
.detach();
458+
}
459+
else
460+
{
461+
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
462+
}
463+
removeResponse(it);
402464
}
403465
else
404466
{
405-
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
467+
// TODO: if the generation does not require the kv cache, the request will
468+
// not be removed from mCancelledRequests. This should be handled by timeout.
469+
auto it = mReadyResponses.find(mCurrentRequest.value());
470+
{
471+
std::unique_lock lkResp(mSenderMutex);
472+
mReadyResponses.erase(it);
473+
mCancelledRequests.erase(mCurrentRequest.value());
474+
mRemainSendCount.erase(mCurrentRequest.value());
475+
}
476+
mCurrentRequest = std::nullopt;
477+
478+
if (mReadyResponses.empty())
479+
{
480+
std::unique_lock lk(mCondMutex);
481+
mAnyReady = false;
482+
}
406483
}
407-
removeResponse(it);
408484
}
409485
mCurrentRequest = std::nullopt;
410486
}
@@ -433,7 +509,11 @@ class CacheSender::Impl
433509
auto reqId = requestInfo.getRequestId();
434510
blockHashes = requestInfo.getBlockHashes();
435511

436-
mCurrentRequest = reqId;
512+
{
513+
std::unique_lock lk(mSenderMutex);
514+
mCurrentRequest = reqId;
515+
}
516+
437517
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
438518
{
439519
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
@@ -513,6 +593,7 @@ class CacheSender::Impl
513593

514594
private:
515595
std::optional<RequestIdType> mCurrentRequest;
596+
std::set<LlmRequest::RequestIdType> mCancelledRequests;
516597
std::map<RequestIdType, Response> mReadyResponses;
517598
std::mutex mSenderMutex, mCondMutex;
518599
std::atomic<bool> mAnyReady{false}, mTerminate{false};
@@ -685,6 +766,62 @@ class CacheReceiver::Impl
685766
connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
686767
}
687768

769+
bool cancelRequest(LlmRequest const& llmRequest)
770+
{
771+
772+
std::string processInfo = "default";
773+
if (common::getEnvRequestKVCacheConcurrent())
774+
{
775+
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
776+
}
777+
778+
bool isCancelled = false;
779+
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
780+
{
781+
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
782+
auto it = std::find_if(asyncResource->mRequestsQueue.begin(), asyncResource->mRequestsQueue.end(),
783+
[&llmRequest](RequestAndPromise const& requestAndPromise)
784+
{ return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; });
785+
if (it != asyncResource->mRequestsQueue.end())
786+
{
787+
asyncResource->mRequestsQueue.erase(it);
788+
isCancelled = true;
789+
}
790+
else
791+
{
792+
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
793+
}
794+
}
795+
return isCancelled;
796+
}
797+
798+
bool receiveReadySignal(TransferSession& session)
799+
{
800+
bool isReadyFinal = true;
801+
bool isReady = false;
802+
auto connections = session.getConnections();
803+
804+
for (size_t i = 0; i < connections.size(); i++)
805+
{
806+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
807+
if (agentConnectionManager != nullptr)
808+
{
809+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
810+
TLLM_CHECK(agentConnection != nullptr);
811+
isReady = agentConnection->recvReadySignal(
812+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
813+
}
814+
else
815+
{
816+
connections.at(i)->recv(
817+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
818+
}
819+
isReadyFinal &= isReady;
820+
}
821+
822+
return isReadyFinal;
823+
}
824+
688825
~Impl()
689826
{
690827
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
@@ -707,6 +844,14 @@ class CacheReceiver::Impl
707844
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
708845
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
709846
auto session = sendRequestInfo(llmRequest);
847+
bool isReady = receiveReadySignal(session);
848+
if (!isReady)
849+
{
850+
// Reuse the error state for the cancelled request.
851+
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
852+
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
853+
return;
854+
}
710855
receiveSync(session);
711856
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
712857

@@ -876,6 +1021,16 @@ RequestInfo CacheSender::recvRequestInfo()
8761021
return mImpl->recvRequestInfo();
8771022
}
8781023

1024+
bool CacheSender::cancelRequest(LlmRequest const& llmRequest)
1025+
{
1026+
return mImpl->cancelRequest(llmRequest);
1027+
}
1028+
1029+
void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
1030+
{
1031+
mImpl->sendReadySignal(requestId, isReady);
1032+
}
1033+
8791034
CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
8801035
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
8811036
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
@@ -899,4 +1054,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
8991054
mImpl->receiveSync(session);
9001055
}
9011056

1057+
bool CacheReceiver::cancelRequest(LlmRequest const& llmRequest)
1058+
{
1059+
return mImpl->cancelRequest(llmRequest);
1060+
}
1061+
1062+
bool CacheReceiver::receiveReadySignal(TransferSession& session)
1063+
{
1064+
return mImpl->receiveReadySignal(session);
1065+
}
1066+
9021067
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct TransceiverTag
122122
static constexpr int32_t kID_TAG{19};
123123
static constexpr int32_t kINFO_SIZE_TAG{22};
124124
static constexpr int32_t kINFO_TAG{32};
125+
static constexpr int32_t kREADY_SIGNAL_TAG{42};
125126
};
126127

127128
// Used to store the information that needs to be sent to the context executor to ensure the generation
@@ -207,6 +208,16 @@ class CacheSender
207208
/// @param llmRequest The request object to which the data belongs.
208209
virtual RequestInfo recvRequestInfo();
209210

211+
/// @brief Cancel the request.
212+
/// @param requestId The ID used in the context phase of the current request.
213+
/// @return Whether the request is cancelled.
214+
virtual bool cancelRequest(LlmRequest const& llmRequest);
215+
216+
/// @brief Send ready signal.
217+
/// @param requestId The ID used in the context phase of the current request.
218+
/// @param isReady Whether the request is ready to be received.
219+
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);
220+
210221
/// @brief Destructor.
211222
virtual ~CacheSender();
212223

@@ -239,6 +250,17 @@ class CacheReceiver
239250
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);
240251

241252
virtual void receiveSync(TransferSession& session);
253+
254+
/// @brief Cancel the request.
255+
/// @param llmRequest Request object.
256+
/// @return Whether the request is cancelled.
257+
virtual bool cancelRequest(LlmRequest const& llmRequest);
258+
259+
/// @brief Receive ready signal.
260+
/// @param session The session object.
261+
/// @return Whether the request is ready to be received.
262+
virtual bool receiveReadySignal(TransferSession& session);
263+
242264
/// @brief Destructor.
243265
virtual ~CacheReceiver();
244266

0 commit comments

Comments
 (0)