@@ -1114,13 +1114,18 @@ def _prepare_tp_inputs(
11141114 new_tokens_lens_device = new_tensors_device .new_tokens_lens # [batch]
11151115 next_draft_tokens_device = new_tensors_device .next_draft_tokens # [batch, draft_len]
11161116
1117- # Requests with draft tokens are treated like extend requests.
1117+ # Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1118+ # at the end of extend_requests.
11181119 extend_requests = []
1120+ extend_dummy_requests = []
11191121 generation_requests = []
11201122 for request in scheduled_requests .generation_requests :
11211123 if len (request .py_draft_tokens
11221124 ) > 0 or next_draft_tokens_device is not None :
1123- extend_requests .append (request )
1125+ if request .is_dummy :
1126+ extend_dummy_requests .append (request )
1127+ else :
1128+ extend_requests .append (request )
11241129 else :
11251130 generation_requests .append (request )
11261131
@@ -1130,6 +1135,7 @@ def _prepare_tp_inputs(
11301135 torch .tensor ([mrope_position_deltas ],
11311136 dtype = torch .int32 ).to ('cuda' ,
11321137 non_blocking = True ))
1138+ extend_requests += extend_dummy_requests
11331139
11341140 if not self ._disable_overlap_scheduler and self .is_spec_decode :
11351141 spec_dec_mode = self .spec_config .spec_dec_mode
@@ -1139,21 +1145,18 @@ def _prepare_tp_inputs(
11391145 # will contain previous batch incices of generation requests
11401146 previous_batch_indices = []
11411147 previous_pos_indices = []
1142- request_ids_with_previous_batch = []
1143- num_extend_reqs_wo_previous_batch = 0
11441148 for request in extend_requests :
1145- if next_draft_tokens_device is None or request .py_batch_idx is None :
1146- # the request has no previous device tensors:
1147- # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
1148- # (2) request.py_batch_idx is None, which means the request has no previous batch.
1149- # the second condition includes dummy generation requests created for CUDA graph padding or
1150- # attention DP. These dummy generation requests should be at the head of generation_requests.
1151- # TODO: move the dummy generation requests to the end of generation_requests to align with
1152- # the logic for those requests in generation_requests.
1153- # get token ids, including input token ids and draft token ids
1154- input_ids .append (request .get_last_tokens (0 ))
1155- input_ids .extend (request .py_draft_tokens )
1156- draft_tokens .extend (request .py_draft_tokens )
1149+ # the request has no previous tensor:
1150+ # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
1151+ # (2) a dummy request; or
1152+ # (3) the first step in the generation server of disaggregated serving
1153+ if next_draft_tokens_device is None or request .is_dummy or request .py_batch_idx is None :
1154+ # get token ids, including input token ids and draft token ids. For these dummy requests,
1155+ # no need to copy the token ids.
1156+ if not request .is_dummy :
1157+ input_ids .append (request .get_last_tokens (0 ))
1158+ input_ids .extend (request .py_draft_tokens )
1159+ draft_tokens .extend (request .py_draft_tokens )
11571160 # get other ids and lengths
11581161 num_draft_tokens = len (request .py_draft_tokens )
11591162 past_seen_token_num = request .max_beam_num_tokens - 1
@@ -1173,7 +1176,6 @@ def _prepare_tp_inputs(
11731176 # update batch index
11741177 request .py_batch_idx = batch_idx
11751178 batch_idx += 1
1176- num_extend_reqs_wo_previous_batch += 1
11771179 else :
11781180 # update batch index
11791181 previous_batch_idx = request .py_batch_idx
@@ -1200,10 +1202,7 @@ def _prepare_tp_inputs(
12001202 num_cached_tokens_per_seq .append (past_seen_token_num +
12011203 self .max_draft_len + 1 )
12021204 prompt_lengths .append (request .py_prompt_len )
1203- request_ids_with_previous_batch .append (request .py_request_id )
1204-
1205- # move requests with previous batch to the end of the list
1206- request_ids .extend (request_ids_with_previous_batch )
1205+ request_ids .append (request .py_request_id )
12071206
12081207 sequence_lengths .extend ([1 ] * len (generation_requests ))
12091208 gather_ids .extend (
@@ -1238,6 +1237,7 @@ def _prepare_tp_inputs(
12381237 num_tokens = len (input_ids )
12391238 num_draft_tokens = len (draft_tokens )
12401239 previous_batchs = len (previous_batch_indices )
1240+ num_requests = len (request_ids )
12411241 # if exist requests that do not have previous batch, copy input_ids and draft_tokens
12421242 if num_tokens > 0 :
12431243 input_ids = torch .tensor (input_ids ,
@@ -1276,31 +1276,27 @@ def _prepare_tp_inputs(
12761276 non_blocking = True )
12771277 # prepare data for the preprocess inputs
12781278 kv_len_offsets_device = new_tokens_lens_device - self .max_draft_len - 1
1279- pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1280- 1 + self .max_draft_len )
1281- pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1282- pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1283- pre_batch_end_idx = pre_batch_start_idx + previous_batchs
12841279 previous_pos_indices = torch .tensor (previous_pos_indices ,
12851280 dtype = torch .int ,
12861281 pin_memory = True )
1287- self .previous_pos_indices_cuda [
1288- pre_tokens_start_idx :pre_tokens_end_idx ].copy_ (
1289- previous_pos_indices , non_blocking = True )
1282+ self .previous_pos_indices_cuda [0 :previous_batch_tokens ].copy_ (
1283+ previous_pos_indices , non_blocking = True )
12901284 self .previous_pos_id_offsets_cuda [
1291- pre_tokens_start_idx : pre_tokens_end_idx ].copy_ (
1285+ 0 : previous_batch_tokens ].copy_ (
12921286 new_tokens_lens_device [self .previous_pos_indices_cuda [
1293- pre_tokens_start_idx :pre_tokens_end_idx ]],
1294- non_blocking = True )
1295- self .previous_kv_lens_offsets_cuda [
1296- pre_batch_start_idx :pre_batch_end_idx ].copy_ (
1297- kv_len_offsets_device [
1298- self .previous_batch_indices_cuda [:previous_batchs ]],
1287+ 0 :previous_batch_tokens ]],
12991288 non_blocking = True )
1289+ self .previous_kv_lens_offsets_cuda [0 :previous_batchs ].copy_ (
1290+ kv_len_offsets_device [
1291+ self .previous_batch_indices_cuda [:previous_batchs ]],
1292+ non_blocking = True )
13001293 # for the requests that do not have previous batch, set the previous_pos_id_offsets and
13011294 # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1302- self .previous_pos_id_offsets_cuda [:pre_tokens_start_idx ] *= 0
1303- self .previous_kv_lens_offsets_cuda [:pre_batch_start_idx ] *= 0
1295+ self .previous_pos_id_offsets_cuda [
1296+ previous_batch_tokens :num_requests *
1297+ (1 + self .max_draft_len )] *= 0
1298+ self .previous_kv_lens_offsets_cuda [
1299+ previous_batchs :num_requests ] *= 0
13041300 else :
13051301 # change the data to zeros to skip the value changes in _preprocess_inputs
13061302 self .previous_pos_id_offsets_cuda *= 0
0 commit comments