@@ -62,41 +62,49 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
6262 runtime::WorldConfig const & worldConfig, executor::kv_cache::CacheState::AttentionType attentionType,
6363 std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
6464{
65-
66- std::optional<CacheTransceiver::CommType> commType;
67- if (common::getEnvUseUCXKvCache ())
68- {
69- commType = CacheTransceiver::CommType::UCX;
70- TLLM_LOG_INFO (" Enable UCX KV cache transport." );
71- }
72- else if (common::getEnvUseNixlKvCache ())
65+ if (!cacheTransceiverConfig.has_value () || !cacheTransceiverConfig.value ().getBackendType ().has_value ())
7366 {
74- commType = CacheTransceiver::CommType::NIXL ;
75- TLLM_LOG_INFO ( " Enable NIXL KV cache transport. " ) ;
67+ TLLM_LOG_INFO ( " CacheTransceiver is disabled. " ) ;
68+ return nullptr ;
7669 }
77- else if (common::getEnvUseMPIKvCache ())
70+ auto backendType = cacheTransceiverConfig.value ().getBackendType ();
71+ if (backendType.value () == executor::CacheTransceiverConfig::BackendType::DEFAULT)
7872 {
79- commType = CacheTransceiver::CommType::MPI;
80- TLLM_LOG_INFO (" Enable MPI KV cache transport." );
73+ if (common::getEnvUseUCXKvCache ())
74+ {
75+ backendType = executor::CacheTransceiverConfig::BackendType::UCX;
76+ TLLM_LOG_INFO (" Enable UCX KV cache transport." );
77+ }
78+ else if (common::getEnvUseNixlKvCache ())
79+ {
80+ backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
81+ TLLM_LOG_INFO (" Enable NIXL KV cache transport." );
82+ }
83+ else if (common::getEnvUseMPIKvCache ())
84+ {
85+ backendType = executor::CacheTransceiverConfig::BackendType::MPI;
86+ TLLM_LOG_INFO (" Enable MPI KV cache transport." );
87+ TLLM_LOG_WARNING (" MPI KV cache transport is deprecated, please use UCX or NIXL instead." );
88+ }
89+ else
90+ {
91+ backendType = executor::CacheTransceiverConfig::BackendType::UCX;
92+ }
8193 }
94+ cacheTransceiverConfig.value ().setBackendType (backendType);
8295
83- if (commType)
84- {
85- executor::kv_cache::CacheState::ModelConfig cacheStateCfg{
86- modelConfig.getNumKvHeadsPerLayer (), modelConfig.getSizePerHead (), modelConfig.getTokensPerBlock ()};
96+ executor::kv_cache::CacheState::ModelConfig cacheStateCfg{
97+ modelConfig.getNumKvHeadsPerLayer (), modelConfig.getSizePerHead (), modelConfig.getTokensPerBlock ()};
8798
88- return std::make_unique<CacheTransceiver>(cacheManager, commType.value (), cacheStateCfg, worldConfig,
89- modelConfig.getKvDataType (), attentionType, cacheTransceiverConfig);
90- }
91- return nullptr ;
99+ return std::make_unique<CacheTransceiver>(
100+ cacheManager, cacheStateCfg, worldConfig, modelConfig.getKvDataType (), attentionType, cacheTransceiverConfig);
92101}
93102
94- CacheTransceiver::CacheTransceiver (kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
103+ CacheTransceiver::CacheTransceiver (kv_cache_manager::BaseKVCacheManager* cacheManager,
95104 executor::kv_cache::CacheState::ModelConfig const & cacheStateModelCfg, runtime::WorldConfig const & worldConfig,
96105 nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType,
97106 std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
98- : mCommType {commType}
99- , mMpiGroupComm (std::addressof(tensorrt_llm::mpi::MpiComm::session()))
107+ : mMpiGroupComm (std::addressof(tensorrt_llm::mpi::MpiComm::session()))
100108 , mCacheTransceiverConfig {cacheTransceiverConfig}
101109{
102110 using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
@@ -138,59 +146,59 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
138146 }
139147 }
140148 bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA ;
141- if (mCommType == CommType::MPI || mCommType == CommType::UCX || mCommType == CommType::NIXL)
142- {
143- std::optional<size_t > maxNumTokens = std::nullopt ;
144- if (mCacheTransceiverConfig .has_value ())
145- {
146- maxNumTokens = mCacheTransceiverConfig .value ().getMaxNumTokens ();
147- }
148- mCacheTransBufferManager
149- = std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
150- if (mCommType == CommType::UCX)
151- {
152- std::lock_guard<std::mutex> lock (mDllMutex );
153- mWrapperLibHandle = dllOpen (UCX_WRAPPER_LIB_NAME);
154- TLLM_CHECK_WITH_INFO (mWrapperLibHandle != nullptr , " UCX wrapper library is not open correctly." );
155- auto load_sym = [](void * handle, char const * name)
156- {
157- void * ret = dllGetSym (handle, name);
158- TLLM_CHECK_WITH_INFO (ret != nullptr ,
159- " Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not "
160- " built with UCX support, please rebuild in UCX-enabled environment." );
161- return ret;
162- };
163- std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
164- *(void **) (&makeUcxConnectionManager) = load_sym (mWrapperLibHandle , " makeUcxConnectionManager" );
165- mManager = makeUcxConnectionManager ();
166- TLLM_LOG_INFO (" UCX Connection Manager created" );
167- }
168- else if (mCommType == CommType::NIXL)
169- {
170- mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
171- mCacheTransBufferManager .get ());
172- TLLM_LOG_INFO (" NIXL Connection Manager created" );
173- }
174- else
175- {
176- mMpiWorldComm = std::addressof (tensorrt_llm::mpi::MpiComm::world ());
177- mManager = std::make_unique<executor::kv_cache::MpiConnectionManager>(mMpiWorldComm );
178- TLLM_LOG_INFO (" MPI Connection Manager created" );
179- }
149+ TLLM_CHECK_WITH_INFO (mCacheTransceiverConfig .has_value (), " CacheTransceiverConfig is not set." );
150+ auto backendType = mCacheTransceiverConfig .value ().getBackendType ();
151+ TLLM_CHECK_WITH_INFO (
152+ backendType.has_value () && (backendType.value () != executor::CacheTransceiverConfig::BackendType::DEFAULT),
153+ " CacheTransceiverConfig::BackendType is not set." );
180154
181- using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
182- auto makeFormatter = [cacheManager, isMLA, this ]()
183- { return createCacheFormatter (cacheManager, mCacheTransBufferManager .get (), isMLA); };
155+ std::optional<size_t > maxNumTokens = mCacheTransceiverConfig .value ().getMaxTokensInBuffer ();
184156
185- mDataResponder = std::make_unique<DataResponder>(
186- std::make_unique<DataSenderImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
187- mDataRequester = std::make_unique<DataRequester>(
188- std::make_unique<DataReceiverImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
157+ mCacheTransBufferManager = std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
158+ if (backendType.value () == executor::CacheTransceiverConfig::BackendType::UCX)
159+ {
160+ std::lock_guard<std::mutex> lock (mDllMutex );
161+ mWrapperLibHandle = dllOpen (UCX_WRAPPER_LIB_NAME);
162+ TLLM_CHECK_WITH_INFO (mWrapperLibHandle != nullptr , " UCX wrapper library is not open correctly." );
163+ auto load_sym = [](void * handle, char const * name)
164+ {
165+ void * ret = dllGetSym (handle, name);
166+ TLLM_CHECK_WITH_INFO (ret != nullptr ,
167+ " Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not "
168+ " built with UCX support, please rebuild in UCX-enabled environment." );
169+ return ret;
170+ };
171+ std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
172+ *(void **) (&makeUcxConnectionManager) = load_sym (mWrapperLibHandle , " makeUcxConnectionManager" );
173+ mManager = makeUcxConnectionManager ();
174+ TLLM_LOG_INFO (" UCX Connection Manager created" );
175+ }
176+ else if (backendType.value () == executor::CacheTransceiverConfig::BackendType::NIXL)
177+ {
178+ mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
179+ mCacheTransBufferManager .get ());
180+ TLLM_LOG_INFO (" NIXL Connection Manager created" );
181+ }
182+ else if (backendType.value () == executor::CacheTransceiverConfig::BackendType::MPI)
183+ {
184+ mMpiWorldComm = std::addressof (tensorrt_llm::mpi::MpiComm::world ());
185+ mManager = std::make_unique<executor::kv_cache::MpiConnectionManager>(mMpiWorldComm );
186+ TLLM_LOG_INFO (" MPI Connection Manager created" );
189187 }
190188 else
191189 {
192- TLLM_THROW (" Unsupported communication type. " );
190+ TLLM_THROW (" Unsupported cache transceiver backend type " );
193191 }
192+
193+ using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
194+ auto makeFormatter = [cacheManager, isMLA, this ]()
195+ { return createCacheFormatter (cacheManager, mCacheTransBufferManager .get (), isMLA); };
196+
197+ mDataResponder = std::make_unique<DataResponder>(
198+ std::make_unique<DataSenderImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
199+ mDataRequester = std::make_unique<DataRequester>(
200+ std::make_unique<DataReceiverImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
201+
194202 initializeCommState ();
195203}
196204
0 commit comments