Skip to content

Commit 6dc70aa

Browse files
[https://nvbugs/5613089][fix] Fix the rank to access all_rank_chunk_size_list when chunked MoE is used (NVIDIA#8723)
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent d16b1a8 commit 6dc70aa

File tree

6 files changed

+10
-9
lines changed

6 files changed

+10
-9
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,10 +727,10 @@ def choose_one(
727727
new_tuning_failure_occured = False
728728

729729
for p in profiles:
730-
tensors = self._prepare_input_tensors(p, inputs)
731730
is_cache_hit, *_ = self.profiling_cache.search_cache(
732731
custom_op, runners, p.get_opt_shapes(), tuning_config)
733732
if not is_cache_hit:
733+
tensors = self._prepare_input_tensors(p, inputs)
734734
# Initialize runner and tactic as None in case of no valid tactic or runners are found
735735
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
736736
custom_op, runners, tensors, p, tuning_config, **kwargs)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def forward_impl(
669669
all_rank_num_tokens_list = [[
670670
val[idx_chunk] for val in all_rank_chunk_size_list
671671
] for idx_chunk in range(num_chunks)]
672-
chunk_size_list = all_rank_chunk_size_list[self.rank]
672+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
673673
else:
674674
all_rank_num_tokens_list = [None] * num_chunks
675675
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -735,7 +735,7 @@ def _reducescatter_or_allreduce(x_, idx):
735735
outputs = torch.cat(outputs_list)
736736

737737
if self.use_dp and self.parallel_size > 1:
738-
rank = self.mapping.tp_rank
738+
rank = self.parallel_rank
739739
outputs = outputs[:all_rank_num_tokens[rank]]
740740
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
741741
return outputs
@@ -765,7 +765,7 @@ def forward_fake(
765765
is_nvfp4_input = isinstance(x, Fp4QuantizedTensor)
766766
data_type = output_dtype if is_nvfp4_input else x.dtype
767767
num_tokens = all_rank_num_tokens[
768-
self.mapping.tp_rank] if all_rank_num_tokens else x.shape[0]
768+
self.parallel_rank] if all_rank_num_tokens else x.shape[0]
769769
hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1)
770770
top_k = self.routing_method.experts_per_token
771771
return x.new_empty((num_tokens, top_k, hidden_size),

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def forward_impl(
706706
all_rank_num_tokens_list = [[
707707
val[idx_chunk] for val in all_rank_chunk_size_list
708708
] for idx_chunk in range(num_chunks)]
709-
chunk_size_list = all_rank_chunk_size_list[self.rank]
709+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
710710
else:
711711
all_rank_num_tokens_list = [None] * num_chunks
712712
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -778,6 +778,6 @@ def _reducescatter_or_allreduce(x_, idx):
778778
outputs = torch.cat(outputs_list)
779779

780780
if self.use_dp and self.parallel_size > 1:
781-
rank = self.mapping.tp_rank
781+
rank = self.parallel_rank
782782
outputs = outputs[:all_rank_num_tokens[rank]]
783783
return outputs

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def forward_impl(
805805
self._load_balancer_done_set_cpu_stage(is_last_call)
806806

807807
if use_dp_padding:
808-
rank = self.mapping.tp_rank
808+
rank = self.parallel_rank
809809
final_hidden_states = final_hidden_states[:
810810
all_rank_num_tokens[rank]]
811811

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
762762
] for idx_chunk in range(num_chunks)]
763763
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
764764
num_chunks)
765-
chunk_size_list = all_rank_chunk_size_list[self.rank]
765+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
766766
if use_all_to_all:
767767
all_rank_num_tokens_list = [[
768768
1 if val == 0 else val for val in val_list
@@ -850,7 +850,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
850850
self.event_dict[EventType.MoeChunkingOverlap].record()
851851
self.event_dict[EventType.MoeChunkingOverlap].wait()
852852
outputs = torch.cat(outputs_list)
853-
rank = self.mapping.tp_rank
853+
rank = self.parallel_rank
854854
outputs = outputs[:all_rank_num_tokens[rank]]
855855
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
856856
return outputs

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184

185185
# All ranks participate in allreduce regardless of EP/TP combination
186186
self.mapping = model_config.mapping
187+
self.parallel_rank = self.mapping.tp_rank
187188
self.parallel_size = self.mapping.tp_size
188189
self.intermediate_size_per_partition = intermediate_size // self.tp_size
189190

0 commit comments

Comments
 (0)