Skip to content

Commit 12f339f

Browse files
[None][fix] Fix the aux_stream in Llama4MinLatencyFusedMoE (#9035)
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent 9ef7eb7 commit 12f339f

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ..modules.multi_stream_utils import maybe_execute_in_parallel
4242
from ..modules.rms_norm import RMSNorm
4343
from ..speculative import SpecMetadata
44-
from ..utils import Fp4QuantizedTensor
44+
from ..utils import AuxStreamType, Fp4QuantizedTensor
4545
from .modeling_multimodal_utils import fuse_input_embeds
4646
from .modeling_speculative import SpecDecOneEngineForCausalLM
4747
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@@ -293,6 +293,7 @@ def __init__(
293293
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
294294
model_config=model_config,
295295
apply_router_weight_on_input=True,
296+
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
296297
layer_idx=layer_idx)
297298

298299
self.router = Linear(

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
WeightsLoadingConfig)
2424
from ..modules.multi_stream_utils import maybe_execute_in_parallel
2525
from ..speculative import SpecMetadata
26-
from ..utils import Fp4QuantizedTensor
26+
from ..utils import AuxStreamType, Fp4QuantizedTensor
2727
from .modeling_llama import Llama4Attention, Llama4DecoderLayer, Llama4MoE
2828

2929
# Perf heuristics thresholds.
@@ -438,7 +438,8 @@ def __init__(
438438
dtype: Optional[torch.dtype] = None,
439439
reduce_results: bool = False,
440440
model_config: ModelConfig = ModelConfig(),
441-
aux_stream: torch.cuda.Stream = torch.cuda.Stream(),
441+
aux_stream_dict: Optional[Dict[AuxStreamType,
442+
torch.cuda.Stream]] = None,
442443
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
443444
VANILLA,
444445
apply_router_weight_on_input: bool = False,
@@ -452,7 +453,7 @@ def __init__(
452453
dtype=dtype,
453454
reduce_results=reduce_results,
454455
model_config=model_config,
455-
aux_stream=aux_stream,
456+
aux_stream_dict=aux_stream_dict,
456457
weight_loading_mode=weight_loading_mode,
457458
apply_router_weight_on_input=apply_router_weight_on_input,
458459
)
@@ -554,6 +555,7 @@ def __init__(
554555
weight_loading_mode=MoEWeightLoadingMode.FUSED_GATE_UP_PROJ,
555556
model_config=model_config,
556557
apply_router_weight_on_input=True,
558+
aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream},
557559
)
558560

559561
self.router = Llama4MinLatencyLinear(
@@ -801,7 +803,7 @@ def forward(
801803
or self.fusion_config.POST_MLP_FUSION
802804
if needs_post_allreduce and self.next_layer_layernorm is not None:
803805
if use_fp8_allreduce and self.next_attn is not None \
804-
and hasattr(elf.next_attn.qkv_proj, 'input_scale'):
806+
and hasattr(self.next_attn.qkv_proj, 'input_scale'):
805807
hidden_states, residual = self.all_reduce(
806808
hidden_states,
807809
all_reduce_params=AllReduceParams(

0 commit comments

Comments
 (0)