@@ -325,6 +325,27 @@ def _get_request_id(self):
325325 self ._next_req_id = (self ._next_req_id + 1 ) & ((1 << 64 ) - 1 )
326326 return self ._next_req_id
327327
328+ def _generate_child_request_ids (
329+ self , request : ExecutorRequest ) -> List [int ] | None :
330+ """ Generate child request IDs if needed. """
331+ child_req_ids = None
332+ sampling_config = request .sampling_config
333+ beam_width = (sampling_config .beam_width
334+ if sampling_config .beam_width else 1 )
335+ num_return_sequences = (sampling_config .num_return_sequences
336+ if sampling_config .num_return_sequences else 1 )
337+
338+ # Create child requests if beam_width == 1 and num_return_sequences > 1.
339+ if beam_width == 1 and num_return_sequences > 1 :
340+ child_req_ids = []
341+ for _ in range (num_return_sequences - 1 ):
342+ child_req_id = self ._get_request_id ()
343+ if self .enable_iter_perf_stats :
344+ self .start_times [child_req_id ] = time .time ()
345+ child_req_ids .append (child_req_id )
346+
347+ return child_req_ids
348+
328349 def enqueue_requests (self , requests : List [ExecutorRequest ]):
329350 """
330351 Enqueue new requests
@@ -339,21 +360,8 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
339360 if self .enable_iter_perf_stats :
340361 self .start_times [req_id ] = time .time ()
341362
342- # Generate child request IDs if needed
343- child_req_ids = None
344- sampling_config = request .sampling_config
345- beam_width = sampling_config .beam_width
346- num_return_sequences = sampling_config .num_return_sequences or beam_width
347-
348- if beam_width == 1 and num_return_sequences > 1 :
349- # Reserve request ids for child requests.
350- child_req_ids = []
351- for _ in range (num_return_sequences - 1 ):
352- child_req_id = self ._get_request_id ()
353- if self .enable_iter_perf_stats :
354- self .start_times [child_req_id ] = time .time ()
355- child_req_ids .append (child_req_id )
356-
363+ # Reserve child request ids if needed.
364+ child_req_ids = self ._generate_child_request_ids (request )
357365 self .request_queue .put (
358366 RequestQueueItem (req_id ,
359367 request ,
@@ -476,23 +484,8 @@ def enqueue_request(self,
476484 if self .enable_iter_perf_stats :
477485 self .start_times [req_id ] = time .time ()
478486
479- # Generate child request IDs if needed
480- child_req_ids = None
481- sampling_config = request .sampling_config
482- beam_width = (sampling_config .beam_width
483- if sampling_config .beam_width else 1 )
484- num_return_sequences = (sampling_config .num_return_sequences if
485- sampling_config .num_return_sequences else 1 )
486-
487- # Only create child requests if beam_width == 1 and num_return_sequences > 1
488- if beam_width == 1 and num_return_sequences > 1 :
489- child_req_ids = []
490- for i in range (num_return_sequences - 1 ):
491- child_req_id = self ._get_request_id ()
492- if self .enable_iter_perf_stats :
493- self .start_times [child_req_id ] = time .time ()
494- child_req_ids .append (child_req_id )
495-
487+ # Reserve child request ids if needed.
488+ child_req_ids = self ._generate_child_request_ids (request )
496489 self .request_queue .put (
497490 RequestQueueItem (req_id ,
498491 request ,
0 commit comments