Skip to content

Commit fe0bce0

Browse files
authored
[Docs]: Add guide for update weights (#4151)
* add guide * resolve comments
1 parent c775937 commit fe0bce0

File tree

6 files changed

+169
-0
lines changed

6 files changed

+169
-0
lines changed

docs/en/advance/update_weights.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Update Weights
2+
3+
LMDeploy supports update model weights online for scenes such as RL training. Here are the steps to do so.
4+
5+
## Step 1: Launch server
6+
7+
For pytorch backend you have to add `--distributed-executor-backend ray`.
8+
9+
```shell
10+
lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend
11+
```
12+
13+
## Step 2: Offloads weights & kv cache
14+
15+
Before update model weights, the server should offloads weights and kv cache.
16+
17+
```python
18+
from lmdeploy.utils import serialize_state_dict
19+
import requests
20+
21+
BASE_URL = 'http://0.0.0.0:23333'
22+
api_key = 'sk-xxx'
23+
24+
headers = {
25+
"Content-Type": "application/json",
26+
"Authorization": f"Bearer {api_key}",
27+
}
28+
29+
# offloads weights and kv cache with level=2
30+
response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))
31+
assert response.status_code == 200, response.status_code
32+
33+
# wake up weights, the server is ready for update weights
34+
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights']))
35+
assert response.status_code == 200, response.status_code
36+
```
37+
38+
## Step 3: Update weights
39+
40+
Split model weights into multi segments and update through `update_weights` endpoint.
41+
42+
```python
43+
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
44+
num_segment = len(segmented_state_dict)
45+
for seg_idx in range(num_segment):
46+
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
47+
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
48+
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
49+
assert response.status_code == 200, f"response.status_code = {response.status_code}"
50+
51+
```
52+
53+
**Note**: For pytorch backend, lmdeploy also supports flattened bucket tensors:
54+
55+
```python
56+
from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata
57+
58+
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
59+
num_segment = len(segmented_state_dict)
60+
for seg_idx in range(num_segment):
61+
named_tensors = list(segmented_state_dict[seg_idx].items())
62+
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
63+
metadata = bucket.get_metadata()
64+
flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)
65+
serialized_data = serialize_state_dict(flattened_tensor_data)
66+
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')
67+
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
68+
assert response.status_code == 200, f"response.status_code = {response.status_code}"
69+
```
70+
71+
## Step 4: Wakeup server
72+
73+
After update model weights, the server should onloads kv cache and provide serving again with the new updated weights.
74+
75+
```python
76+
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache']))
77+
assert response.status_code == 200, response.status_code
78+
```

docs/en/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Documentation
105105
advance/metrics.md
106106
advance/context_parallel.md
107107
advance/spec_decoding.md
108+
advance/update_weights.md
108109

109110
.. toctree::
110111
:maxdepth: 1
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# 权重更新
2+
3+
LMDeploy支持在线权重更新,方便RL训练等场景下的使用。以下是权重更新的步骤:
4+
5+
## 步骤 1: 启动服务
6+
7+
For pytorch backend you have to add `--distributed-executor-backend ray`.
8+
9+
```shell
10+
lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend
11+
```
12+
13+
## 步骤 2: 卸载权重和KV缓存
14+
15+
在权重更新前,需要调用API卸载权重和KV缓存,使推理引擎处于可更新状态:
16+
17+
```python
18+
from lmdeploy.utils import serialize_state_dict
19+
import requests
20+
21+
BASE_URL = 'http://0.0.0.0:23333'
22+
api_key = 'sk-xxx'
23+
24+
headers = {
25+
"Content-Type": "application/json",
26+
"Authorization": f"Bearer {api_key}",
27+
}
28+
29+
# offloads weights and kv cache with level=2
30+
response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))
31+
assert response.status_code == 200, response.status_code
32+
33+
# wake up weights, the server is ready for update weights
34+
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights']))
35+
assert response.status_code == 200, response.status_code
36+
```
37+
38+
## 步骤 3: 更新权重
39+
40+
将模型权重切分后调用`update_weights`API进行更新。
41+
42+
```python
43+
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
44+
num_segment = len(segmented_state_dict)
45+
for seg_idx in range(num_segment):
46+
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
47+
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
48+
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
49+
assert response.status_code == 200, f"response.status_code = {response.status_code}"
50+
51+
```
52+
53+
**注意**: 对于pytorch推理后端,lmdeploy还支持扁平化桶张量(flattened bucket tensor)传输方式:
54+
55+
```python
56+
from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata
57+
58+
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
59+
num_segment = len(segmented_state_dict)
60+
for seg_idx in range(num_segment):
61+
named_tensors = list(segmented_state_dict[seg_idx].items())
62+
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
63+
metadata = bucket.get_metadata()
64+
flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)
65+
serialized_data = serialize_state_dict(flattened_tensor_data)
66+
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')
67+
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
68+
assert response.status_code == 200, f"response.status_code = {response.status_code}"
69+
```
70+
71+
## 步骤 4: 唤醒引擎
72+
73+
权重更新后,调用API构建KV缓存,唤醒引擎,重新提供推理服务。
74+
75+
```python
76+
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache']))
77+
assert response.status_code == 200, response.status_code
78+
```

docs/zh_cn/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ LMDeploy 工具箱提供以下核心功能:
106106
advance/metrics.md
107107
advance/context_parallel.md
108108
advance/spec_decoding.md
109+
advance/update_weights.md
109110

110111
.. toctree::
111112
:maxdepth: 1

lmdeploy/cli/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def add_parser_api_server():
9999
ArgumentHelper.dllm_denoising_steps(pt_group)
100100
ArgumentHelper.dllm_confidence_threshold(pt_group)
101101
ArgumentHelper.enable_return_routed_experts(pt_group)
102+
ArgumentHelper.distributed_executor_backend(pt_group)
102103

103104
# common engine args
104105
dtype_act = ArgumentHelper.dtype(pt_group)

lmdeploy/cli/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def enable_return_routed_experts(parser):
699699
default=False,
700700
help='Whether to output routed expert ids for replay')
701701

702+
@staticmethod
702703
def add_spec_group(parser):
703704
spec_group = parser.add_argument_group('Speculative decoding arguments')
704705
spec_group.add_argument('--speculative-algorithm',
@@ -719,6 +720,15 @@ def add_spec_group(parser):
719720

720721
return spec_group
721722

723+
@staticmethod
724+
def distributed_executor_backend(parser):
725+
"""Distributed_executor_backend."""
726+
return parser.add_argument('--distributed-executor-backend',
727+
type=str,
728+
default=None,
729+
choices=['uni', 'mp', 'ray'],
730+
help='The distributed executor backend for pytorch engine.')
731+
722732

723733
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
724734
class FlexibleArgumentParser(argparse.ArgumentParser):

0 commit comments

Comments
 (0)