Skip to content

Commit edb7c6e

Browse files
authored
Fix profile_serving hung issue (#344)
* read data after start processes * fix hang * fix exceptions when request_output_len is 0
1 parent 9bfe03c commit edb7c6e

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

benchmark/profile_serving.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import multiprocessing as mp
34
import os
45
import random
@@ -28,10 +29,8 @@ def encode(self, prompts: List):
2829

2930
def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
3031
stats = []
31-
while not req_que.empty():
32-
prompt, input_seqlen, output_seqlen = req_que.get()
33-
print(f'request info: session {session_id}, '
34-
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
32+
for prompt, input_seqlen, output_seqlen in iter(req_que.get,
33+
[None, None, None]):
3534
timestamps = []
3635
tokens = []
3736
start = time.perf_counter()
@@ -43,12 +42,13 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
4342
sequence_end=True):
4443
timestamps.append(time.perf_counter())
4544
tokens.append(token)
46-
chatbot.reset_session()
4745

48-
first_token_latency = timestamps[1] - start
49-
token_latency = timestamps[-1] - timestamps[0]
46+
first_token_latency = np.round(timestamps[1] - start, 3)
47+
token_latency = np.round(timestamps[-1] - timestamps[0], 3)
5048
token = tokens[-1] - tokens[0]
5149
stats.append([first_token_latency, token, token_latency])
50+
print(f'session {session_id}: '
51+
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
5252
res_que.put((session_id, stats))
5353

5454

@@ -73,6 +73,7 @@ def _infer(_chatbot, session_id):
7373
chatbots = [
7474
Chatbot(tritonserver_addr=tritonserver_addr,
7575
ignore_eos=True,
76+
log_level=logging.ERROR,
7677
profile_generation=True) for _ in range(concurrency)
7778
]
7879
procs = []
@@ -87,7 +88,7 @@ def _infer(_chatbot, session_id):
8788

8889

8990
def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
90-
session_len: int):
91+
session_len: int, que: mp.Queue):
9192
start = time.perf_counter()
9293
with open(dataset_path) as f:
9394
dataset = json.load(f)
@@ -119,12 +120,11 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
119120
if samples > 0:
120121
filtered_dataset = random.sample(filtered_dataset, samples)
121122

122-
que = mp.Queue()
123123
for data in filtered_dataset:
124124
que.put(data)
125125
print(f'elapsed time for filtering: '
126126
f'{round(time.perf_counter() - start, 2)} s')
127-
return que, len(filtered_dataset)
127+
return len(filtered_dataset)
128128

129129

130130
def main(tritonserver_addr: str,
@@ -134,32 +134,39 @@ def main(tritonserver_addr: str,
134134
session_len: int = 2048,
135135
samples: int = 1000):
136136
warmup(tritonserver_addr, concurrency, session_len - 1)
137-
req_que, n_req = read_dataset(tokenizer_path, dataset_path, samples,
138-
session_len)
137+
req_que = mp.Queue()
139138
res_que = mp.Queue()
139+
140140
procs = []
141141
_start = time.perf_counter()
142142
for i in range(concurrency):
143143
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
144144
display=False,
145145
profile_serving=True,
146-
ignore_eos=True)
146+
ignore_eos=True,
147+
log_level=logging.ERROR)
147148
proc = mp.Process(target=infer,
148149
args=(chatbot, i + 1, req_que, res_que))
149150
procs.append(proc)
150151
proc.start()
151-
for proc in procs:
152-
proc.join()
153-
_end = time.perf_counter()
154-
elapsed_time = _end - _start
152+
153+
# read data and put it to queue
154+
n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len,
155+
req_que)
156+
for i in range(concurrency):
157+
req_que.put([None, None, None])
155158

156159
stats = []
157-
while not res_que.empty():
160+
for i in range(concurrency):
158161
session_id, _stats = res_que.get()
159162
print(f'\n{"-" * 50}\n'
160-
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
163+
f'session {session_id}: processed reqs {len(_stats)}, '
164+
f'stats: \n{_stats}\n{"-" * 50}\n')
161165
stats.append(np.array(_stats))
162166

167+
_end = time.perf_counter()
168+
elapsed_time = _end - _start
169+
163170
stats = np.concatenate(stats).reshape(-1, 3)
164171

165172
first_token_latency_min = np.min(stats[:, 0], axis=0)
@@ -169,14 +176,17 @@ def main(tritonserver_addr: str,
169176
req_throughput = n_req / elapsed_time
170177

171178
print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
172-
f'elapsed_time: {elapsed_time:.2f}s\n'
179+
f'elapsed_time: {elapsed_time:.3f}s\n'
173180
f'first_token latency(min, max, ave): '
174-
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
175-
f'{first_token_latency_ave:.2f}s\n'
176-
f'token throughput: {token_throughput:.2f} token/s\n'
177-
f'req throughput: {req_throughput:.2f} req/s\n'
181+
f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, '
182+
f'{first_token_latency_ave:.3f}s\n'
183+
f'token throughput: {token_throughput:.3f} token/s\n'
184+
f'req throughput: {req_throughput:.3f} req/s\n'
178185
f'{"-" * 50}\n')
179186

187+
for proc in procs:
188+
proc.join()
189+
180190

181191
if __name__ == '__main__':
182192
fire.Fire(main)

lmdeploy/serve/turbomind/chatbot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
631631
output_ids = output_ids[:, :, n_input_token +
632632
preseq_length:sequence_length.squeeze(
633633
)]
634-
last_token_id = output_ids[-1, -1, -1]
634+
last_token_id = None if output_ids.shape[
635+
-1] == 0 else output_ids[-1, -1, -1]
635636
if last_token_id == eos_id:
636637
session.sequence_length = session.sequence_length - 1
637638
output_ids = output_ids[:, :, :-1]
@@ -652,6 +653,8 @@ def stream_consumer(postprocess, res_queue, session, n_input_token,
652653
output_ids.shape[-1])
653654
except Exception as e:
654655
logger.error(f'catch exception: {e}')
656+
logger.error(
657+
f'session {session.session_id}: prompt: {session.prompt}')
655658

656659
# put session back to queue so that `_stream_infer` can update it in
657660
# `self.sessions`

0 commit comments

Comments
 (0)