Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class BaseCacheTransceiver
virtual void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) = 0;

[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;

virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
};

class CacheTransceiver : public BaseCacheTransceiver
Expand Down Expand Up @@ -111,6 +113,8 @@ class CacheTransceiver : public BaseCacheTransceiver

[[nodiscard]] bool checkGenTransferComplete() const override;

virtual bool cancelRequest(LlmRequest* llmRequest) override;

private:
void initializeCommState();

Expand Down
13 changes: 13 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,4 +567,17 @@ bool CacheTransceiver::checkGenTransferComplete() const
return mRequesterFutures.empty();
}

bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest)
{
if (llmRequest->isContextOnlyRequest())
{
return mCacheSender->cancelRequest(*llmRequest);
}
else if (llmRequest->isGenerationOnlyRequest())
{
return mCacheReceiver->cancelRequest(*llmRequest);
}
return false;
}

} // namespace tensorrt_llm::batch_manager
187 changes: 177 additions & 10 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class CacheSender::Impl
auto future = promise.get_future();
{
{
std::unique_lock lkResp(mSenderMutex);
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.emplace(
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
}
Expand Down Expand Up @@ -379,6 +379,49 @@ class CacheSender::Impl
mFormatter->format(*session);
}

bool cancelRequest(LlmRequest const& llmRequest)
{
bool isCancelled = false;
std::scoped_lock lkResp(mSenderMutex);
auto it = mReadyResponses.find(llmRequest.mRequestId);
// If the request is not the current request and already in the ready queue, we can cancel it.
if (it != mReadyResponses.end()
&& (!mCurrentRequest.has_value() || getCurrentRequestId() != llmRequest.mRequestId))
{
mCancelledRequests.insert(llmRequest.mRequestId);
isCancelled = true;
}
else
{
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
}
return isCancelled;
}

void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
{
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
auto& session = it->second;
auto const& connections = session.getConnections();
for (size_t i = 0; i < connections.size(); i++)
{
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager)
{
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
TLLM_CHECK(agentConnection);
agentConnection->sendReadySignal(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, isReady);
}
else
{
connections.at(i)->send(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
}
}
}

~Impl()
{
terminate();
Expand Down Expand Up @@ -463,8 +506,42 @@ class CacheSender::Impl
{
mRemainSendCount.erase(reqId);

asyncSendAndRemoveResponse(it->first, std::move(it->second));
removeResponse(it);
// Check if the request is cancelled
bool isReady = true;
{
std::scoped_lock lk(mSenderMutex);
if (mCancelledRequests.find(reqId) != mCancelledRequests.end())
{
isReady = false;
}
}
sendReadySignal(reqId, isReady);

if (isReady)
{
asyncSendAndRemoveResponse(it->first, std::move(it->second));
removeResponse(it);
}
else
{
// TODO: if the generation does not require the kv cache, the request will
// not be removed from mCancelledRequests. This should be handled by timeout.
auto it = mReadyResponses.find(mCurrentRequest.value());
TLLM_CHECK(it != mReadyResponses.end());
{
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
mCancelledRequests.erase(mCurrentRequest.value());
mRemainSendCount.erase(mCurrentRequest.value());
}
mCurrentRequest = std::nullopt;

if (mReadyResponses.empty())
{
std::unique_lock lk(mCondMutex);
mAnyReady = false;
}
}
}
mCurrentRequest = std::nullopt;
}
Expand All @@ -491,7 +568,11 @@ class CacheSender::Impl
auto const& requestInfo = recvRequestInfo();
auto reqId = requestInfo.getRequestId();

mCurrentRequest = reqId;
{
std::scoped_lock lk(mSenderMutex);
mCurrentRequest = reqId;
}

if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
{
mRemainSendCount[reqId] = getCounterpartsCount(reqId);
Expand Down Expand Up @@ -549,7 +630,7 @@ class CacheSender::Impl
void removeResponse(std::map<RequestIdType, Response>::iterator it)
{
{
std::unique_lock lkResp(mSenderMutex);
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
}
if (mReadyResponses.empty())
Expand All @@ -566,12 +647,13 @@ class CacheSender::Impl

[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
{
std::unique_lock lk(mSenderMutex);
std::scoped_lock lk(mSenderMutex);
return mReadyResponses.find(getCurrentRequestId());
}

private:
std::optional<RequestIdType> mCurrentRequest;
std::set<LlmRequest::RequestIdType> mCancelledRequests;
std::map<RequestIdType, Response> mReadyResponses;
std::mutex mSenderMutex, mCondMutex;
std::atomic<bool> mAnyReady{false}, mTerminate{false};
Expand Down Expand Up @@ -619,7 +701,7 @@ class CacheReceiver::Impl
auto promise = std::make_unique<std::promise<void>>();
auto future = promise->get_future();
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processInfo = "default";
std::string processInfo = kDefaultProcessInfo;
if (common::getEnvRequestKVCacheConcurrent())
{
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
Expand Down Expand Up @@ -703,7 +785,7 @@ class CacheReceiver::Impl

auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
std::optional<size_t> cacheBufferId = std::nullopt;
if (agentConnectionManager != nullptr)
if (agentConnectionManager)
{
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
TLLM_CHECK(cacheBufferId.has_value());
Expand All @@ -726,7 +808,7 @@ class CacheReceiver::Impl
auto const* connection = counterPartConnections[i];
// if Manager is agentConnectionManager, then send request info to agent
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager != nullptr)
if (agentConnectionManager)
{
// TODO: index -> validConnectionIdx conversion
auto validConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
Expand All @@ -751,7 +833,7 @@ class CacheReceiver::Impl
{
std::scoped_lock<std::mutex> lock(mProcessIoResouceMutex);
TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value());
std::string processString = "default";
std::string processString = kDefaultProcessInfo;
if (common::getEnvRequestKVCacheConcurrent())
{
processString = llmRequest.getDataTransceiverState().getCommState()->toString();
Expand All @@ -777,6 +859,62 @@ class CacheReceiver::Impl
connection->send(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize);
}

bool cancelRequest(LlmRequest const& llmRequest)
{

std::string processInfo = kDefaultProcessInfo;
if (common::getEnvRequestKVCacheConcurrent())
{
processInfo = llmRequest.getDataTransceiverState().getCommState()->toString();
}

bool isCancelled = false;
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
{
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
auto it = std::find_if(asyncResource->mRequestsQueue.begin(), asyncResource->mRequestsQueue.end(),
[&llmRequest](RequestAndPromise const& requestAndPromise)
{ return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; });
if (it != asyncResource->mRequestsQueue.end())
{
asyncResource->mRequestsQueue.erase(it);
isCancelled = true;
}
else
{
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
}
}
return isCancelled;
}

bool receiveReadySignal(TransferSession& session)
{
bool isReadyFinal = true;
bool isReady = false;
auto const& connections = session.getConnections();

for (size_t i = 0; i < connections.size(); i++)
{
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager)
{
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
TLLM_CHECK(agentConnection);
isReady = agentConnection->recvReadySignal(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
}
else
{
connections.at(i)->recv(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
}
isReadyFinal &= isReady;
}

return isReadyFinal;
}

~Impl()
{
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
Expand All @@ -799,6 +937,14 @@ class CacheReceiver::Impl
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
auto session = sendRequestInfo(llmRequest);
bool isReady = receiveReadySignal(session);
if (!isReady)
{
// Reuse the error state for the cancelled request.
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
return;
}
receiveSync(session);
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());

Expand Down Expand Up @@ -915,6 +1061,7 @@ class CacheReceiver::Impl
}

int mDeviceId{-1};
static constexpr char const* kDefaultProcessInfo = "default";
std::vector<std::future<void>> mRequestFutures;
std::unordered_map<std::string, std::unique_ptr<AsyncResource>> mInstanceToAsyncResource;
executor::kv_cache::ConnectionManager* mManager;
Expand Down Expand Up @@ -970,6 +1117,16 @@ RequestInfo CacheSender::recvRequestInfo()
return mImpl->recvRequestInfo();
}

bool CacheSender::cancelRequest(LlmRequest const& llmRequest)
{
return mImpl->cancelRequest(llmRequest);
}

void CacheSender::sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady)
{
mImpl->sendReadySignal(requestId, isReady);
}

CacheReceiver::CacheReceiver(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mImpl{std::unique_ptr<Impl, ImplDeleter>(new Impl(manager, selfCacheState, selfIndex, std::move(formatter)))}
Expand All @@ -993,4 +1150,14 @@ void CacheReceiver::receiveSync(TransferSession& session)
mImpl->receiveSync(session);
}

bool CacheReceiver::cancelRequest(LlmRequest const& llmRequest)
{
return mImpl->cancelRequest(llmRequest);
}

bool CacheReceiver::receiveReadySignal(TransferSession& session)
{
return mImpl->receiveReadySignal(session);
}

} // namespace tensorrt_llm::batch_manager
22 changes: 22 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct TransceiverTag
static constexpr int32_t kID_TAG{19};
static constexpr int32_t kINFO_SIZE_TAG{22};
static constexpr int32_t kINFO_TAG{32};
static constexpr int32_t kREADY_SIGNAL_TAG{42};
};

// Used to store the information that needs to be sent to the context executor to ensure the generation
Expand Down Expand Up @@ -240,6 +241,16 @@ class CacheSender
/// @param llmRequest The request object to which the data belongs.
virtual RequestInfo recvRequestInfo();

/// @brief Cancel the request.
/// @param requestId The ID used in the context phase of the current request.
/// @return Whether the request is cancelled.
virtual bool cancelRequest(LlmRequest const& llmRequest);

/// @brief Send ready signal.
/// @param requestId The ID used in the context phase of the current request.
/// @param isReady Whether the request is ready to be received.
virtual void sendReadySignal(LlmRequest::RequestIdType requestId, bool isReady);

/// @brief Destructor.
virtual ~CacheSender();

Expand Down Expand Up @@ -272,6 +283,17 @@ class CacheReceiver
virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest);

virtual void receiveSync(TransferSession& session);

/// @brief Cancel the request.
/// @param llmRequest Request object.
/// @return Whether the request is cancelled.
virtual bool cancelRequest(LlmRequest const& llmRequest);

/// @brief Receive ready signal.
/// @param session The session object.
/// @return Whether the request is ready to be received.
virtual bool receiveReadySignal(TransferSession& session);

/// @brief Destructor.
virtual ~CacheReceiver();

Expand Down
Loading
Loading