2323 WeightsLoadingConfig )
2424from ..modules .multi_stream_utils import maybe_execute_in_parallel
2525from ..speculative import SpecMetadata
26- from ..utils import Fp4QuantizedTensor
26+ from ..utils import AuxStreamType , Fp4QuantizedTensor
2727from .modeling_llama import Llama4Attention , Llama4DecoderLayer , Llama4MoE
2828
2929# Perf heuristics thresholds.
@@ -438,7 +438,8 @@ def __init__(
438438 dtype : Optional [torch .dtype ] = None ,
439439 reduce_results : bool = False ,
440440 model_config : ModelConfig = ModelConfig (),
441- aux_stream : torch .cuda .Stream = torch .cuda .Stream (),
441+ aux_stream_dict : Optional [Dict [AuxStreamType ,
442+ torch .cuda .Stream ]] = None ,
442443 weight_loading_mode : MoEWeightLoadingMode = MoEWeightLoadingMode .
443444 VANILLA ,
444445 apply_router_weight_on_input : bool = False ,
@@ -452,7 +453,7 @@ def __init__(
452453 dtype = dtype ,
453454 reduce_results = reduce_results ,
454455 model_config = model_config ,
455- aux_stream = aux_stream ,
456+ aux_stream_dict = aux_stream_dict ,
456457 weight_loading_mode = weight_loading_mode ,
457458 apply_router_weight_on_input = apply_router_weight_on_input ,
458459 )
@@ -554,6 +555,7 @@ def __init__(
554555 weight_loading_mode = MoEWeightLoadingMode .FUSED_GATE_UP_PROJ ,
555556 model_config = model_config ,
556557 apply_router_weight_on_input = True ,
558+ aux_stream_dict = {AuxStreamType .MoeChunkingOverlap : aux_stream },
557559 )
558560
559561 self .router = Llama4MinLatencyLinear (
@@ -801,7 +803,7 @@ def forward(
801803 or self .fusion_config .POST_MLP_FUSION
802804 if needs_post_allreduce and self .next_layer_layernorm is not None :
803805 if use_fp8_allreduce and self .next_attn is not None \
804- and hasattr (elf .next_attn .qkv_proj , 'input_scale' ):
806+ and hasattr (self .next_attn .qkv_proj , 'input_scale' ):
805807 hidden_states , residual = self .all_reduce (
806808 hidden_states ,
807809 all_reduce_params = AllReduceParams (
0 commit comments