11import json
2+ import logging
23import multiprocessing as mp
34import os
45import random
@@ -28,10 +29,8 @@ def encode(self, prompts: List):
2829
2930def infer (chatbot , session_id : int , req_que : mp .Queue , res_que : mp .Queue ):
3031 stats = []
31- while not req_que .empty ():
32- prompt , input_seqlen , output_seqlen = req_que .get ()
33- print (f'request info: session { session_id } , '
34- f'input_seqlen { input_seqlen } , output_seqlen { output_seqlen } ' )
32+ for prompt , input_seqlen , output_seqlen in iter (req_que .get ,
33+ [None , None , None ]):
3534 timestamps = []
3635 tokens = []
3736 start = time .perf_counter ()
@@ -43,12 +42,13 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
4342 sequence_end = True ):
4443 timestamps .append (time .perf_counter ())
4544 tokens .append (token )
46- chatbot .reset_session ()
4745
48- first_token_latency = timestamps [1 ] - start
49- token_latency = timestamps [- 1 ] - timestamps [0 ]
46+ first_token_latency = np . round ( timestamps [1 ] - start , 3 )
47+ token_latency = np . round ( timestamps [- 1 ] - timestamps [0 ], 3 )
5048 token = tokens [- 1 ] - tokens [0 ]
5149 stats .append ([first_token_latency , token , token_latency ])
50+ print (f'session { session_id } : '
51+ f'input_seqlen { input_seqlen } , output_seqlen { output_seqlen } ' )
5252 res_que .put ((session_id , stats ))
5353
5454
@@ -73,6 +73,7 @@ def _infer(_chatbot, session_id):
7373 chatbots = [
7474 Chatbot (tritonserver_addr = tritonserver_addr ,
7575 ignore_eos = True ,
76+ log_level = logging .ERROR ,
7677 profile_generation = True ) for _ in range (concurrency )
7778 ]
7879 procs = []
@@ -87,7 +88,7 @@ def _infer(_chatbot, session_id):
8788
8889
8990def read_dataset (tokenizer_path : str , dataset_path : str , samples : int ,
90- session_len : int ):
91+ session_len : int , que : mp . Queue ):
9192 start = time .perf_counter ()
9293 with open (dataset_path ) as f :
9394 dataset = json .load (f )
@@ -119,12 +120,11 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
119120 if samples > 0 :
120121 filtered_dataset = random .sample (filtered_dataset , samples )
121122
122- que = mp .Queue ()
123123 for data in filtered_dataset :
124124 que .put (data )
125125 print (f'elapsed time for filtering: '
126126 f'{ round (time .perf_counter () - start , 2 )} s' )
127- return que , len (filtered_dataset )
127+ return len (filtered_dataset )
128128
129129
130130def main (tritonserver_addr : str ,
@@ -134,32 +134,39 @@ def main(tritonserver_addr: str,
134134 session_len : int = 2048 ,
135135 samples : int = 1000 ):
136136 warmup (tritonserver_addr , concurrency , session_len - 1 )
137- req_que , n_req = read_dataset (tokenizer_path , dataset_path , samples ,
138- session_len )
137+ req_que = mp .Queue ()
139138 res_que = mp .Queue ()
139+
140140 procs = []
141141 _start = time .perf_counter ()
142142 for i in range (concurrency ):
143143 chatbot = Chatbot (tritonserver_addr = tritonserver_addr ,
144144 display = False ,
145145 profile_serving = True ,
146- ignore_eos = True )
146+ ignore_eos = True ,
147+ log_level = logging .ERROR )
147148 proc = mp .Process (target = infer ,
148149 args = (chatbot , i + 1 , req_que , res_que ))
149150 procs .append (proc )
150151 proc .start ()
151- for proc in procs :
152- proc .join ()
153- _end = time .perf_counter ()
154- elapsed_time = _end - _start
152+
153+ # read data and put it to queue
154+ n_req = read_dataset (tokenizer_path , dataset_path , samples , session_len ,
155+ req_que )
156+ for i in range (concurrency ):
157+ req_que .put ([None , None , None ])
155158
156159 stats = []
157- while not res_que . empty ( ):
160+ for i in range ( concurrency ):
158161 session_id , _stats = res_que .get ()
159162 print (f'\n { "-" * 50 } \n '
160- f'session { session_id } stats: \n { _stats } \n { "-" * 50 } \n ' )
163+ f'session { session_id } : processed reqs { len (_stats )} , '
164+ f'stats: \n { _stats } \n { "-" * 50 } \n ' )
161165 stats .append (np .array (_stats ))
162166
167+ _end = time .perf_counter ()
168+ elapsed_time = _end - _start
169+
163170 stats = np .concatenate (stats ).reshape (- 1 , 3 )
164171
165172 first_token_latency_min = np .min (stats [:, 0 ], axis = 0 )
@@ -169,14 +176,17 @@ def main(tritonserver_addr: str,
169176 req_throughput = n_req / elapsed_time
170177
171178 print (f'\n { "-" * 50 } \n concurrency: { concurrency } \n '
172- f'elapsed_time: { elapsed_time :.2f } s\n '
179+ f'elapsed_time: { elapsed_time :.3f } s\n '
173180 f'first_token latency(min, max, ave): '
174- f'{ first_token_latency_min :.2f } s, { first_token_latency_max :.2f } s, '
175- f'{ first_token_latency_ave :.2f } s\n '
176- f'token throughput: { token_throughput :.2f } token/s\n '
177- f'req throughput: { req_throughput :.2f } req/s\n '
181+ f'{ first_token_latency_min :.3f } s, { first_token_latency_max :.3f } s, '
182+ f'{ first_token_latency_ave :.3f } s\n '
183+ f'token throughput: { token_throughput :.3f } token/s\n '
184+ f'req throughput: { req_throughput :.3f } req/s\n '
178185 f'{ "-" * 50 } \n ' )
179186
187+ for proc in procs :
188+ proc .join ()
189+
180190
181191if __name__ == '__main__' :
182192 fire .Fire (main )
0 commit comments