Skip to content

Conversation

shihaobai
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new fused Triton kernel designed to optimize the Mixture of Experts (MoE) alignment process. By consolidating previous separate alignment steps into a single, autotuned operation, this change aims to significantly enhance the performance and efficiency of MoE computations within the system.

Highlights

  • New Fused MoE Alignment Kernel: A new Triton kernel, moe_align_fused_kernel, has been introduced to efficiently handle the Mixture of Experts (MoE) token-to-expert assignment process in a fused operation.
  • Performance Optimization with Autotuning: The new moe_align_fused function is wrapped with an autotune decorator, enabling dynamic optimization of kernel launch parameters such as BLOCK_TOK, num_warps, and num_stages for improved performance.
  • Integration into MoE Pipeline: The existing moe_align and moe_align1 calls within the fused_experts_impl function have been replaced by the new moe_align_fused function, streamlining the MoE processing pipeline.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused Triton kernel, moe_align_fused, to optimize the MoE expert alignment process by combining the logic of moe_align and moe_align1. This is a valuable performance enhancement. However, I've identified a critical issue in the new kernel's signature that will prevent it from executing, along with a few medium-severity issues related to code clarity and maintainability. My feedback includes suggestions to fix the critical bug, refactor a redundant function parameter, and improve code consistency.

expert_to_weight_ptr, # [expert_num, token_num * topk]
expert_token_num_ptr, # [expert_num]
token_num,
topk: tl.constexpr,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The topk parameter is declared as tl.constexpr, but it is passed as a positional argument from moe_align_fused, which makes it a runtime value. According to the Triton documentation, constexpr arguments must be compile-time constants and passed as keyword arguments. This will cause a TypeError at runtime. Please remove the tl.constexpr annotation to treat it as a regular runtime argument. The autotuner is already creating specializations for different topk values via _get_moe_align_fused_static_key, so the performance impact of this change should be minimal.

Suggested change
topk: tl.constexpr,
topk,

Comment on lines 237 to 245
# 遍历 topk
# for k in range(topk):
expert_ids = tl.load(topk_ids_ptr + offs, mask=mask, other=0)
weights = tl.load(topk_weights_ptr + offs, mask=mask, other=0.0)

# 用 atomic_add 给 expert 分配写位置
write_pos = tl.atomic_add(expert_token_num_ptr + expert_ids, 1, mask=mask)

# 按 token 顺序写 index 和 weight

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments in this kernel are in Chinese, while the rest of the codebase uses English. For consistency and better maintainability for a wider audience, please translate these comments to English.

For example:

  • Line 237: // 遍历 topk -> // Iterate over top-k experts
  • Line 242: // 用 atomic_add 给 expert 分配写位置 -> // Use atomic_add to allocate write positions for experts
  • Line 245: // 按 token 顺序写 index 和 weight -> // Write index and weight in token order

mutates_args=["expert_to_index", "expert_to_weight", "expert_token_num"],
)
def moe_align_fused(
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The topk parameter in the function signature is redundant because its value is derived from topk_ids.shape within the function. The argument is also shadowed by a local variable, which is confusing and error-prone. Please remove this parameter from the function signature for clarity and simplicity. Note that this will require updating the call site in fused_experts_impl.

Suggested change
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, topk, run_config: Optional[dict] = None
expert_to_index, expert_to_weight, expert_token_num, topk_ids, topk_weights, run_config: Optional[dict] = None

Comment on lines 816 to 818
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To accompany the removal of the redundant topk parameter from the moe_align_fused function signature, this call should be updated to no longer pass the topk argument.

Suggested change
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights, topk=topk_num
)
moe_align_fused(
expert_to_tokens, expert_to_weights, expert_to_token_num, curr_topk_ids, curr_topk_weights
)

@hiworldwzj hiworldwzj merged commit 3956f92 into main Sep 18, 2025
1 check passed
@hiworldwzj hiworldwzj deleted the moe_align_fused branch September 18, 2025 09:36
sufubao pushed a commit that referenced this pull request Sep 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants