Skip to content

Commit 78a4b44

Browse files
Fix dpsk-r1-fp4 tp8 by reverting two commits (#13162 and #13341) (#13348)
Co-authored-by: Kangyan-Zhou <[email protected]>
1 parent f969664 commit 78a4b44

File tree

5 files changed

+14
-17
lines changed

5 files changed

+14
-17
lines changed

python/sglang/srt/eplb/expert_location.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -284,17 +284,9 @@ def update(
284284
# -------------------------------- usage ------------------------------------
285285

286286
def logical_to_all_physical(
287-
self,
288-
layer_id: int,
289-
logical_expert_id: int,
290-
require_global_experts: bool = False,
287+
self, layer_id: int, logical_expert_id: int
291288
) -> List[int]:
292289
# Use CPU copy to avoid GPU→CPU sync on every call, which is expensive in update weights scenario
293-
if require_global_experts:
294-
num_physical_experts = self.logical_to_all_physical_map_cpu[layer_id].shape[
295-
-1
296-
]
297-
return list(torch.arange(0, num_physical_experts))
298290
return [
299291
physical_expert_id
300292
for physical_expert_id in self.logical_to_all_physical_map_cpu[
@@ -363,10 +355,14 @@ def _compute_logical_to_all_physical_map(
363355
)
364356

365357
# Replace by the nearest physical expert
366-
if nearest_expert != -1:
367-
logical_to_all_physical_map[layer_id][logical_expert_id] = [
368-
nearest_expert
369-
]
358+
mapped_physical_experts = logical_to_all_physical_map[layer_id][
359+
logical_expert_id
360+
]
361+
if (
362+
nearest_expert != -1
363+
and nearest_expert not in mapped_physical_experts
364+
):
365+
mapped_physical_experts[0] = nearest_expert
370366

371367
logical_to_all_physical_map = _pad_nested_array(
372368
logical_to_all_physical_map, pad_value=-1

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,9 @@ def weight_loader(
539539
# This is a shared expert.
540540
physical_expert_ids = [expert_id]
541541
else:
542-
require_global_experts = getattr(
543-
param, "_sglang_require_global_experts", False
544-
)
545542
physical_expert_ids = (
546543
global_expert_location_metadata.logical_to_all_physical(
547-
self.layer_id, expert_id, require_global_experts
544+
self.layer_id, expert_id
548545
)
549546
)
550547

@@ -1129,6 +1126,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
11291126
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
11301127
local_num_experts=self.num_local_experts,
11311128
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
1129+
tile_tokens_dim=None,
11321130
routing_method_type=RoutingMethodType.DeepSeekV3,
11331131
do_finalize=True,
11341132
output=symm_output,

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,7 @@ def apply_with_router_logits(
12451245
routed_scaling_factor=(
12461246
routed_scaling_factor if routed_scaling_factor is not None else 1.0
12471247
),
1248+
tile_tokens_dim=None,
12481249
routing_method_type=routing_method_type,
12491250
use_shuffled_weight=False,
12501251
)

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,7 @@ def apply(
739739
else 1.0
740740
),
741741
use_routing_scales_on_input=use_routing_scales_on_input,
742+
tile_tokens_dim=None,
742743
routing_method_type=routing_method_type,
743744
)
744745

python/sglang/srt/layers/quantization/mxfp4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def apply(
681681
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
682682
layer.num_local_experts, # local num experts
683683
None,
684+
None, # tile_tokens_dim
684685
1, # routing_method_type, renormalize
685686
True, # do finalize
686687
output=symm_output,

0 commit comments

Comments
 (0)