Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
create_moe)
from ..modules.fused_moe import DeepSeekV3MoeRoutingMethod, create_moe
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -502,23 +501,18 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, do_finalize):
# max-throughput
use_dp_padding = False
if self.use_dp and self.mapping.tp_size > 1:
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
if disable_fp4_allgather() and not self.experts.enable_alltoall:
hidden_states = allgather(hidden_states,
self.mapping,
dim=0,
sizes=all_rank_num_tokens)
elif not isinstance(self.experts, CutlassFusedMoE) or (
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
use_dp_padding = True
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))

router_logits = self.gate(hidden_states)

Expand All @@ -528,7 +522,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
do_finalize=do_finalize,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
use_dp_padding=True,
)

return routed_output
Expand Down