@@ -508,7 +508,7 @@ def _dynamic_step_context_init(
508508 inference_wrapper_config .moe_pad_experts_for_cuda_graph_inference
509509 )
510510 if moe_pad_experts_for_cuda_graph_inference :
511- if context .is_decode_only ():
511+ if context .is_decode_only () or warmup_engine_mode is not None :
512512 capacity_factor = model_config .num_moe_experts / model_config .moe_router_topk
513513 set_decode_expert_padding (unwrapped_model , True , capacity_factor = capacity_factor )
514514 else :
@@ -565,18 +565,23 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
565565 )
566566 return logits
567567
568- def _dynamic_step_sample_bookkeeping (self , sampling_params : SamplingParams ):
568+ def _dynamic_step_sample_bookkeeping (
569+ self ,
570+ active_sampling_map : List [Tuple [SamplingParams , List [int ]]],
571+ ):
569572 """Perform bookkeeping necessary to sample logits for dynamic batching."""
570573 pass
571574
572575 def _dynamic_step_sample_logits (
573- self , logits : Tensor , sampling_params : SamplingParams
576+ self , logits : Tensor , active_sampling_map : List [ Tuple [ SamplingParams , List [ int ]]],
574577 ) -> Tensor :
575578 """Sample logits for dynamic batching.
576579
577580 Args:
578581 logits (Tensor): The logits from the forward step.
579- sampling_params (SamplingParams): Parameters for sampling logits.
582+ active_sampling_map (List[Tuple[SamplingParams, List[int]]]): A list of tuples
583+ matching each unique set of sampling params to the context array indices
584+ of the corresponding active requests.
580585
581586 Returns:
582587 new_sample (Tensor): The sampled tokens for each active request.
@@ -609,7 +614,10 @@ def _dynamic_step_log_probs_bookkeeping(self):
609614 pass
610615
611616 def _dynamic_step_calculate_log_probs (
612- self , logits : Tensor , new_sample : Tensor , sampling_params : SamplingParams
617+ self ,
618+ logits : Tensor ,
619+ new_sample : Tensor ,
620+ active_sampling_map : List [Tuple [SamplingParams , List [int ]]],
613621 ) -> Optional [Tensor ]:
614622 context = self .inference_wrapped_model .inference_context
615623 materialize_only_last_token_logits = context .materialize_only_last_token_logits
@@ -710,13 +718,13 @@ async def async_generate_output_tokens_dynamic_batch(
710718 if context .active_token_count == 0 :
711719 return None
712720
713- # This method only interacts with CPU tensors.
721+ # This method only performs computations using CPU tensors.
714722 input_ids , position_ids = self ._dynamic_step_context_init ()
715723 cuda_graph_request_count = (
716724 context .padded_active_request_count if context .is_decode_only () else None
717725 )
718726
719- # This method only interacts with GPU tensors.
727+ # This method only performs computations using GPU tensors.
720728 logits = self ._dynamic_step_forward_logits (input_ids , position_ids )
721729
722730 # This is the best place to yield control back to event loop.
@@ -728,17 +736,17 @@ async def async_generate_output_tokens_dynamic_batch(
728736 # NOTE [TDE]: This will be moved once CPU and GPU methods are separated.
729737 await asyncio .sleep (0 )
730738
731- # This method will only interact with CPU tensors in the future.
732- self ._dynamic_step_sample_bookkeeping (sampling_params )
733- # This method will only interact with GPU tensors in the future.
734- new_sample = self ._dynamic_step_sample_logits (logits , sampling_params )
739+ # This method will only perform computations using CPU tensors in the future.
740+ self ._dynamic_step_sample_bookkeeping (active_sampling_map )
741+ # This method will only perform computations using GPU tensors in the future.
742+ new_sample , termination_id = self ._dynamic_step_sample_logits (logits , active_sampling_map )
735743
736- # This method will only interact with CPU tensors in the future.
744+ # This method will only perform computations using CPU tensors in the future.
737745 self ._dynamic_step_log_probs_bookkeeping ()
738- # This method will only interact with GPU tensors in the future.
739- log_probs = self ._dynamic_step_calculate_log_probs (logits , new_sample , sampling_params )
746+ # This method will only perform computations using GPU tensors in the future.
747+ log_probs = self ._dynamic_step_calculate_log_probs (logits , new_sample , active_sampling_map )
740748
741- # This method only interacts with CPU tensors.
749+ # This method only performs computations using CPU tensors.
742750 if skip_bookkeeping :
743751 request_bookeeping = {}
744752 else :
0 commit comments