Skip to content

Commit 04b52ae

Browse files
committed
[RL] support update_weights_from_tensor for mtp
1 parent 10285ec commit 04b52ae

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

python/sglang/srt/managers/scheduler_update_weights_mixin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def update_weights_from_distributed(
7676

7777
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
7878
"""Update the online model parameter from tensors."""
79-
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
79+
if self.draft_worker is not None:
80+
success, message = self.draft_worker.update_weights_from_tensor(recv_req)
81+
else:
82+
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
8083
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
8184
if success:
8285
if recv_req.flush_cache:

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
1010
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
1111
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
12+
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
1213
from sglang.srt.managers.schedule_batch import ScheduleBatch
1314
from sglang.srt.managers.scheduler import GenerationBatchResult
1415
from sglang.srt.managers.tp_worker import TpModelWorker
@@ -50,13 +51,15 @@
5051
select_top_k_tokens,
5152
)
5253
from sglang.srt.utils import (
54+
MultiprocessingSerializer,
5355
empty_context,
5456
get_available_gpu_memory,
5557
get_bool_env_var,
5658
is_cuda,
5759
is_npu,
5860
next_power_of_2,
5961
)
62+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
6063

6164
_is_npu = is_npu()
6265

@@ -984,6 +987,26 @@ def capture_for_decode(
984987
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
985988
draft_input.hidden_states = logits_output.hidden_states
986989

990+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
991+
992+
monkey_patch_torch_reductions()
993+
named_tensors = MultiprocessingSerializer.deserialize(
994+
recv_req.serialized_named_tensors[self.tp_rank]
995+
)
996+
success, message = self.model_runner.update_weights_from_tensor(
997+
named_tensors=named_tensors,
998+
load_format=recv_req.load_format,
999+
)
1000+
if not success:
1001+
return success, message
1002+
1003+
success, message = self.target_worker.model_runner.update_weights_from_tensor(
1004+
named_tensors=named_tensors,
1005+
load_format=recv_req.load_format,
1006+
)
1007+
1008+
return success, message
1009+
9871010

9881011
@torch.compile(dynamic=True, disable=_is_npu)
9891012
def get_last_loc_large_page_size_top_k_1(

0 commit comments

Comments
 (0)