Skip to content

Commit 9bfe03c

Browse files
authored
Decode generated token_ids incrementally (#309)
* add incremental decoding for turbomind * update TIS * fix triton post processing * update doc * fix typo * SentencePieceTokenizer incremental decode, add qwen message prompt * docstring * update bot
1 parent 22e8b2c commit 9bfe03c

File tree

8 files changed

+135
-34
lines changed

8 files changed

+135
-34
lines changed

docs/en/restful_api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ Generate:
9090
curl http://{server_ip}:{server_port}/generate \
9191
-H "Content-Type: application/json" \
9292
-d '{
93-
"model": "internlm-chat-7b",
94-
"prompt": "Hello! Ho are you?",
93+
"prompt": "Hello! How are you?",
94+
"instance_id": 1,
9595
"sequence_start": true,
9696
"sequence_end": true
9797
}'

docs/zh_cn/restful_api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ curl http://{server_ip}:{server_port}/v1/models
9292
curl http://{server_ip}:{server_port}/generate \
9393
-H "Content-Type: application/json" \
9494
-d '{
95-
"model": "internlm-chat-7b",
96-
"prompt": "Hello! Ho are you?",
95+
"prompt": "Hello! How are you?",
96+
"instance_id": 1,
9797
"sequence_start": true,
9898
"sequence_end": true
9999
}'

lmdeploy/model.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,29 @@ def get_prompt(self, prompt, sequence_start=True):
256256
else:
257257
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'
258258

259+
def messages2prompt(self, messages, sequence_start=True):
260+
"""Return the prompt that is concatenated with other elements in the
261+
chat template.
262+
263+
Args:
264+
messages (str | List): user's input prompt
265+
sequence_start (bool): flag to start the sequence
266+
Returns:
267+
str: the concatenated prompt
268+
"""
269+
if isinstance(messages, str):
270+
return self.get_prompt(messages, sequence_start)
271+
system, users, assistants = self._translate_messages(messages)
272+
system = self.system if not system else system
273+
ret = f'<BOS>{system}{self.meta_instruction}{self.eosys}'
274+
for user, assistant in zip(users, assistants):
275+
if assistant:
276+
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}' \
277+
f'{assistant}'
278+
else:
279+
ret += f'\n{self.user}{user}{self.eoh}\n{self.assistant}'
280+
return ret
281+
259282
@property
260283
def stop_words(self):
261284
"""Return the stop-words' token ids."""
@@ -360,6 +383,29 @@ def get_prompt(self, prompt, sequence_start=True):
360383
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
361384
f'\n{self.im_start}assistant\n'
362385

386+
def messages2prompt(self, messages, sequence_start=True):
387+
"""Return the prompt that is concatenated with other elements in the
388+
chat template.
389+
390+
Args:
391+
messages (str | List): user's input prompt
392+
Returns:
393+
str: the concatenated prompt
394+
"""
395+
if isinstance(messages, str):
396+
return self.get_prompt(messages, sequence_start)
397+
system, users, assistants = self._translate_messages(messages)
398+
system = self.system if not system else system
399+
ret = f'{self.im_start}system\n{system}{self.im_end}'
400+
for user, assistant in zip(users, assistants):
401+
if assistant:
402+
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
403+
f'\n{self.im_start}assistant\n{assistant}'
404+
else:
405+
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
406+
f'\n{self.im_start}assistant\n'
407+
return ret
408+
363409
@property
364410
def stop_words(self):
365411
"""Return the stop-words' token ids."""

lmdeploy/serve/async_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,13 @@ async def generate(
138138
random_seed=seed if sequence_start else None):
139139
res, tokens = outputs[0]
140140
# decode res
141-
response = self.tokenizer.decode(res)[response_size:]
141+
response = self.tokenizer.decode(res.tolist(),
142+
offset=response_size)
142143
# response, history token len,
143144
# input token len, gen token len
144145
yield GenOut(response, self.steps[str(session_id)],
145146
len(input_ids), tokens, finish_reason)
146-
response_size += len(response)
147+
response_size = tokens
147148

148149
# update step
149150
self.steps[str(session_id)] += len(input_ids) + tokens
@@ -229,7 +230,8 @@ async def generate_openai(
229230
random_seed=seed if sequence_start else None):
230231
res, tokens = outputs[0]
231232
# decode res
232-
response = self.tokenizer.decode(res[response_size:])
233+
response = self.tokenizer.decode(res.tolist(),
234+
offset=response_size)
233235
# response, history token len, input token len, gen token len
234236
yield GenOut(response, self.steps[str(session_id)],
235237
len(input_ids), tokens, finish_reason)

