-
Notifications
You must be signed in to change notification settings - Fork 281
add moe_align_fused #1054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add moe_align_fused #1054
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -219,6 +219,97 @@ def moe_align1( | |||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@triton.jit | ||||||||||||||
def moe_align_fused_kernel( | ||||||||||||||
topk_ids_ptr, # [token_num, topk] | ||||||||||||||
topk_weights_ptr, # [token_num, topk] | ||||||||||||||
expert_to_index_ptr, # [expert_num, token_num * topk] | ||||||||||||||
expert_to_weight_ptr, # [expert_num, token_num * topk] | ||||||||||||||
expert_token_num_ptr, # [expert_num] | ||||||||||||||
token_num, | ||||||||||||||
topk: tl.constexpr, | ||||||||||||||
BLOCK_TOK: tl.constexpr, | ||||||||||||||
): | ||||||||||||||
token_block = tl.program_id(0) | ||||||||||||||
offs = token_block * BLOCK_TOK + tl.arange(0, BLOCK_TOK) | ||||||||||||||
mask = offs < token_num * topk | ||||||||||||||
|
||||||||||||||
# 遍历 topk | ||||||||||||||
# for k in range(topk): | ||||||||||||||
expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0) | ||||||||||||||
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0) | ||||||||||||||
|
||||||||||||||
# 用 atomic_add 给 expert 分配写位置 | ||||||||||||||
write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask) | ||||||||||||||
|
||||||||||||||
# 按 token 顺序写 index 和 weight | ||||||||||||||
|
||||||||||||||
tl.store( | ||||||||||||||
expert_to_index_ptr + expert_ids * (token_num * topk) + write_pos, | ||||||||||||||
offs, | ||||||||||||||
mask=mask, | ||||||||||||||
) | ||||||||||||||
tl.store( | ||||||||||||||
expert_to_weight_ptr + expert_ids * (token_num * topk) + write_pos, | ||||||||||||||
weights, | ||||||||||||||
mask=mask, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _get_moe_align_fused_static_key( | ||||||||||||||
topk_weights: torch.Tensor, | ||||||||||||||
) -> dict: | ||||||||||||||
topk = topk_weights.shape[1] | ||||||||||||||
return { | ||||||||||||||
"topk": topk, | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _get_moe_align_fused_configs(): | ||||||||||||||
return [ | ||||||||||||||
{ | ||||||||||||||
"BLOCK_TOK": bt, | ||||||||||||||
"num_warps": nw, | ||||||||||||||
"num_stages": ns, | ||||||||||||||
} | ||||||||||||||
for ns in [2, 3, 4, 5] | ||||||||||||||
for nw in [4, 8] | ||||||||||||||
for bt in [1024, 2048] | ||||||||||||||
] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@autotune( | ||||||||||||||
kernel_name="moe_align_fused:v1", | ||||||||||||||
configs_gen_func=_get_moe_align_fused_configs, | ||||||||||||||
static_key_func=_get_moe_align_fused_static_key, | ||||||||||||||
run_key_func=lambda topk_ids: topk_ids.shape[0], | ||||||||||||||
mutates_args=["expert_to_index", "expert_to_weight", "expert_token_num"], | ||||||||||||||
) | ||||||||||||||
def moe_align_fused( | ||||||||||||||
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None | ||||||||||||||
|
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None | |
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To accompany the removal of the redundant topk
parameter from the moe_align_fused
function signature, this call should be updated to no longer pass the topk
argument.
moe_align_fused( | |
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num | |
) | |
moe_align_fused( | |
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
topk
parameter is declared astl.constexpr
, but it is passed as a positional argument frommoe_align_fused
, which makes it a runtime value. According to the Triton documentation,constexpr
arguments must be compile-time constants and passed as keyword arguments. This will cause aTypeError
at runtime. Please remove thetl.constexpr
annotation to treat it as a regular runtime argument. The autotuner is already creating specializations for differenttopk
values via_get_moe_align_fused_static_key
, so the performance impact of this change should be minimal.