Skip to content

Commit bdb9fe2

Browse files
authored
Merge branch 'main' into user/nzmora/add_mem_logs
2 parents 547c718 + a6d20f6 commit bdb9fe2

File tree

125 files changed

+8221
-1411
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+8221
-1411
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ tensorrt_llm/deep_gemm/
4747
tensorrt_llm/deep_gemm_cpp_tllm.*.so
4848
tensorrt_llm/deep_gemm_cpp_tllm.pyi
4949
tensorrt_llm/pg_utils_bindings.*.so
50+
tensorrt_llm/flash_mla/
51+
tensorrt_llm/flash_mla_cpp_tllm.*.so
52+
tensorrt_llm/flash_mla_cpp_tllm.pyi
5053
*docs/cpp_docs*
5154
*docs/source/_cpp_gen*
5255
docs/source/**/*.rst

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@
3030
path = 3rdparty/DeepGEMM
3131
url = https://github.com/ruoqianguo/DeepGEMM.git
3232
branch = swapab_sm100
33+
[submodule "3rdparty/flash-mla"]
34+
path = 3rdparty/flash-mla
35+
url = https://github.com/deepseek-ai/FlashMLA.git

3rdparty/flash-mla

Submodule flash-mla added at 1408756

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ option(BUILD_TESTS "Build Google tests" ON)
3232
option(BUILD_BENCHMARKS "Build benchmarks" ON)
3333
option(BUILD_DEEP_EP "Build the Deep EP module" ON)
3434
option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON)
35+
option(BUILD_FLASH_MLA "Build the FlashMLA module" ON)
3536
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
3637
option(NVTX_DISABLE "Disable all NVTX features" ON)
3738
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 125 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,66 @@
1717
#pragma once
1818

1919
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
20+
#include "tensorrt_llm/runtime/iTensor.h"
2021

2122
namespace tensorrt_llm::batch_manager::kv_cache_manager
2223
{
2324

2425
class BlockIterator;
2526

26-
class BlockRange
27+
class BlockRangeForWindow
2728
{
2829
public:
29-
// C++20 std::default_sentinel_t equivalent
30+
BlockRangeForWindow(BaseKVCacheManager const* cacheManager, SizeType32 windowSize, std::vector<SizeType32> blockIds,
31+
runtime::ITensor::SharedPtr pool)
32+
: mCacheManager(cacheManager)
33+
, mWindowSize(windowSize)
34+
, mBlockIds(std::move(blockIds))
35+
, mPool(std::move(pool))
36+
{
37+
}
38+
3039
struct Sentinel
3140
{
3241
};
3342

34-
static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId,
35-
SizeType32 beam = kFIRST_AND_ONLY_BEAM)
43+
friend class BlockIterator;
44+
BlockIterator begin() const;
45+
46+
[[nodiscard]] Sentinel end() const
47+
{
48+
return {};
49+
}
50+
51+
[[nodiscard]] size_t size() const
52+
{
53+
return mBlockIds.size();
54+
}
55+
56+
private:
57+
BaseKVCacheManager const* mCacheManager;
58+
SizeType32 mWindowSize;
59+
std::vector<SizeType32> mBlockIds;
60+
runtime::ITensor::SharedPtr mPool;
61+
};
62+
63+
class BlockRange
64+
{
65+
public:
66+
static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
3667
{
37-
assert(kFIRST_AND_ONLY_BEAM == beam);
38-
auto const windowSize = firstWindowSize(cacheManager);
39-
auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
40-
return BlockRange(cacheManager, blockIds, requestId);
68+
69+
return BlockRange(cacheManager, requestId);
4170
}
4271

4372
static BlockRange fromReuseTree(
4473
BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd)
4574
{
46-
auto const windowSize = firstWindowSize(cacheManager);
75+
76+
auto poolNum = cacheManager.getNumPools();
77+
TLLM_CHECK_WITH_INFO(poolNum == 1, "Reuse tree is not supported for multiple pools or variable window size");
78+
79+
auto windowSize = cacheManager.getBlockManager().getWindowSizesMetadata().begin()->first;
4780
// Find the last block in the reuse tree for the provided full sequence of block keys
4881
auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize);
4982
// TODO: handle the case where the last block is not found
@@ -65,78 +98,104 @@ class BlockRange
6598
}
6699
// Reverse to chronological order: oldest to newest
67100
std::reverse(blockIds.begin(), blockIds.end());
68-
return BlockRange(cacheManager, blockIds, 0);
69-
}
70-
71-
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
72-
: mManager{nullptr}
73-
, mPool{std::move(pool)}
74-
, mWindowSize{0}
75-
, mRequestId{0}
76-
, mBlockIds{blockIds}
77-
{
78-
TLLM_CHECK(mPool);
101+
std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow;
102+
blockIdsPerWindow[windowSize] = blockIds;
103+
return BlockRange(cacheManager, blockIdsPerWindow, 0);
79104
}
80105

