5757class RequestQueueItem :
5858 id : int
5959 request : Optional [ExecutorRequest ] = None
60+ child_req_ids : Optional [list ] = None
6061 is_canceled_request : bool = False
6162 query : Optional [list ] = None # only used in `StarAttention`
6263
@@ -88,6 +89,13 @@ def _get_from_request_queue(
8889 return items
8990
9091
92+ def _get_num_child_requests (request : ExecutorRequest ) -> int :
93+ sampling_config = request .sampling_config
94+ logger .info (sampling_config )
95+ return 0 if sampling_config .beam_width > 1 else (
96+ sampling_config .num_return_sequences or 1 ) - 1
97+
98+
9199def _get_from_waiting_queue (
92100 waiting_queue : deque [RequestQueueItem ],
93101 max_req_count : int ,
@@ -108,8 +116,9 @@ def _get_from_waiting_queue(
108116 items = []
109117 req_count = 0
110118 while req_count < max_req_count and waiting_queue :
111- items .append (waiting_queue .popleft ())
112- req_count += 1
119+ req_item = waiting_queue .popleft ()
120+ items .append (req_item )
121+ req_count += 1 + _get_num_child_requests (req_item .request )
113122 return items
114123
115124
@@ -359,9 +368,16 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
359368 start_time = time .time ()
360369 for request in requests :
361370 self .start_times [self .next_req_id ] = start_time
362- self .request_queue .put (
363- RequestQueueItem (self .next_req_id , request ))
371+ req_id = self .next_req_id
364372 req_ids .append (self .next_req_id )
373+
374+ child_req_ids = []
375+ num_child_requests = _get_num_child_requests (request )
376+ for _ in range (num_child_requests ):
377+ self .next_req_id += 1
378+ child_req_ids .append (self .next_req_id )
379+ self .request_queue .put (
380+ RequestQueueItem (req_id , request , child_req_ids ))
365381 self .next_req_id += 1
366382 finally :
367383 self .enqueue_lock .release ()
@@ -472,14 +488,23 @@ def enqueue_request(self,
472488 try :
473489 self .enqueue_lock .acquire ()
474490 assert self .active , "PyExecutor has already been shutdown."
491+ logger .info (
492+ f"Enqueuing new Executor request with id { self .next_req_id } " )
475493 req_id = self .next_req_id
476494 if self .enable_iter_perf_stats :
477495 self .start_times [req_id ] = time .time ()
478496
479497 if query is not None :
480- self .request_queue .put (RequestQueueItem (req_id , request , query ))
498+ self .request_queue .put (
499+ RequestQueueItem (req_id , request , [], False , query ))
481500 else :
482- self .request_queue .put (RequestQueueItem (req_id , request ))
501+ child_req_ids = []
502+ num_child_requests = _get_num_child_requests (request )
503+ for _ in range (num_child_requests ):
504+ self .next_req_id += 1
505+ child_req_ids .append (self .next_req_id )
506+ self .request_queue .put (
507+ RequestQueueItem (req_id , request , child_req_ids ))
483508 self .next_req_id += 1
484509 finally :
485510 self .enqueue_lock .release ()
@@ -1506,12 +1531,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
15061531 else :
15071532 raise NotImplementedError (f'unsupport cp type { cp_type } ' )
15081533 else :
1509- return [
1510- executor_request_to_llm_request (
1511- req_item .id , req_item .request ,
1534+ req_with_children = []
1535+ for req_item in new_requests :
1536+ req = executor_request_to_llm_request (
1537+ req_item .id , req_item .request , req_item .child_req_ids ,
15121538 self ._should_exclude_last_generation_logits ())
1513- for req_item in new_requests
1514- ]
1539+ req_with_children .append (req )
1540+ for child in req .children :
1541+ req_with_children .append (child )
1542+ return req_with_children
15151543
15161544 @nvtx_range ("_schedule" )
15171545 def _schedule (self ):
@@ -1977,7 +2005,7 @@ def _handle_canceled_requests(self):
19772005 if req .id not in self .canceled_req_ids )
19782006
19792007 for request in self .active_requests :
1980- req_id = request .py_request_id
2008+ req_id = request .py_request_id if not request . is_child else request . parent_request_id
19812009 if req_id in self .canceled_req_ids :
19822010 # Mark requests as finished, then, we reuse all existing code
19832011 # to clean up the KV cache resources.
@@ -1991,7 +2019,7 @@ def _handle_canceled_requests(self):
19912019 self .canceled_req_ids .clear ()
19922020
19932021 @nvtx_range ("_enqueue_responses" )
1994- def _enqueue_responses (self , responses : Dict [ int , LlmResponse ]):
2022+ def _enqueue_responses (self , responses : List [ Tuple [ int , LlmResponse ] ]):
19952023 if 0 not in self .dist .mapping .tp_group and not self .gather_all_responses :
19962024 return
19972025
@@ -2003,18 +2031,18 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
20032031 else :
20042032 responses_list = self .dist .allgather (responses )
20052033 if self .dist .rank == 0 or self .gather_all_responses :
2006- gather_responses = {}
2034+ gather_responses = []
20072035 if responses_list is not None :
20082036 for resp in responses_list :
20092037 if resp is not None :
2010- gather_responses .update (resp )
2038+ gather_responses .append (resp )
20112039 responses = gather_responses
20122040 logger .debug (
20132041 f'after gather, rank = { self .dist .rank } , responses = { responses } ' )
20142042
20152043 if self .dist .rank == 0 or self .gather_all_responses :
20162044 with self .response_cv :
2017- for req_id , resp in responses . items () :
2045+ for req_id , resp in responses :
20182046 if req_id in self .responses .keys ():
20192047 self .responses [req_id ].append (resp )
20202048 else :
@@ -2023,20 +2051,20 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
20232051
20242052 @nvtx_range ("_handle_first_token_response" )
20252053 def _handle_first_token_response (self , scheduled_batch ):
2026- new_responses = {}
2054+ new_responses = []
20272055 for req in scheduled_batch .generation_requests :
20282056 if req .py_decoding_iter == 1 :
20292057 logger .debug (
20302058 f'Send first token response for request { req .py_request_id } '
20312059 )
20322060 response = req .create_response (False , self .dist .rank )
2033- new_responses .update ({ req .py_request_id : response } )
2061+ new_responses .append (( req .py_request_id , response ) )
20342062
20352063 self ._enqueue_responses (new_responses )
20362064
20372065 @nvtx_range ("_handle_responses" )
20382066 def _handle_responses (self ):
2039- new_responses = {}
2067+ new_responses = []
20402068 requests_to_terminate = []
20412069 new_active_requests = []
20422070 logger .debug (
@@ -2070,14 +2098,17 @@ def _handle_responses(self):
20702098 request .py_decoding_iter % self .stream_interval == 0 :
20712099 response = request .create_response (False , self .dist .rank )
20722100 if response :
2073- request_done = response . result . is_final
2074- new_responses .update ({ req_id : response } )
2101+ request_done = request . is_finished
2102+ new_responses .append (( req_id , response ) )
20752103
20762104 if request_done :
20772105 if request .is_disagg_context_transmission_state :
20782106 self .ctx_in_transmission_requests .append (request )
20792107 else :
2080- requests_to_terminate .append (request )
2108+ if response .result .is_final :
2109+ requests_to_terminate .append (request )
2110+ for child in request .children :
2111+ requests_to_terminate .append (child )
20812112 else :
20822113 new_active_requests .append (request )
20832114 self .active_requests = new_active_requests
0 commit comments