lmdeploy/serve/turbomind/chatbot.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -599,14 +599,12 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
599599
Yields:
600600
tuple: status, text, generated token number
601601
"""
602-
offset = n_input_token + preseq_length
603602
status, res, n_token = None, '', 0
604603
while True:
605604
result = res_queue.get()
606605
if result is None:
607606
status = StatusCode.TRITON_STREAM_END
608607
res = session.response
609-
n_token = session.sequence_length - offset
610608
session.status = StatusCode.TRITON_STREAM_END
611609
break
612610
if 'errcode' in result:
@@ -629,30 +627,29 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
629627
output_ids = result.as_numpy('output_ids')
630628

631629
session.sequence_length = sequence_length.squeeze()
632-
sequence_length = sequence_length - offset
633-
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
630+
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
631+
output_ids = output_ids[:, :, n_input_token +
632+
preseq_length:sequence_length.squeeze(
633+
)]
634+
last_token_id = output_ids[-1, -1, -1]
634635
if last_token_id == eos_id:
635636
session.sequence_length = session.sequence_length - 1
636-
sequence_length = sequence_length - 1
637-
638-
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
639-
sequence_length = sequence_length.reshape(
640-
(1, sequence_length.shape[-1]))
637+
output_ids = output_ids[:, :, :-1]
641638

642639
if profile_generation:
643640
yield (StatusCode.TRITON_STREAM_ING,
644641
'postprocessing is ignored during profiling '
645-
'token generation', sequence_length.squeeze())
642+
'token generation', output_ids.shape[-1])
646643
continue
647-
output_str = postprocess(output_ids[:, :, offset:],
648-
sequence_length)
644+
output_str = postprocess(
645+
output_ids, np.array([[n_token]], dtype=np.uint32))
646+
n_token = output_ids.shape[-1]
649647
text = output_str[0].decode()
650648
if display:
651-
new_text = text[len(session.response):]
652-
print(new_text, end='', flush=True)
653-
session.response = text
649+
print(text, end='', flush=True)
650+
session.response += text
654651
yield (StatusCode.TRITON_STREAM_ING, session.response,
655-
sequence_length.squeeze())
652+
output_ids.shape[-1])
656653
except Exception as e:
657654
logger.error(f'catch exception: {e}')
658655

lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _postprocessing(self, tokens_batch, sequence_length):
123123
outputs = []
124124
for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
125125
for tokens, _len in zip(beam_tokens, beam_len):
126-
output = self.tokenizer.decode(tokens[:_len])
126+
output = self.tokenizer.decode(tokens, _len)
127127
output = output.encode('utf8')
128128
outputs.append(output)
129129
return outputs

lmdeploy/turbomind/chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def main(model_path,
9999
random_seed=seed if nth_round == 1 else None):
100100
res, tokens = outputs[0]
101101
# decode res
102-
response = tokenizer.decode(res)[response_size:]
102+
response = tokenizer.decode(res.tolist(), offset=response_size)
103103
response = valid_str(response)
104104
print(f'{response}', end='', flush=True)
105-
response_size += len(response)
105+
response_size = tokens
106106

107107
# update step
108108
step += len(input_ids) + tokens

lmdeploy/turbomind/tokenizer.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import json
33
import os.path as osp
4-
from typing import Sequence, Union
4+
from typing import Optional, Sequence, Union
55

66
import torch
77

@@ -16,6 +16,7 @@ class SentencePieceTokenizer:
1616
def __init__(self, model_file: str):
1717
from sentencepiece import SentencePieceProcessor
1818
self.model = SentencePieceProcessor(model_file=model_file)
19+
self._no_prefix_space_tokens = None
1920

2021
@property
2122
def vocab_size(self):
@@ -32,6 +33,24 @@ def eos_token_id(self):
3233
"""end of the sentence token id."""
3334
return self.model.eos_id()
3435

36+
@property
37+
def no_prefix_space_tokens(self):
38+
"""tokens without prefix space."""
39+
if self._no_prefix_space_tokens is None:
40+
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
41+
self._no_prefix_space_tokens = {
42+
i
43+
for i, tok in enumerate(vocab) if not tok.startswith('▁')
44+
}
45+
return self._no_prefix_space_tokens
46+
47+
def _maybe_add_prefix_space(self, tokens, decoded):
48+
"""maybe add prefix space for incremental decoding."""
49+
if len(tokens) and tokens[0] not in self.no_prefix_space_tokens:
50+
return ' ' + decoded
51+
else:
52+
return decoded
53+
3554
def encode(self, s: str):
3655
"""Tokenize a prompt.
3756
@@ -50,17 +69,23 @@ def encode(self, s: str):
5069
add_eos = True
5170
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
5271