81-
[[nodiscard]] BlockIterator begin() const;
82-
83-
[[nodiscard]] Sentinel end() const
106+
void setBlockIdsForWindow(SizeType32 windowSize, std::vector<SizeType32> blockIds)
84107
{
85-
return {};
108+
TLLM_CHECK_WITH_INFO(mBlockIdsPerWindow.find(windowSize) != mBlockIdsPerWindow.end(),
109+
"Window size %d should exists", windowSize);
110+
mBlockIdsPerWindow[windowSize] = std::move(blockIds);
86111
}
87112

88-
[[nodiscard]] size_t size() const
113+
void setBlockIdsForAllWindows(std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow)
89114
{
90-
return mBlockIds.size();
115+
for (auto const& [windowSize, blockIds] : blockIdsPerWindow)
116+
{
117+
TLLM_CHECK_WITH_INFO(
118+
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d should exists", windowSize);
119+
}
120+
mBlockIdsPerWindow = std::move(blockIdsPerWindow);
91121
}
92122

93-
[[nodiscard]] std::vector<SizeType32> const& getBlockIds() const
123+
[[nodiscard]] std::unordered_map<SizeType32, std::vector<size_t>> getBlockHashesPerWindow() const
94124
{
95-
return mBlockIds;
125+
TLLM_CHECK(mManager);
126+
std::unordered_map<SizeType32, std::vector<size_t>> blockHashesPerWindow;
127+
auto& blockManager = mManager->getBlockManager();
128+
for (auto const& [windowSize, blockIds] : mBlockIdsPerWindow)
129+
{
130+
for (auto const& blockId : blockIds)
131+
{
132+
blockHashesPerWindow[windowSize].emplace_back(
133+
blockManager.getBlockById(blockId, windowSize)->getHash());
134+
}
135+
}
136+
return blockHashesPerWindow;
96137
}
97138

98-
void setBlockIds(std::vector<SizeType32> blockIds)
139+
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const
99140
{
100-
mBlockIds = std::move(blockIds);
141+
TLLM_CHECK_WITH_INFO(
142+
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize);
143+
auto pool = mPoolsPerWindow.at(windowSize).front();
144+
auto blockIds = mBlockIdsPerWindow.at(windowSize);
145+
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(pool));
101146
}
102147

103-
void updatePoolIdx(SizeType32 poolIdx)
148+
std::vector<SizeType32> getWindowSizes() const
104149
{
105-
TLLM_CHECK(mManager);
106-
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
107-
auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx);
108-
if (newWindowSize != mWindowSize)
150+
std::vector<SizeType32> windowSizes;
151+
for (auto const& [windowSize, _] : mPoolsPerWindow)
109152
{
110-
mWindowSize = newWindowSize;
111-
mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM);
153+
windowSizes.push_back(windowSize);
112154
}
155+
return windowSizes;
113156
}
114157

115-
friend class BlockIterator;
158+
std::unordered_map<SizeType32, std::vector<SizeType32>> const& getBlockIdsPerWindow() const
159+
{
160+
return mBlockIdsPerWindow;
161+
}
116162

117163
private:
118-
BlockRange(
119-
BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, LlmRequest::RequestIdType requestId)
164+
BlockRange(BaseKVCacheManager const& cacheManager,
165+
std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow, LlmRequest::RequestIdType requestId)
120166
: mManager(&cacheManager)
121-
, mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX))
122-
, mWindowSize(firstWindowSize(cacheManager))
123167
, mRequestId(requestId)
124-
, mBlockIds(std::move(blockIds))
168+
, mBlockIdsPerWindow(std::move(blockIdsPerWindow))
125169
{
170+
171+
// cacheManager.getBlockManager.getPrimaryPool(0);
172+
auto poolNum = mManager->getNumPools();
173+
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
174+
{
175+
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
176+
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
177+
}
126178
}
127179

128-
static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager)
180+
BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
181+
: mManager(&cacheManager)
182+
, mRequestId(requestId)
129183
{
130-
constexpr SizeType32 FIRST_POOL_IDX = 0;
131-
return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX);
184+
auto poolNum = mManager->getNumPools();
185+
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
186+
{
187+
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
188+
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
189+
mBlockIdsPerWindow[windowSize]
190+
= cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
191+
}
132192
}
133193

134194
private:
135195
BaseKVCacheManager const* mManager;
136-
runtime::ITensor::SharedPtr mPool;
137-
SizeType32 mWindowSize;
138-
const LlmRequest::RequestIdType mRequestId;
139-
std::vector<SizeType32> mBlockIds;
196+
LlmRequest::RequestIdType const mRequestId;
197+
std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow;
198+
std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow;
140199

