Skip to content

Conversation

@Victarry
Copy link
Contributor

@Victarry Victarry commented Nov 11, 2025

What does this PR do ?

Design doc: https://docs.google.com/document/d/1whtnUiw1hpfdkjFss_g5P8fyIT9xBA5XJvmdklFes48/edit?usp=sharing
Changelog:

  • Add shared expert overlap for FlexDispatcher
  • Add stream wait for cases where CUDA_DEVICE_MAX_CONNECTIONS> 1 to prevent shared expert GEMM launched too early.
  • Change fc1 location of shared experts in A2A dispatcher for better overlap.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

1. Add shared expert overlap for FlexDispatcher
2. Add stream wait for cases where CUDA_DEVICE_MAX_CONNECTIONS> 1 to prevent shared expert GEMM launched too early.
3. Change fc1 location of shared experts in A2A dispatcher for better overlap.
@Victarry Victarry self-assigned this Nov 11, 2025
@Victarry Victarry requested review from a team as code owners November 11, 2025 09:22
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 11, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@kvareddy kvareddy requested a review from fanshiqing November 11, 2025 09:49
@kvareddy
Copy link
Contributor

@fanshiqing can you please take a look at this MR?

@yanring yanring added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Nov 11, 2025

if self.stream is None:
self.stream = torch.cuda.Stream()
if SharedExpertMLP.stream is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, isn't this a ClassVar now? Will we never need 2 different stream for the same class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. It's intended. There will not be two shared experts overlap with each other.
Moreover, if we creating a new stream for each instance, PyTorch may run out of streams from stream pool and reuse existing stream. This may cause interference and unwanted behavior.

@Victarry Victarry requested review from a team as code owners November 21, 2025 01:36
@jaredcasper jaredcasper added this to the Core 0.16 milestone Nov 21, 2025
@jaredcasper
Copy link
Contributor

/ok to test 3e308ac

group=group,
)
if use_nccl_stream:
handle = torch.distributed.all_to_all_single(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the differene between if and else paths?

@kvareddy
Copy link
Contributor

@fanshiqing can you please take a look at this MR?

@yanring yanring changed the title [MoE] Improvement of shared expert overlap. [MoE] Improvement of shared expert overlap, support shared expert overlap for FlexDispatcher Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review. module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants