Skip to content

Commit 86cfb3e

Browse files
authored
[None][feat] Update TRTLLM MoE cubins; reduce mxfp4 weight padding requirement; tighten TMA bound (NVIDIA#9025)
Signed-off-by: Anthony Chang <[email protected]>
1 parent 6dc70aa commit 86cfb3e

File tree

1,434 files changed

+22312
-9892
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,434 files changed

+22312
-9892
lines changed

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
2020
#include "tensorrt_llm/common/cudaDriverWrapper.h"
2121
#include "tensorrt_llm/common/cudaFp8Utils.h"
22+
#if ENABLE_FP4
23+
#include <cuda_fp4.h>
24+
#endif
2225
#include "tensorrt_llm/common/logger.h"
2326
#include "tensorrt_llm/common/tllmException.h"
2427
#include <algorithm>
@@ -545,6 +548,9 @@ template void printArrayInfo(__nv_bfloat16 const* ptr, uint64_t nElement, std::s
545548
#ifdef ENABLE_FP8
546549
template void printArrayInfo(__nv_fp8_e4m3 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
547550
#endif
551+
#ifdef ENABLE_FP4
552+
template void printArrayInfo(__nv_fp4_e2m1 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
553+
#endif
548554
template void printArrayInfo(uint32_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
549555
template void printArrayInfo(uint64_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
550556
template void printArrayInfo(int const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 252 additions & 58 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,50 +68,54 @@ class TrtllmGenBatchedGemmRunner
6868
int32_t configIndex) const;
6969

7070
// Generic GEMM interface
71-
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
72-
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
73-
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
74-
float const* scaleGateC, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
75-
float const* clampLimit, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
76-
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
77-
void* workspace, CUstream stream, int device, int32_t configIndex);
71+
void run(int32_t m, int32_t n, int32_t k, int32_t validM, int32_t validN, int32_t validK,
72+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
73+
void const* a, void const* sfA, void const* b, void const* sfB, void const* perTokensSfA,
74+
void const* perTokensSfB, float const* scaleC, float const* scaleGateC, float const* bias,
75+
float const* swiGluAlpha, float const* swiGluBeta, float const* clampLimit, void* c, void* outSfC,
76+
int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx,
77+
int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, CUstream stream,
78+
int device, int32_t configIndex);
7879

7980
// Block-scaling GEMM
8081
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8182
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
82-
int32_t configIndex);
83+
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
8384

8485
// Block-scaling GEMM with SwiGLU activation
8586
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8687
void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
8788
float const* clampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device,
88-
int32_t configIndex);
89+
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
8990

9091
// FP8 per-tensor scaling GEMM
9192
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
9293
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
93-
int32_t configIndex);
94+
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
9495

9596
// Get the list of configs that passed the validation based on the constructor options
9697
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const
9798
{
9899
return mPassingConfigIndices;
99100
}
100101

102+
// Get the kernel name from the config index
103+
[[nodiscard]] std::string getKernelNameFromConfigIndex(int32_t configIndex) const;
104+
101105
// Get the list of config indices that are valid for the given problem shape
102106
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
103-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
104-
int32_t maxNumCtasInBatchDim) const;
107+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
108+
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
105109

106110
// Get a default config index that is valid for the given problem shape
107111
// This will be used as the fallback config if using auto-tuning
108112
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
109-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
110-
int32_t maxNumCtasInBatchDim) const;
113+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
114+
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
111115

