Skip to content

Commit 4c9a79b

Browse files
authored
Merge branch 'main' into clean_cuda_graph
2 parents ee1d352 + dddfcdd commit 4c9a79b

File tree

79 files changed

+3189
-1114
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+3189
-1114
lines changed

constraints.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
1-
# These vulnerabilities were inherited from the base image (pytorch:25.06-py3) and should be removed when the base image
1+
# These vulnerabilities were inherited from the base image (pytorch:25.10-py3) and should be removed when the base image
22
# is updated.
3-
4-
# WAR against https://github.com/advisories/GHSA-8qvm-5x2c-j2w7
5-
protobuf>=4.25.8

cpp/tensorrt_llm/common/customAllReduceUtils.h

Lines changed: 255 additions & 0 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ public:
134134
// corresponding CTA has not been launched.
135135
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
136136
{
137-
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
137+
asm volatile(
138+
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
138139
}
140+
// Single release fence
141+
asm volatile("fence.release.sys;");
142+
139143
while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
140144
{
141145
}

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,30 @@ inline std::string toString(AllReduceFusionOp op)
106106
return oss.str();
107107
}
108108

109+
inline std::ostream& operator<<(std::ostream& os, AllReduceStrategyType op)
110+
{
111+
switch (op)
112+
{
113+
case AllReduceStrategyType::NCCL: os << "NCCL"; break;
114+
case AllReduceStrategyType::MIN_LATENCY: os << "MIN_LATENCY"; break;
115+
case AllReduceStrategyType::UB: os << "UB"; break;
116+
case AllReduceStrategyType::AUTO: os << "AUTO"; break;
117+
case AllReduceStrategyType::ONESHOT: os << "ONESHOT"; break;
118+
case AllReduceStrategyType::TWOSHOT: os << "TWOSHOT"; break;
119+
case AllReduceStrategyType::LOWPRECISION: os << "LOWPRECISION"; break;
120+
case AllReduceStrategyType::MNNVL: os << "MNNVL"; break;
121+
case AllReduceStrategyType::NCCL_SYMMETRIC: os << "NCCL_SYMMETRIC"; break;
122+
}
123+
return os;
124+
}
125+
126+
inline std::string toString(AllReduceStrategyType op)
127+
{
128+
std::ostringstream oss;
129+
oss << op;
130+
return oss.str();
131+
}
132+
109133
struct AllReduceFusionParams
110134
{
111135
AllReduceFusionParams()

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ set_cuda_architectures(fb_gemm_src 89 90 100f 120f)
205205
# ${INSTANTIATION_GENERATION_DIR}/fp8_rowwise_gemm)
206206

207207
add_library(fp8_blockscale_gemm_src STATIC ${FP8_BLOCKSCALE_GEMM_SRC_CU})
208-
set_cuda_architectures(fp8_blockscale_gemm_src 89 90 100f)
208+
set_cuda_architectures(fp8_blockscale_gemm_src 89 90 100f 120f)
209209

210210
set(GEMM_SWIGLU_SM90_SRC_CU
211211
${CMAKE_CURRENT_SOURCE_DIR}/fused_gated_gemm/gemm_swiglu_e4m3.cu)

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,16 +1622,15 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
16221622
dim3 grid = dim3(grid_m, grid_n, grid_k);
16231623
dim3 block = dim3(kThreadCount, 1, 1);
16241624

1625-
if (kSmemSize > (48 << 10))
1626-
{
1627-
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
1628-
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
1629-
auto result = cudaGetLastError();
1630-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
1631-
}
1625+
auto result = cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
1626+
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
1627+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
16321628

16331629
ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>
16341630
<<<grid, block, kSmemSize, stream>>>(shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b);
1631+
1632+
result = cudaGetLastError();
1633+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel runtime error: %s", cudaGetErrorString(result));
16351634
}
16361635

16371636
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
@@ -1643,7 +1642,7 @@ void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b
16431642
}
16441643
#ifndef PLACEHOLDER_KERNELS
16451644
int arch = tensorrt_llm::common::getSMVersion();
1646-
if (arch == 89)
1645+
if (arch == 89 || arch == 120)
16471646
{
16481647
gemm_dispatch_sm89(mat_a, mat_b, mat_d, scales_a, scales_b, shape_m, shape_n, shape_k, stream);
16491648
return;
@@ -1883,7 +1882,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
18831882
}
18841883

18851884
int arch = tensorrt_llm::common::getSMVersion();
1886-
if (arch == 89)
1885+
if (arch == 89 || arch == 120)
18871886
{
18881887
strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
18891888
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);

cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3FusedAGemm.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,8 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
601601
}
602602
}
603603
__syncthreads();
604+
asm volatile("griddepcontrol.wait;");
605+
asm volatile("griddepcontrol.launch_dependents;");
604606

605607
if (warp_idx < 2)
606608
{
@@ -622,7 +624,6 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
622624
mma_computer.issue_mainloop();
623625
mma_computer.epi();
624626
}
625-
asm volatile("griddepcontrol.launch_dependents;");
626627
#endif
627628
}
628629

0 commit comments

Comments
 (0)