Skip to content

Commit 4dd7778

Browse files
committed
Fix merge conflicts
1 parent a544170 commit 4dd7778

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def _dynamic_step_context_init(
508508
inference_wrapper_config.moe_pad_experts_for_cuda_graph_inference
509509
)
510510
if moe_pad_experts_for_cuda_graph_inference:
511-
if context.is_decode_only():
511+
if context.is_decode_only() or warmup_engine_mode is not None:
512512
capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk
513513
set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor)
514514
else:
@@ -565,18 +565,23 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
565565
)
566566
return logits
567567

568-
def _dynamic_step_sample_bookkeeping(self, sampling_params: SamplingParams):
568+
def _dynamic_step_sample_bookkeeping(
569+
self,
570+
active_sampling_map: List[Tuple[SamplingParams, List[int]]],
571+
):
569572
"""Perform bookkeeping necessary to sample logits for dynamic batching."""
570573
pass
571574

572575
def _dynamic_step_sample_logits(
573-
self, logits: Tensor, sampling_params: SamplingParams
576+
self, logits: Tensor, active_sampling_map: List[Tuple[SamplingParams, List[int]]],
574577
) -> Tensor:
575578
"""Sample logits for dynamic batching.
576579
577580
Args:
578581
logits (Tensor): The logits from the forward step.
579-
sampling_params (SamplingParams): Parameters for sampling logits.
582+
active_sampling_map (List[Tuple[SamplingParams, List[int]]]): A list of tuples
583+
matching each unique set of sampling params to the context array indices
584+
of the corresponding active requests.
580585
581586
Returns:
582587
new_sample (Tensor): The sampled tokens for each active request.
@@ -609,7 +614,10 @@ def _dynamic_step_log_probs_bookkeeping(self):
609614
pass
610615

611616
def _dynamic_step_calculate_log_probs(
612-
self, logits: Tensor, new_sample: Tensor, sampling_params: SamplingParams
617+
self,
618+
logits: Tensor,
619+
new_sample: Tensor,
620+
active_sampling_map: List[Tuple[SamplingParams, List[int]]],
613621
) -> Optional[Tensor]:
614622
context = self.inference_wrapped_model.inference_context
615623
materialize_only_last_token_logits = context.materialize_only_last_token_logits
@@ -710,13 +718,13 @@ async def async_generate_output_tokens_dynamic_batch(
710718
if context.active_token_count == 0:
711719
return None
712720

713-
# This method only interacts with CPU tensors.
721+
# This method only performs computations using CPU tensors.
714722
input_ids, position_ids = self._dynamic_step_context_init()
715723
cuda_graph_request_count = (
716724
context.padded_active_request_count if context.is_decode_only() else None
717725
)
718726

719-
# This method only interacts with GPU tensors.
727+
# This method only performs computations using GPU tensors.
720728
logits = self._dynamic_step_forward_logits(input_ids, position_ids)
721729

722730
# This is the best place to yield control back to event loop.
@@ -728,17 +736,17 @@ async def async_generate_output_tokens_dynamic_batch(
728736
# NOTE [TDE]: This will be moved once CPU and GPU methods are separated.
729737
await asyncio.sleep(0)
730738

731-
# This method will only interact with CPU tensors in the future.
732-
self._dynamic_step_sample_bookkeeping(sampling_params)
733-
# This method will only interact with GPU tensors in the future.
734-
new_sample = self._dynamic_step_sample_logits(logits, sampling_params)
739+
# This method will only perform computations using CPU tensors in the future.
740+
self._dynamic_step_sample_bookkeeping(active_sampling_map)
741+
# This method will only perform computations using GPU tensors in the future.
742+
new_sample, termination_id = self._dynamic_step_sample_logits(logits, active_sampling_map)
735743

736-
# This method will only interact with CPU tensors in the future.
744+
# This method will only perform computations using CPU tensors in the future.
737745
self._dynamic_step_log_probs_bookkeeping()
738-
# This method will only interact with GPU tensors in the future.
739-
log_probs = self._dynamic_step_calculate_log_probs(logits, new_sample, sampling_params)
746+
# This method will only perform computations using GPU tensors in the future.
747+
log_probs = self._dynamic_step_calculate_log_probs(logits, new_sample, active_sampling_map)
740748

741-
# This method only interacts with CPU tensors.
749+
# This method only performs computations using CPU tensors.
742750
if skip_bookkeeping:
743751
request_bookeeping = {}
744752
else:

0 commit comments

Comments
 (0)