112116
[[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k,
113-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
114-
int32_t maxNumCtasInBatchDim) const;
117+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
118+
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
115119

116120
private:
117121
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
---
2+
AccessModifierOffset: -4
3+
AlignAfterOpenBracket: DontAlign
4+
AlignConsecutiveAssignments: None
5+
AlignConsecutiveDeclarations: None
6+
AlignOperands: false
7+
AlignTrailingComments: true
8+
AllowAllParametersOfDeclarationOnNextLine: true
9+
AllowShortBlocksOnASingleLine: Empty
10+
AllowShortCaseLabelsOnASingleLine: true
11+
AllowShortFunctionsOnASingleLine: Empty
12+
AllowShortIfStatementsOnASingleLine: false
13+
AllowShortLoopsOnASingleLine: false
14+
AlwaysBreakAfterDefinitionReturnType: None
15+
AlwaysBreakAfterReturnType: None
16+
AlwaysBreakBeforeMultilineStrings: true
17+
AlwaysBreakTemplateDeclarations: Yes
18+
BasedOnStyle: None
19+
BinPackArguments: true
20+
BinPackParameters: true
21+
BreakBeforeBinaryOperators: All
22+
BreakBeforeBraces: Allman
23+
BreakBeforeTernaryOperators: true
24+
BreakConstructorInitializersBeforeComma: true
25+
ColumnLimit: 120
26+
CommentPragmas: '^ IWYU pragma:'
27+
ConstructorInitializerAllOnOneLineOrOnePerLine: false
28+
ConstructorInitializerIndentWidth: 4
29+
ContinuationIndentWidth: 4
30+
Cpp11BracedListStyle: true
31+
DerivePointerAlignment: false
32+
DisableFormat: false
33+
ExperimentalAutoDetectBinPacking: false
34+
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
35+
IncludeBlocks: Preserve
36+
IncludeCategories:
37+
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
38+
Priority: 2
39+
- Regex: '^(<|"(gtest|isl|json)/)'
40+
Priority: 3
41+
- Regex: '.*'
42+
Priority: 1
43+
IndentCaseLabels: false
44+
IndentWidth: 4
45+
IndentWrappedFunctionNames: false
46+
KeepEmptyLinesAtTheStartOfBlocks: true
47+
Language: Cpp
48+
MacroBlockBegin: ''
49+
MacroBlockEnd: ''
50+
MaxEmptyLinesToKeep: 1
51+
NamespaceIndentation: None
52+
ObjCBlockIndentWidth: 4
53+
ObjCSpaceAfterProperty: true
54+
ObjCSpaceBeforeProtocolList: true
55+
PenaltyBreakBeforeFirstCallParameter: 19
56+
PenaltyBreakComment: 300
57+
PenaltyBreakFirstLessLess: 120
58+
PenaltyBreakString: 1000
59+
PenaltyExcessCharacter: 1000000
60+
PenaltyReturnTypeOnItsOwnLine: 60
61+
PointerAlignment: Left
62+
QualifierAlignment: Right
63+
ReflowComments: true
64+
SeparateDefinitionBlocks: Always
65+
SortIncludes: false
66+
SpaceAfterCStyleCast: true
67+
SpaceBeforeAssignmentOperators: true
68+
SpaceBeforeParens: ControlStatements
69+
SpaceInEmptyParentheses: false
70+
SpacesBeforeTrailingComments: 1
71+
SpacesInAngles: false
72+
SpacesInCStyleCastParentheses: false
73+
SpacesInContainerLiterals: true
74+
SpacesInParentheses: false
75+
SpacesInSquareBrackets: false
76+
Standard: c++14
77+
TabWidth: 4
78+
UseTab: Never

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmEnums.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
*/
1717
#pragma once
1818

19-
#include <cassert>
2019
#include <string>
20+
#include <cassert>
2121

2222
namespace batchedGemm
2323
{
@@ -34,7 +34,9 @@ enum class RouteImpl
3434
// Use LDGSTS to do the routing
3535
Ldgsts = 1,
3636
// Use UTMALDG.GATHER4 to do the routing
37-
Tma = 2
37+
Tma = 2,
38+
// Use LDG+STS to do the routing
39+
LdgPlusSts = 3
3840
};
3941

4042
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -60,6 +62,13 @@ inline bool doesRouteImplUseTma(RouteImpl mode)
6062

6163
////////////////////////////////////////////////////////////////////////////////////////////////////
6264

65+
inline bool doesRouteImplUseLdgPlusSts(RouteImpl mode)
66+
{
67+
return (mode == RouteImpl::LdgPlusSts);
68+
}
69+
70+
////////////////////////////////////////////////////////////////////////////////////////////////////
71+
6372
} // namespace batchedGemm
6473

6574
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)