Skip to content

Conversation

@amirkl94
Copy link
Contributor

@amirkl94 amirkl94 commented Oct 21, 2025

Purpose

This PR enables FusedMoE FP8 cutlass path for models using non-gated relu2 models. This new path gives a performance gain of around 20% on output token throughput over the triton path.
This PR requires flashinfer 0.5.0 .

Tests

New parameterization in test_flashinfer.py::test_flashinfer_cutlass_moe_fp8_no_graph to verify the new non-gated activation.

Performance tests:
Ran on a single H100. Started the server (once with VLLM_USE_FLASHINFER_MOE_FP8=1 and once with VLLM_USE_FLASHINFER_MOE_FP8=0):

python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8080 --model <NemotronNanov3-modelopt-fp8-ckpt> --tensor-parallel-size 1 --pipeline-parallel-size 1 --swap-space 16 --max-num-seqs 512 --trust-remote-code --max-model-len 3800 --gpu-memory-utilization 0.9 --max-num-batched-tokens 8192 --enable-chunked-prefill --disable-log-requests

Benchmarked using:

vllm bench serve --backend vllm --host 0.0.0.0 --port $port --model $model_path --num-prompts 1280--trust-remote-code --ignore-eos --max-concurrency 256 --random-input-len 1024 --random-output-len 1024 --no-stream --dataset-name random --num-warmups 20

triton path yielded:

Request throughput (req/s):              3.47
Output token throughput (tok/s):         3551.76
Peak output token throughput (tok/s):    6116.00

cutlass path yielded:

Request throughput (req/s):              5.15
Output token throughput (tok/s):         5270.17
Peak output token throughput (tok/s):    7168.00

~17% perf gain for peak (decode), ~40% perf gain for average token throughput (includes prefill).

@amirkl94 amirkl94 force-pushed the feat/relu2-cutlass branch 2 times, most recently from 98b3df9 to ed78454 Compare November 2, 2025 14:23
@mergify
Copy link

mergify bot commented Nov 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @amirkl94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 5, 2025
Signed-off-by: Amir Klein <[email protected]>
@amirkl94 amirkl94 requested a review from tomeras91 November 5, 2025 11:30
@mergify mergify bot removed the needs-rebase label Nov 5, 2025
Signed-off-by: Amir Klein <[email protected]>
Signed-off-by: Amir Klein <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 352 to 356
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid enabling TensorRT flashinfer for relu2 activations

Removing the self.moe.is_act_and_mul guard means a flashinfer backend is now enabled whenever VLLM_USE_FLASHINFER_MOE_FP8 is set, regardless of the activation. If the user selects the latency backend (TensorRT‑LLM) and runs a relu2_no_mul model, apply() will hit the hard assertion activation == "silu" and abort instead of falling back to the existing non‑flashinfer path, which previously worked (albeit slower). Consider only enabling flashinfer when either the model is gated or the chosen backend is CUTLASS; otherwise leave flashinfer_moe_backend as None so non‑gated models continue to run.

Useful? React with 👍 / 👎.

Comment on lines 570 to 575
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
g1_alphas=layer.output1_scales_gate_scalar.squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like output1_scales_gate_scalar and output2_scales_scalar are only used in flashinfer_trtllm_moe. It's not clear from the vLLM source code how these are getting set. Does flashinfer add them to the layer's member variables?

I think we should try to keep the quantization format decoupled from the kernels used for the implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These factors are registered into the layer during process_weights_after_loading (line 566) .
They are registered to the layer when we're using cutlass or trtllm backends, using register_moe_scaling_factors .
The problem I see with trying to make this decouple is that the cutlass path uses the layers' get_fused_moe_quant_config to get the relevant scaling factors here .
There are 2 options I see here let me know which one you prefer:

  1. I can remove the if else when setting the quantization and set it the same for all paths.
  2. I can build the needed quantization for the flashinfer cutlass kernel in here , and not change the quantization in the ModelOptFusedMoEFP8 object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove the if else when setting the quantization and set it the same for all paths.
I think I'd prefer to just provide the same information in all cases. The kernel can choose to ignore it

@amirkl94 amirkl94 requested a review from tlrmchlsmth November 10, 2025 09:20
Comment on lines +359 to +366
if (
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not self.moe.is_act_and_mul
):
logger.info_once(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you are missing the override of self.flashinfer_moe_backend here

Comment on lines 570 to 575
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
g1_alphas=layer.output1_scales_gate_scalar.squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove the if else when setting the quantization and set it the same for all paths.
I think I'd prefer to just provide the same information in all cases. The kernel can choose to ignore it

Comment on lines +359 to +366
if (
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not self.moe.is_act_and_mul
):
logger.info_once(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the self.flashinfer_moe_backend override was left out

Signed-off-by: Amir Klein <[email protected]>
@amirkl94 amirkl94 requested a review from mgoin November 12, 2025 08:50
Signed-off-by: Amir Klein <[email protected]>
@github-project-automation github-project-automation bot moved this to In review in NVIDIA Nov 13, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL at the failing blackwell test

Comment on lines 85 to 86
if activation != "relu2_no_mul":
is_gated = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right as it breaks test_flashinfer_per_tensor_moe_fp8_no_graph on blackwell
https://buildkite.com/vllm/ci/builds/38920/steps/canvas?jid=019a7fad-270b-4d40-8820-e3a1e75dc35e#019a7fad-270b-4d40-8820-e3a1e75dc35e/102-2387

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it should be if activation == "relu2_no_mul: , I originally wrote it as a one liner but the pre-commit hook complained, and I fixed it incorrectly. Changed it back and this should be correct.

Signed-off-by: Amir Klein <[email protected]>
@mergify
Copy link

mergify bot commented Nov 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @amirkl94.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 15, 2025
@mergify mergify bot removed the needs-rebase label Nov 15, 2025
@amirkl94 amirkl94 requested a review from mgoin November 16, 2025 11:11
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@mgoin mgoin merged commit 03ee481 into vllm-project:main Nov 16, 2025
53 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Nov 16, 2025
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants