Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 97 additions & 3 deletions lightllm/common/fused_moe/grouped_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,97 @@ def moe_align1(
)


@triton.jit
def moe_align_fused_kernel(
topk_ids_ptr, # [token_num, topk]
topk_weights_ptr, # [token_num, topk]
expert_to_index_ptr, # [expert_num, token_num * topk]
expert_to_weight_ptr, # [expert_num, token_num * topk]
expert_token_num_ptr, # [expert_num]
token_num,
topk: tl.constexpr,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The topk parameter is declared as tl.constexpr, but it is passed as a positional argument from moe_align_fused, which makes it a runtime value. According to the Triton documentation, constexpr arguments must be compile-time constants and passed as keyword arguments. This will cause a TypeError at runtime. Please remove the tl.constexpr annotation to treat it as a regular runtime argument. The autotuner is already creating specializations for different topk values via _get_moe_align_fused_static_key, so the performance impact of this change should be minimal.

Suggested change
topk: tl.constexpr,
topk,

BLOCK_TOK: tl.constexpr,
):
token_block = tl.program_id(0)
offs = token_block * BLOCK_TOK + tl.arange(0, BLOCK_TOK)
mask = offs < token_num * topk

# 遍历 topk
# for k in range(topk):
expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0)
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0)

# 用 atomic_add 给 expert 分配写位置
write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask)

# 按 token 顺序写 index 和 weight

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments in this kernel are in Chinese, while the rest of the codebase uses English. For consistency and better maintainability for a wider audience, please translate these comments to English.

For example:

  • Line 237: // 遍历 topk -> // Iterate over top-k experts
  • Line 242: // 用 atomic_add 给 expert 分配写位置 -> // Use atomic_add to allocate write positions for experts
  • Line 245: // 按 token 顺序写 index 和 weight -> // Write index and weight in token order

tl.store(
expert_to_index_ptr + expert_ids * (token_num * topk) + write_pos,
offs,
mask=mask,
)
tl.store(
expert_to_weight_ptr + expert_ids * (token_num * topk) + write_pos,
weights,
mask=mask,
)


def _get_moe_align_fused_static_key(
topk_weights: torch.Tensor,
) -> dict:
topk = topk_weights.shape[1]
return {
"topk": topk,
}


def _get_moe_align_fused_configs():
return [
{
"BLOCK_TOK": bt,
"num_warps": nw,
"num_stages": ns,
}
for ns in [2, 3, 4, 5]
for nw in [4, 8]
for bt in [1024, 2048]
]


@autotune(
kernel_name="moe_align_fused:v1",
configs_gen_func=_get_moe_align_fused_configs,
static_key_func=_get_moe_align_fused_static_key,
run_key_func=lambda topk_ids: topk_ids.shape[0],
mutates_args=["expert_to_index", "expert_to_weight", "expert_token_num"],
)
def moe_align_fused(
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The topk parameter in the function signature is redundant because its value is derived from topk_ids.shape within the function. The argument is also shadowed by a local variable, which is confusing and error-prone. Please remove this parameter from the function signature for clarity and simplicity. Note that this will require updating the call site in fused_experts_impl.

Suggested change
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None

):
token_num, topk = topk_ids.shape
if run_config is None:
run_config = {}
BLOCK_TOK = run_config.get("BLOCK_TOK", 256)
num_warps = run_config.get("num_warps", 4)
num_stages = run_config.get("num_stages", 3)

grid = (triton.cdiv(token_num * topk, BLOCK_TOK),)
moe_align_fused_kernel[grid](
topk_ids,
topk_weights,
expert_to_index,
expert_to_weight,
expert_token_num,
token_num,
topk,
BLOCK_TOK=BLOCK_TOK,
num_warps=num_warps,
num_stages=num_stages,
)
return expert_to_index, expert_to_weight, expert_token_num


@triton.jit
def moe_align2_kernel(
experts_token_num_ptr, # [expert_num,]
Expand Down Expand Up @@ -719,9 +810,12 @@ def fused_experts_impl(

expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda")
expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda")
moe_align(topk_ids=curr_topk_ids, out=expert_to_tokens)
expert_to_token_num = torch.empty((E,), dtype=torch.int32, device="cuda")
moe_align1(expert_to_tokens, curr_topk_weights, expert_to_weights, expert_to_token_num, topk=topk_num)
# moe_align(topk_ids=curr_topk_ids, out=expert_to_tokens)
expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda")
# moe_align1(expert_to_tokens, curr_topk_weights, expert_to_weights, expert_to_token_num, topk=topk_num)
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To accompany the removal of the redundant topk parameter from the moe_align_fused function signature, this call should be updated to no longer pass the topk argument.

Suggested change
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num
)
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights
)


reused_mblock_infos = grouped_matmul(
curr_topk_ids.numel(),
Expand Down