Skip to content

Commit abba6ad

Browse files
authored
Fix hidden_size % 128 != 0 in intranode kernels (#413)
* Fix hidden_size % 128 != 0 * Add `align_down()` function * Use the full warp to wait TMA store * Support arbitrary hidden sizes in fp8 cast * lint
1 parent 2012e31 commit abba6ad

File tree

6 files changed

+62
-38
lines changed

6 files changed

+62
-38
lines changed

csrc/config.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@ dtype_t ceil_div(dtype_t a, dtype_t b) {
1111
}
1212

1313
template <typename dtype_t>
14-
dtype_t align(dtype_t a, dtype_t b) {
14+
dtype_t align_up(dtype_t a, dtype_t b) {
1515
return ceil_div<dtype_t>(a, b) * b;
1616
}
1717

18+
template <typename dtype_t>
19+
dtype_t align_down(dtype_t a, dtype_t b) {
20+
return a / b * b;
21+
}
22+
1823
struct Config {
1924
int num_sms;
2025
int num_max_nvl_chunked_send_tokens;
@@ -36,7 +41,7 @@ struct Config {
3641
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
3742

3843
// Ceil up RDMA buffer size
39-
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
44+
this->num_max_rdma_chunked_recv_tokens = align_up<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
4045
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
4146
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push
4247
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
@@ -160,7 +165,7 @@ struct LowLatencyLayout {
160165
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
161166
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
162167
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
163-
size_t signaling_buffer_bytes_aligned = align<size_t>(signaling_buffer_bytes, 128);
168+
size_t signaling_buffer_bytes_aligned = align_up<size_t>(signaling_buffer_bytes, 128);
164169
total_bytes += signaling_buffer_bytes_aligned * 2;
165170

166171
// Assign pointers

csrc/kernels/internode.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ int get_source_meta_bytes() {
4343

4444
__host__ __device__ __forceinline__
4545
int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
46-
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
46+
return static_cast<int>(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
4747
}
4848

4949
__host__ __device__ __forceinline__
@@ -1516,8 +1516,8 @@ combine(int4* combined_x, float* combined_topk_weights,
15161516
// Load data
15171517
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
15181518
auto shifted_x = x + token_idx * hidden_int4;
1519+
tma_store_wait<0>();
15191520
if (elect_one_sync()) {
1520-
tma_store_wait<0>();
15211521
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
15221522
mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes);
15231523
}

csrc/kernels/internode_ll.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
263263
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
264264
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
265265
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
266-
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
266+
const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));
267267
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
268268

269269
// Shared between sub-warps in warp groups
@@ -584,7 +584,7 @@ combine(void* combined_x,
584584
// Use different unroll factors for send and recv phases
585585
constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
586586
constexpr int kNumRecvUnrolls = 2;
587-
constexpr int hidden_bf16_int4_pad = align(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
587+
constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
588588
EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
589589
EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
590590
EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");

csrc/kernels/intranode.cu

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
399399
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
400400
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
401401
#ifndef DISABLE_SM90_FEATURES
402-
if (elect_one_sync()) {
403-
#pragma unroll
404-
for (int i = 0; i < 2; ++ i) {
405-
tma_store_wait<0>();
402+
#pragma unroll
403+
for (int i = 0; i < 2; ++ i) {
404+
tma_store_wait<0>();
405+
if (elect_one_sync()) {
406406
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
407407
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
408408
mbarrier_wait(tma_mbarrier, tma_phase);
@@ -589,6 +589,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
589589

590590
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
591591
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
592+
int hidden_int4_aligned = align_down(hidden_int4, 32);
592593
auto x_int4 = reinterpret_cast<const int4*>(x);
593594
auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);
594595
auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);
@@ -791,8 +792,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
791792

792793
// Wait shared memory release
793794
#ifndef DISABLE_SM90_FEATURES
794-
if (elect_one_sync())
795-
tma_store_wait<0>();
795+
tma_store_wait<0>();
796796
__syncwarp();
797797
#endif
798798

@@ -837,26 +837,29 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
837837
out_dtypes[j] = static_cast<dtype_t>(values[j]);
838838

839839
#ifndef DISABLE_SM90_FEATURES
840-
// Wait TMA arrival
841-
if (elect_one_sync())
840+
if (i < hidden_int4_aligned) {
841+
// Wait TMA arrival
842842
tma_store_wait<kNumStages - 1>();
843-
__syncwarp();
844-
845-
// Write into TMA buffer
846-
auto tma_stage_idx = (i / 32) % kNumStages;
847-
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
848-
849-
// Issue TMA
850-
tma_store_fence();
851-
__syncwarp();
852-
if (elect_one_sync()) {
853-
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
854-
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
855-
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
843+
__syncwarp();
844+
845+
// Write into TMA buffer
846+
auto tma_stage_idx = (i / 32) % kNumStages;
847+
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
848+
849+
// Issue TMA
850+
tma_store_fence();
851+
__syncwarp();
852+
if (elect_one_sync()) {
853+
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
854+
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
855+
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
856+
}
857+
__syncwarp();
858+
} else {
859+
#endif
860+
recv_int4[token_idx * hidden_int4 + i] = out_int4;
861+
#ifndef DISABLE_SM90_FEATURES
856862
}
857-
__syncwarp();
858-
#else
859-
recv_int4[token_idx * hidden_int4 + i] = out_int4;
860863
#endif
861864
}
862865

csrc/kernels/utils.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,15 @@ __host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
408408
}
409409

410410
template <typename dtype_t>
411-
__host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) {
411+
__host__ __device__ constexpr dtype_t align_up(dtype_t a, dtype_t b) {
412412
return ceil_div<dtype_t>(a, b) * b;
413413
}
414414

415+
template <typename dtype_t>
416+
__host__ __device__ constexpr dtype_t align_down(dtype_t a, dtype_t b) {
417+
return a / b * b;
418+
}
419+
415420
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
416421
int& token_start_idx, int& token_end_idx) {
417422
int num_tokens_per_sm = ceil_div(num_tokens, num_sms);

tests/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,34 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
4343
return (1 - sim).item()
4444

4545

46+
def align_up(x, y):
47+
return (x + y - 1) // y * y
48+
49+
4650
def per_token_cast_to_fp8(x: torch.Tensor):
47-
assert x.dim() == 2 and x.size(1) % 128 == 0
51+
assert x.dim() == 2
4852
m, n = x.shape
49-
x_view = x.view(m, -1, 128)
50-
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
51-
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
53+
aligned_n = align_up(n, 128)
54+
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
55+
x_padded_view = x_padded.view(m, -1, 128)
56+
x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
57+
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
5258

5359

5460
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
5561
if x_fp8.numel() == 0:
5662
return x_fp8.to(torch.bfloat16)
63+
64+
assert x_fp8.dim() == 2
65+
m, n = x_fp8.shape
66+
aligned_n = align_up(n, 128)
67+
x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0)
5768
if x_scales.dtype == torch.int:
5869
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
5970
x_scales = x_scales.view(dtype=torch.float)
60-
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
71+
x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128)
6172
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
62-
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
73+
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
6374

6475

6576
def inplace_unique(x: torch.Tensor, num_slots: int):

0 commit comments

Comments
 (0)