141200
static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
142201
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;
@@ -151,7 +210,7 @@ class BlockIterator
151210
using reference = value_type&;
152211
using SizeType32 = tensorrt_llm::runtime::SizeType32;
153212

154-
BlockIterator(BlockRange const* range, size_t idx)
213+
BlockIterator(BlockRangeForWindow const* range, size_t idx)
155214
: mRange{range}
156215
, mIdx{idx}
157216
{
@@ -194,7 +253,7 @@ class BlockIterator
194253
return mIdx == other.mIdx && mRange == other.mRange;
195254
}
196255

197-
[[nodiscard]] bool operator==(BlockRange::Sentinel other) const
256+
[[nodiscard]] bool operator==(BlockRangeForWindow::Sentinel other) const
198257
{
199258
return mIdx == mRange->mBlockIds.size();
200259
}
@@ -210,16 +269,27 @@ class BlockIterator
210269
{
211270
if (mIdx < mRange->mBlockIds.size())
212271
{
213-
mCurrent = runtime::ITensor::slice(mRange->mPool, mRange->mBlockIds.at(mIdx), 1);
272+
if (mRange->mCacheManager != nullptr)
273+
{
274+
BlockPtr const& block = mRange->mCacheManager->getBlockManager().getBlockById(
275+
mRange->mBlockIds.at(mIdx), mRange->mWindowSize);
276+
TLLM_CHECK_WITH_INFO(block->isPrimary(), "cache transceiver only supports primary blocks");
277+
auto const blockOffset = block->getMemoryPoolBlockIndex();
278+
mCurrent = runtime::ITensor::slice(mRange->mPool, blockOffset, 1);
279+
}
280+
else
281+
{
282+
mCurrent = runtime::ITensor::slice(mRange->mPool, mRange->mBlockIds.at(mIdx), 1);
283+
}
214284
}
215285
}
216286

217-
BlockRange const* mRange;
287+
BlockRangeForWindow const* mRange;
218288
runtime::ITensor::SharedPtr mCurrent;
219289
size_t mIdx;
220290
};
221291

222-
inline BlockIterator BlockRange::begin() const
292+
inline BlockIterator BlockRangeForWindow::begin() const
223293
{
224294
return {this, 0};
225295
}

cpp/tensorrt_llm/CMakeLists.txt

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -147,24 +147,6 @@ add_subdirectory(runtime)
147147
add_subdirectory(testing)
148148
add_subdirectory(executor_worker)
149149

150-
if(ENABLE_CUFILE)
151-
find_library(
152-
CUFILE_LIBRARY cufile HINTS ${CUDAToolkit_LIBRARY_DIR}
153-
/usr/lib/${TARGET_ARCH} /usr/local/lib)
154-
if(NOT CUFILE_LIBRARY)
155-
# FATAL_ERROR if user explicitly requests with GDS if CUDA's libcufile.so is
156-
# not found.
157-
message(
158-
FATAL_ERROR
159-
"cuFile library not found. Set -DENABLE_CUFILE=OFF if cufile isn't required."
160-
)
161-
else()
162-
message(STATUS "Linking with cufile: ${CUFILE_LIBRARY}")
163-
endif()
164-
else()
165-
message(STATUS "ENABLE_CUFILE=OFF, skipping GDS linkage.")
166-
endif()
167-
168150
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
169151
set(BATCH_MANAGER_TARGET_ARCH ${TARGET_ARCH})
170152
add_subdirectory(batch_manager)
@@ -263,10 +245,6 @@ set_target_properties(
263245

264246
target_link_libraries(${SHARED_TARGET} PUBLIC ${TRTLLM_LINK_LIBS})
265247

266-
if(ENABLE_CUFILE)
267-
target_link_libraries(${SHARED_TARGET} PUBLIC ${CUFILE_LIBRARY})
268-
endif()
269-
270248
target_link_libraries(
271249
${SHARED_TARGET}
272250
PRIVATE $<LINK_LIBRARY:WHOLE_ARCHIVE,${BATCH_MANAGER_TARGET}>
@@ -320,4 +298,8 @@ if(BUILD_DEEP_GEMM)
320298
add_subdirectory(deep_gemm)
321299
endif()
322300

301+
if(BUILD_FLASH_MLA)
302+
add_subdirectory(flash_mla)
303+
endif()
304+
323305
add_subdirectory(plugins)

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
9494
target_compile_definitions(${BATCH_MANAGER_STATIC_TARGET}
9595
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
9696

97-
if(ENABLE_CUFILE)
98-
target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} PUBLIC ${CUFILE_LIBRARY})
99-
endif()
100-
10197
if(ENABLE_UCX)
10298
find_package(ucx REQUIRED)
10399
find_package(ucxx REQUIRED)

0 commit comments

Comments
 (0)