Skip to content

Commit 4bd0b48

Browse files
authored
[Refactor] Support multi-session chat (#178)
* add some dist utils * add model utils * add termio and basicstreamer * typo * fix world size * refactor chat and tested llama1 * add internlm adapter and support stoping criteria * concat with id for internlm * update docstring * update and support llama2 * typo * move docs to docs * update docstring of session manager * update docstring * update docs * fix accel none in model * fix and add test for tensor broadcast * fix session using typing to check type * add docstrings and comprehensive condition test * unit test for dist * fix session * split unittests of utils * typo * update control flow of accel * move test model * remove main in unittest * remove some log * remove some comments
1 parent c80f3e4 commit 4bd0b48

File tree

14 files changed

+1081
-262
lines changed

14 files changed

+1081
-262
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,7 @@ For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and
122122

123123
### Inference with PyTorch
124124

125-
You have to install deepspeed first before running with PyTorch.
126-
127-
```
128-
pip install deepspeed
129-
```
125+
For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md).
130126

131127
#### Single GPU
132128

@@ -149,6 +145,12 @@ deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
149145
--seed 0
150146
```
151147

148+
You need to install deepspeed first to use this feature.
149+
150+
```
151+
pip install deepspeed
152+
```
153+
152154
## Quantization
153155

154156
In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.

docs/en/pytorch.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Pytorch
2+
3+
## Chat in command line
4+
5+
LMDeploy support chatting with PyTorch models with submodule `lmdeploy.pytorch.chat`.
6+
7+
This submodule allow user to chat with language model through command line, and optionally accelerate model using backends like deepspeed.
8+
9+
**Example 1**: Chat with default setting
10+
11+
```python
12+
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL
13+
```
14+
15+
**Example 2**: Disable sampling and chat history
16+
17+
```python
18+
python -m lmdeploy.pytorch.chat \
19+
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
20+
--temperature 0 --max-histroy 0
21+
```
22+
23+
**Example 3**: Accelerate with deepspeed inference
24+
25+
```python
26+
python -m lmdeploy.pytorch.chat \
27+
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
28+
--accel deepspeed
29+
```
30+
31+
Note: to use deepspeed, you need to install deepspeed, and if hope to accelerate InternLM, you need a customized version <https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0>
32+
33+
**Example 4**: Tensor parallel the model on 2 GPUs
34+
35+
```python
36+
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
37+
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
38+
--accel deepspeed \
39+
```
40+
41+
This module also allow the following control commands to change generation behaviors during chat.
42+
43+
- `exit`: terminate and exit chat
44+
- `config set key=value`: change generation config `key` to `value`, e.g. config temperature=0 disable sampling for following chats
45+
- `clear`: clear chat history
46+
47+
### Simple diagram of components
48+
49+
```mermaid
50+
graph LR;
51+
subgraph model specific adapter
52+
p((user_input))-->tokenize-->id((input_ids))-->decorate
53+
tmpl_ids((template_ids))-->decorate;
54+
end
55+
subgraph generate
56+
model[CausalLM_model.generate]-->gen_result(("gen_result"))
57+
gen_result-->hid
58+
gen_result-->attn((attention))
59+
end
60+
subgraph streamer
61+
model-->s[streamer]--value-->decode_single--token-->output
62+
end
63+
subgraph session_manager
64+
prepend_history-->fullid((complete_ids));
65+
trim-->prepend_history
66+
end
67+
decorate-->prepend_history
68+
hid((history_ids))-->trim;
69+
attn-->trim;
70+
fullid-->model
71+
tokenizer((tokenizer))-->decode_single
72+
tokenizer-->tokenize
73+
p-->genconfig(GenConfig)-->model
74+
```
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import logging
4+
5+
import torch.nn as nn
6+
7+
from .base import BasicAdapter, BasicAdapterFast
8+
from .internlm import InternLMAdapter
9+
from .llama2 import Llama2Adapter
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def _get_default_adapter(tokenizer):
15+
if tokenizer.is_fast:
16+
return BasicAdapterFast
17+
else:
18+
return BasicAdapter
19+
20+
21+
def init_adapter(model: nn.Module, tokenizer, adapter=None):
22+
if adapter is None:
23+
for v in model.modules():
24+
if 'InternLMModel' in v.__class__.__name__:
25+
Adapter = InternLMAdapter
26+
break
27+
elif 'LlamaModel' in v.__class__.__name__:
28+
Adapter = Llama2Adapter
29+
break
30+
else:
31+
Adapter = _get_default_adapter(tokenizer)
32+
elif adapter == 'llama1':
33+
Adapter = _get_default_adapter(tokenizer)
34+
else:
35+
raise ValueError(f'Adapter {adapter} is not allowed.')
36+
37+
logger.info(f'Using adapter {Adapter.__name__}')
38+
39+
return Adapter(tokenizer)

lmdeploy/pytorch/adapters/base.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Basic adapter suitable for general HuggingFace models."""
3+
4+
import logging
5+
import re
6+
7+
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerBase,
8+
PreTrainedTokenizerFast)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class BaseAdapter:
14+
"""Base class for all adapters.
15+
16+
Note:
17+
Adapters coordinate with the session manager to prepare input_ids.
18+
The full sequence fed to the model is as follows:
19+
20+
```
21+
adapter.start_ids
22+
adapter.encode_and_decorate(user_input_1)
23+
output_1_generated_by_model
24+
adapter.sep_ids
25+
adapter.encode_and_decorate(user_input_2)
26+
output_2_generated_by_model
27+
adapter.sep_ids
28+
adapter.encode_and_decorate(user_input_3)
29+
```
30+
31+
Thus adapter is responsible for providing model specific
32+
``start_ids``, ``sep_ids``, and method to encode single prompt.
33+
"""
34+
35+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
36+
self.tokenizer = tokenizer
37+
38+
def encode_and_decorate(self, prompt, add_special_tokens=False):
39+
"""Model specific method to encode and decorate prompt."""
40+
raise NotImplementedError
41+
42+
def decode(self, value):
43+
"""Model specific method to decode single value to string."""
44+
raise NotImplementedError
45+
46+
@property
47+
def stopping_criteria(self):
48+
"""Model specific stopping criteria for generation."""
49+
return None
50+
51+
@property
52+
def start_ids(self):
53+
"""Model specific start ids."""
54+
return [self.tokenizer.bos_token_id]
55+
56+
@property
57+
def sep_ids(self):
58+
"""Model specific separation ids."""
59+
return [self.tokenizer.bos_token_id]
60+
61+
62+
class BasicAdapter(BaseAdapter):
63+
"""Basic adapter for slow tokenizers."""
64+
65+
def encode_and_decorate(self, prompt, add_special_tokens=False):
66+
"""Encode prompt.
67+
68+
Note:
69+
we leave <bos> to session manager to add.
70+
"""
71+
input_ids = self.tokenizer.encode(
72+
prompt,
73+
add_special_tokens=add_special_tokens,
74+
return_tensors='pt',
75+
)
76+
logger.debug(f'Encode {prompt} to {input_ids}')
77+
return input_ids
78+
79+
def decode(self, value):
80+
"""Fallback when tokenizer is not fast."""
81+
82+
self.tokenizer: PreTrainedTokenizer
83+
84+
tok = self.tokenizer.decode(value)
85+
return tok + ' '
86+
87+
88+
class BasicAdapterFast(BaseAdapter):
89+
"""Basic adapter for slow tokenizers."""
90+
91+
hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
92+
93+
def encode_and_decorate(self, prompt, add_special_tokens=False):
94+
"""Encode prompt.
95+
96+
Note:
97+
we leave <bos> to session manager to add.
98+
"""
99+
input_ids = self.tokenizer.encode(
100+
prompt,
101+
add_special_tokens=add_special_tokens,
102+
return_tensors='pt',
103+
)
104+
logger.debug(f'Encode {prompt} to {input_ids}')
105+
return input_ids
106+
107+
def decode(self, value):
108+
"""Decode with fast tokenizers."""
109+
110+
self.tokenizer: PreTrainedTokenizerFast
111+
112+
tok = self.tokenizer._convert_id_to_token(value)
113+
if tok.startswith('▁'): # sentencepiece
114+
space = ' '
115+
tok = tok[1:]
116+
else:
117+
space = ''
118+
if res := self.hex_regex.match(tok):
119+
tok = chr(int(res.group(1), 16))
120+
if tok == '</s>' or tok == '\r':
121+
tok = '\n'
122+
123+
tok = space + tok
124+
125+
logger.debug(f'Decode {value} to {repr(tok)}')
126+
127+
return tok
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import logging
3+
import re
4+
5+
import torch
6+
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
7+
StoppingCriteriaList)
8+
9+
from .base import BaseAdapter
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class InternLMStoppingCriteria(StoppingCriteria):
15+
"""Stopping criteria for HF version of InternLM."""
16+
17+
def __call__(self, input_ids, *args, **kwargs) -> bool:
18+
return input_ids[0, -1] in [2, 103028]
19+
20+
21+
class InternLMAdapter(BaseAdapter):
22+
"""Adapter for InternLM.
23+
24+
InternLM use the following template and \n should be 13.
25+
26+
<bos> (no actual newline here, just for better readability)
27+
<|User|>:{prompt}<eoh>\n
28+
<|Bot|>:{model_output}<eoa>\n
29+
<|User|>:{prompt}<eoh>\n
30+
<|Bot|>:{model_output}<eoa>\n
31+
...
32+
<eos>
33+
"""
34+
35+
hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
36+
# ids of '<|User|>:'
37+
B_USER_ID = torch.tensor([[333, 352, 1621, 352, 27232]])
38+
# ids of '<eoh>\n<|Bot|>:'
39+
E_USER_ID = torch.tensor([[103027, 13, 333, 352, 23845, 352, 27232]])
40+
# ids of '<bos>'
41+
start_ids = [1]
42+
# ids of '\n'
43+
sep_ids = [13]
44+
45+
def __init__(self, tokenizer: PreTrainedTokenizerFast):
46+
self.tokenizer = tokenizer
47+
48+
def encode_and_decorate(self, prompt):
49+
r"""Encode prompt and decorate with template.
50+
51+
Note:
52+
we leave <bos> and chat history for session manager to add,
53+
so we will decorate input_ids to '<|User|>:{prompt}<eoh>\n<|Bot|>:'
54+
"""
55+
input_ids = self.tokenizer.encode(
56+
prompt,
57+
add_special_tokens=False,
58+
return_tensors='pt',
59+
)
60+
# This is f'<|User|>:{prompt}<eoh>\n<|Bot|>:'
61+
# but force \n to 13 instead of 364
62+
input_ids = torch.cat([self.B_USER_ID, input_ids, self.E_USER_ID],
63+
dim=1)
64+
return input_ids
65+
66+
def decode(self, value):
67+
"""Decode generated tokens for InternLM."""
68+
69+
tok = self.tokenizer.decode(value)
70+
if res := self.hex_regex.match(tok):
71+
tok = chr(int(res.group(1), 16))
72+
if tok == '</s>' or tok == '<eoa>' or tok == '\r':
73+
tok = '\n'
74+
75+
logger.debug(f'Decode {value} to {repr(tok)}')
76+
77+
return tok
78+
79+
@property
80+
def stopping_criteria(self):
81+
return StoppingCriteriaList([InternLMStoppingCriteria()])

0 commit comments

Comments
 (0)