|
9 | 9 | from sglang.srt.layers.logits_processor import LogitsProcessorOutput |
10 | 10 | from sglang.srt.layers.moe.utils import speculative_moe_backend_context |
11 | 11 | from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs |
| 12 | +from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput |
12 | 13 | from sglang.srt.managers.schedule_batch import ScheduleBatch |
13 | 14 | from sglang.srt.managers.scheduler import GenerationBatchResult |
14 | 15 | from sglang.srt.managers.tp_worker import TpModelWorker |
|
50 | 51 | select_top_k_tokens, |
51 | 52 | ) |
52 | 53 | from sglang.srt.utils import ( |
| 54 | + MultiprocessingSerializer, |
53 | 55 | empty_context, |
54 | 56 | get_available_gpu_memory, |
55 | 57 | get_bool_env_var, |
56 | 58 | is_cuda, |
57 | 59 | is_npu, |
58 | 60 | next_power_of_2, |
59 | 61 | ) |
| 62 | +from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions |
60 | 63 |
|
61 | 64 | _is_npu = is_npu() |
62 | 65 |
|
@@ -984,6 +987,26 @@ def capture_for_decode( |
984 | 987 | draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) |
985 | 988 | draft_input.hidden_states = logits_output.hidden_states |
986 | 989 |
|
| 990 | + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): |
| 991 | + |
| 992 | + monkey_patch_torch_reductions() |
| 993 | + named_tensors = MultiprocessingSerializer.deserialize( |
| 994 | + recv_req.serialized_named_tensors[self.tp_rank] |
| 995 | + ) |
| 996 | + success, message = self.model_runner.update_weights_from_tensor( |
| 997 | + named_tensors=named_tensors, |
| 998 | + load_format=recv_req.load_format, |
| 999 | + ) |
| 1000 | + if not success: |
| 1001 | + return success, message |
| 1002 | + |
| 1003 | + success, message = self.target_worker.model_runner.update_weights_from_tensor( |
| 1004 | + named_tensors=named_tensors, |
| 1005 | + load_format=recv_req.load_format, |
| 1006 | + ) |
| 1007 | + |
| 1008 | + return success, message |
| 1009 | + |
987 | 1010 |
|
988 | 1011 | @torch.compile(dynamic=True, disable=_is_npu) |
989 | 1012 | def get_last_loc_large_page_size_top_k_1( |
|
0 commit comments