53-
def decode(self, t: Sequence[int]):
72+
def decode(self, t: Sequence[int], offset: Optional[int] = None):
5473
"""De-tokenize.
5574
5675
Args:
5776
t (List[int]): a list of token ids
77+
offset (int): for incrementally decoding. Default to None, which
78+
means not applied.
5879
Returns:
5980
str: text of decoding tokens
6081
"""
6182
if isinstance(t, torch.Tensor):
6283
t = t.tolist()
63-
return self.model.Decode(t)
84+
t = t[offset:]
85+
out_string = self.model.Decode(t)
86+
if offset:
87+
out_string = self._maybe_add_prefix_space(t, out_string)
88+
return out_string
6489

6590
def __call__(self, s: Union[str, Sequence[str]]):
6691
"""Tokenize prompts.
@@ -86,7 +111,7 @@ class HuggingFaceTokenizer:
86111
"""
87112

88113
def __init__(self, model_dir: str):
89-
from transformers import AutoTokenizer
114+
from transformers import AutoTokenizer, LlamaTokenizerFast
90115
model_file = osp.join(model_dir, 'tokenizer.model')
91116
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
92117
model_file_exists = osp.exists(model_file)
@@ -95,6 +120,8 @@ def __init__(self, model_dir: str):
95120
'It may take long time to initialize the tokenizer.')
96121
self.model = AutoTokenizer.from_pretrained(model_dir,
97122
trust_remote_code=True)
123+
self.need_padding = isinstance(self.model, LlamaTokenizerFast)
124+
self._no_prefix_space_tokens = None
98125
# save tokenizer.json to reuse
99126
if not osp.exists(backend_tokenizer_file) and model_file_exists:
100127
if hasattr(self.model, 'backend_tokenizer'):
@@ -122,6 +149,26 @@ def eos_token_id(self):
122149
"""end of the sentence token id."""
123150
return self.model.eos_token_id
124151

152+
@property
153+
def no_prefix_space_tokens(self):
154+
"""tokens without prefix space."""
155+
if self._no_prefix_space_tokens is None:
156+
vocab = self.model.convert_ids_to_tokens(
157+
list(range(self.vocab_size)))
158+
self._no_prefix_space_tokens = {
159+
i
160+
for i, tok in enumerate(vocab) if not tok.startswith('▁')
161+
}
162+
return self._no_prefix_space_tokens
163+
164+
def _maybe_add_prefix_space(self, tokens, decoded):
165+
"""maybe add prefix space for incremental decoding."""
166+
if self.need_padding and len(
167+
tokens) and tokens[0] not in self.no_prefix_space_tokens:
168+
return ' ' + decoded
169+
else:
170+
return decoded
171+
125172
def encode(self, s: str):
126173
"""Tokenize a prompt.
127174
@@ -139,16 +186,23 @@ def encode(self, s: str):
139186
add_special_tokens = True
140187
return self.model.encode(s, add_special_tokens=add_special_tokens)
141188

142-
def decode(self, t: Sequence[int]):
189+
def decode(self, t: Sequence[int], offset: Optional[int] = None):
143190
"""De-tokenize.
144191
145192
Args:
146193
t (List[int]): a list of token ids
194+
offset (int): for incrementally decoding. Default to None, which
195+
means not applied.
147196
Returns:
148197
str: text of decoding tokens
149198
"""
150199
skip_special_tokens = True
151-
return self.model.decode(t, skip_special_tokens=skip_special_tokens)
200+
t = t[offset:]
201+
out_string = self.model.decode(t,
202+
skip_special_tokens=skip_special_tokens)
203+
if offset:
204+
out_string = self._maybe_add_prefix_space(t, out_string)
205+
return out_string
152206

153207
def __call__(self, s: Union[str, Sequence[str]]):
154208
"""Tokenize prompts.
@@ -211,15 +265,17 @@ def encode(self, s: str):
211265
"""
212266
return self.model.encode(s)
213267

214-
def decode(self, t: Sequence[int]):
268+
def decode(self, t: Sequence[int], offset: Optional[int] = None):
215269
"""De-tokenize.
216270
217271
Args:
218272
t (List[int]): a list of token ids
273+
offset (int): for incrementally decoding. Default to None, which
274+
means not applied.
219275
Returns:
220276
str: text of decoding tokens
221277
"""
222-
return self.model.decode(t)
278+
return self.model.decode(t, offset)
223279

224280
def __call__(self, s: Union[str, Sequence[str]]):
225281
"""Tokenize prompts.

0 commit comments

Comments
 (0)