@@ -1988,9 +1988,12 @@ def sample_async(
19881988 if beam_width > 1 :
19891989 self ._update_cache_indirection_buffer (scheduled_requests )
19901990
1991+ decoder_input_buffers = self .store ["decoder_input_buffers" ][self .micro_batch_idx ]
1992+ decoder_state = self .store ["decoder_state" ]
1993+
19911994 make_decoding_batch_input (
1992- self . store [ " decoder_input_buffers" ][ self . micro_batch_idx ] ,
1993- self . store [ " decoder_state" ] ,
1995+ decoder_input_buffers ,
1996+ decoder_state ,
19941997 scheduled_requests .context_requests ,
19951998 scheduled_requests .generation_requests ,
19961999 model_outputs ["logits" ],
@@ -2000,35 +2003,40 @@ def sample_async(
20002003 )
20012004
20022005 self .algs .decoder .forward_async (
2003- self . store [ " decoder_state" ] ,
2004- self . store [ " decoder_input_buffers" ][ self . micro_batch_idx ] ,
2006+ decoder_state ,
2007+ decoder_input_buffers ,
20052008 )
20062009
2010+ finished_context_requests = [
2011+ req for req in scheduled_requests .context_requests if req .is_last_context_chunk
2012+ ]
2013+ sampling_requests = finished_context_requests + scheduled_requests .generation_requests
2014+
20072015 finalize_events = {}
20082016 gathered_ids = None
20092017 if beam_width > 1 :
2010- finished_sum_device = self . store [ " decoder_state" ] .finished_sum
2018+ finished_sum_device = decoder_state .finished_sum
20112019
2012- for request in scheduled_requests . all_requests () :
2020+ for request in sampling_requests :
20132021 if request .is_context_init_state :
20142022 continue
20152023 if finished_sum_device [request .seq_slot ] == beam_width :
20162024 finalize_events [request .request_id ] = self ._finalize_request (request , False )
20172025 elif request .streaming :
20182026 finalize_events [request .request_id ] = self ._finalize_request (request , True )
2019- gathered_ids = self . store [ " decoder_state" ] .gathered_ids .to ("cpu" , non_blocking = True )
2020- new_output_tokens = self . store [ " decoder_state" ] .all_new_tokens .to ("cpu" , non_blocking = True )
2021- finished_sum = self . store [ " decoder_state" ] .finished_sum .to ("cpu" , non_blocking = True )
2022- finish_reasons = self . store [ " decoder_state" ] .finish_reasons .to ("cpu" , non_blocking = True )
2023- sequence_lengths = self . store [ " decoder_state" ] .sequence_lengths .to ("cpu" , non_blocking = True )
2027+ gathered_ids = decoder_state .gathered_ids .to ("cpu" , non_blocking = True )
2028+ new_output_tokens = decoder_state .all_new_tokens .to ("cpu" , non_blocking = True )
2029+ finished_sum = decoder_state .finished_sum .to ("cpu" , non_blocking = True )
2030+ finish_reasons = decoder_state .finish_reasons .to ("cpu" , non_blocking = True )
2031+ sequence_lengths = decoder_state .sequence_lengths .to ("cpu" , non_blocking = True )
20242032
20252033 log_probs = None
20262034 cum_log_probs = None
2027- if any (request .py_return_log_probs for request in scheduled_requests . all_requests () ):
2028- log_probs = self . store [ " decoder_state" ] .log_probs .to ("cpu" , non_blocking = True )
2029- cum_log_probs = self . store [ " decoder_state" ] .cum_log_probs .to ("cpu" , non_blocking = True )
2035+ if any (request .py_return_log_probs for request in sampling_requests ):
2036+ log_probs = decoder_state .log_probs .to ("cpu" , non_blocking = True )
2037+ cum_log_probs = decoder_state .cum_log_probs .to ("cpu" , non_blocking = True )
20302038
2031- device = SampleStateTensors (new_tokens = self . store [ " decoder_state" ] .all_new_tokens )
2039+ device = SampleStateTensors (new_tokens = decoder_state .all_new_tokens )
20322040
20332041 host = SampleStateTensorsHostTRTLLM (
20342042 new_tokens = new_output_tokens ,
@@ -2046,7 +2054,7 @@ def sample_async(
20462054 self .micro_batch_idx = (self .micro_batch_idx + 1 ) % self .num_micro_batches
20472055
20482056 return SampleStateTRTLLM (
2049- scheduled_requests = scheduled_requests ,
2057+ requests = sampling_requests ,
20502058 device = device ,
20512059 host = host ,
20522060 sampler_event = sampler_event ,
@@ -2062,13 +2070,13 @@ def update_requests(
20622070 ):
20632071 # resource_manager will not be used in this function, just for interface consistency.
20642072 assert isinstance (state , SampleStateTRTLLM )
2065- if state .scheduled_requests . batch_size == 0 :
2073+ if len ( state .requests ) == 0 :
20662074 return
20672075
20682076 if state .sampler_event :
20692077 state .sampler_event .synchronize ()
20702078
2071- beam_width = self .beam_width (state .scheduled_requests . all_requests () )
2079+ beam_width = self .beam_width (state .requests )
20722080
20732081 if beam_width == 1 and self .MAX_DECODING_TOKENS == 1 :
20742082 self .update_requests_single_beam_single_step (state )
@@ -2087,13 +2095,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
20872095 state .host .cum_log_probs .tolist () if state .host .cum_log_probs is not None else None
20882096 )
20892097
2090- reqs = [
2091- r for r in state .scheduled_requests .context_requests if not r .is_context_init_state
2092- ] + [
2093- r
2094- for r in state .scheduled_requests .generation_requests
2095- if not r .is_generation_complete_state
2096- ]
2098+ reqs = [r for r in state .requests if not r .is_generation_complete_state ]
20972099
20982100 reqs_with_new_tokens = [
20992101 r for r in reqs if (sequence_lengths_host_data [r .py_seq_slot ] > r .get_num_tokens (0 ))
@@ -2148,13 +2150,7 @@ def update_requests_multiple_beams_or_drafting(
21482150 log_probs_host = state .host .log_probs .tolist () if state .host .log_probs is not None else None
21492151 finalize_events = state .finalize_events
21502152
2151- reqs = [
2152- r for r in state .scheduled_requests .context_requests if not r .is_context_init_state
2153- ] + [
2154- r
2155- for r in state .scheduled_requests .generation_requests
2156- if not r .is_generation_complete_state
2157- ]
2153+ reqs = [r for r in state .requests if not r .is_generation_complete_state ]
21582154
21592155 for request in reqs :
21602156 seq_slot = request .py_seq_slot
0 commit comments