Skip to content

Commit 3de0dbb

Browse files
authored
Add non-stream inference api for chatbot (#200)
* add non-stream inference api for chatbot * update according to reviewer's comments
1 parent b7e7e66 commit 3de0dbb

File tree

3 files changed

+87
-11
lines changed

3 files changed

+87
-11
lines changed

lmdeploy/serve/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,22 @@ def input_prompt():
1313
return '\n'.join(iter(input, sentinel))
1414

1515

16-
def main(tritonserver_addr: str, session_id: int = 1):
16+
def main(tritonserver_addr: str,
17+
session_id: int = 1,
18+
stream_output: bool = True):
1719
"""An example to communicate with inference server through the command line
1820
interface.
1921
2022
Args:
2123
tritonserver_addr (str): the address in format "ip:port" of
2224
triton inference server
2325
session_id (int): the identical id of a session
26+
stream_output (bool): indicator for streaming output or not
2427
"""
2528
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
26-
chatbot = Chatbot(tritonserver_addr, log_level=log_level, display=True)
29+
chatbot = Chatbot(tritonserver_addr,
30+
log_level=log_level,
31+
display=stream_output)
2732
nth_round = 1
2833
while True:
2934
prompt = input_prompt()
@@ -33,12 +38,19 @@ def main(tritonserver_addr: str, session_id: int = 1):
3338
chatbot.end(session_id)
3439
else:
3540
request_id = f'{session_id}-{nth_round}'
36-
for status, res, n_token in chatbot.stream_infer(
37-
session_id,
38-
prompt,
39-
request_id=request_id,
40-
request_output_len=512):
41-
continue
41+
if stream_output:
42+
for status, res, n_token in chatbot.stream_infer(
43+
session_id,
44+
prompt,
45+
request_id=request_id,
46+
request_output_len=512):
47+
continue
48+
else:
49+
status, res, n_token = chatbot.infer(session_id,
50+
prompt,
51+
request_id=request_id,
52+
request_output_len=512)
53+
print(res)
4254
nth_round += 1
4355

4456

lmdeploy/serve/turbomind/chatbot.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,65 @@ def resume(self, session_id: int, *args, **kwargs):
294294
self._session.histories = histories
295295
return status
296296

297+
def infer(self,
298+
session_id: int,
299+
prompt: str,
300+
request_id: str = '',
301+
request_output_len: int = None,
302+
sequence_start: bool = False,
303+
sequence_end: bool = False,
304+
*args,
305+
**kwargs):
306+
"""Start a new round conversion of a session. Return the chat
307+
completions in non-stream mode.
308+
309+
Args:
310+
session_id (int): the identical id of a session
311+
prompt (str): user's prompt in this round conversation
312+
request_id (str): the identical id of this round conversation
313+
request_output_len (int): the expected generated token numbers
314+
sequence_start (bool): start flag of a session
315+
sequence_end (bool): end flag of a session
316+
Returns:
317+
tuple(Status, str, int): status, text/chat completion,
318+
generated token number
319+
"""
320+
assert isinstance(session_id, int), \
321+
f'INT session id is required, but got {type(session_id)}'
322+
323+
logger = get_logger(log_level=self.log_level)
324+
logger.info(f'session {session_id}, request_id {request_id}, '
325+
f'request_output_len {request_output_len}')
326+
327+
if self._session is None:
328+
sequence_start = True
329+
self._session = Session(session_id=session_id)
330+
elif self._session.status == 0:
331+
logger.error(f'session {session_id} has been ended. Please set '
332+
f'`sequence_start` be True if you want to restart it')
333+
return StatusCode.TRITON_SESSION_CLOSED, '', 0
334+
335+
self._session.status = 1
336+
self._session.request_id = request_id
337+
self._session.response = ''
338+
339+
self._session.prompt = self._get_prompt(prompt, sequence_start)
340+
status, res, tokens = None, '', 0
341+
for status, res, tokens in self._stream_infer(self._session,
342+
self._session.prompt,
343+
request_output_len,
344+
sequence_start,
345+
sequence_end):
346+
if status.value < 0:
347+
break
348+
if status.value == 0:
349+
self._session.histories = \
350+
self._session.histories + self._session.prompt + \
351+
self._session.response
352+
return status, res, tokens
353+
else:
354+
return status, res, tokens
355+
297356
def reset_session(self):
298357
"""reset session."""
299358
self._session = None

lmdeploy/turbomind/chat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ def valid_str(string, coding='utf-8'):
3232
def main(model_path,
3333
session_id: int = 1,
3434
repetition_penalty: float = 1.0,
35-
tp=1):
35+
tp=1,
36+
stream_output=True):
3637
"""An example to perform model inference through the command line
3738
interface.
3839
3940
Args:
4041
model_path (str): the path of the deployed model
4142
session_id (int): the identical id of a session
43+
repetition_penalty (float): parameter to penalize repetition
44+
tp (int): GPU number used in tensor parallelism
45+
stream_output (bool): indicator for streaming output or not
4246
"""
4347
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
4448
tokenizer = Tokenizer(tokenizer_model_path)
@@ -62,7 +66,8 @@ def main(model_path,
6266
input_ids=[input_ids],
6367
request_output_len=512,
6468
sequence_start=False,
65-
sequence_end=True):
69+
sequence_end=True,
70+
stream_output=stream_output):
6671
pass
6772
nth_round = 1
6873
step = 0
@@ -80,7 +85,7 @@ def main(model_path,
8085
for outputs in generator.stream_infer(
8186
session_id=session_id,
8287
input_ids=[input_ids],
83-
stream_output=True,
88+
stream_output=stream_output,
8489
request_output_len=512,
8590
sequence_start=(nth_round == 1),
8691
sequence_end=False,

0 commit comments

Comments
 (0)