diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 1a53787002..0ec1c428b0 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -7,6 +7,7 @@ import os import pickle import sys +import warnings import torch from argparse import ArgumentParser from collections import defaultdict @@ -56,20 +57,6 @@ from megatron.core.utils import configure_nvtx_profiling -import json - -from examples.inference.gpt.utils import ( - Request, - add_common_inference_args, - build_dynamic_engine_setup_prefix, - build_requests, - get_curr_time, -) -from megatron.training.checkpointing import load_checkpoint - -from model_provider import model_provider -from gpt_builders import gpt_builder - torch.serialization.add_safe_globals([io.BytesIO]) torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunState]) torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunDiagnostic]) @@ -188,7 +175,7 @@ def get_inference_context( buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb, max_tokens=args.inference_dynamic_batching_max_tokens, tensor_model_parallel_size=args.tensor_model_parallel_size, - materialize_only_last_token_logits=not args.return_log_probs, + materialize_only_last_token_logits=not (args.return_log_probs or args.return_prompt_top_n_logprobs), mamba_inference_state_config=mamba_inference_state_config, cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents, kv_lora_rank=args.kv_lora_rank if args.multi_latent_attention else None, @@ -389,9 +376,15 @@ def _add_request(): # Log probs. if finished_request.sampling_params.return_log_probs: + if not finished_request.prompt_log_probs: + finished_request.prompt_log_probs = [] request.log_probs = ( finished_request.prompt_log_probs + finished_request.generated_log_probs ) + if finished_request.sampling_params.top_n_logprobs > 0: + request.generated_top_n_logprobs = finished_request.generated_top_n_logprobs + if finished_request.sampling_params.return_prompt_top_n_logprobs: + request.prompt_top_n_logprobs = finished_request.prompt_top_n_logprobs num_requests_finished += 1 output_times.append(get_curr_time() - output_start) @@ -434,9 +427,12 @@ def main(): temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + skip_prompt_log_probs=args.skip_prompt_log_probs, return_log_probs=args.return_log_probs, num_tokens_to_generate=args.num_tokens_to_generate, termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod, + top_n_logprobs=args.top_n_logprobs, + return_prompt_top_n_logprobs=args.return_prompt_top_n_logprobs, ) model = get_model() @@ -553,6 +549,7 @@ def escape_str(s): # Write every 'n' requests, plus the final request. for i, req in enumerate(requests): if i % args.output_every_n_results == 0 or i == len(requests) - 1: + print(f' Attributes of request {i}: {req.__dict__}') result_dict = { "input_prompt": req.prompt_text, "generated_text": req.output_text, @@ -560,6 +557,8 @@ def escape_str(s): "latency": req.time_end - req.time_start, "cuda_graph_request_count_map" : result["cuda_graph_request_count_map"], "step_count" : engine.step_count, + "top_n_logprobs" : getattr(req, 'generated_top_n_logprobs', None), + "prompt_top_n_logprobs" : getattr(req, 'prompt_top_n_logprobs', None), } if req.sampling_params.return_log_probs: response_logprobs = req.log_probs @@ -569,6 +568,7 @@ def escape_str(s): # Track system-level throughput as a test / debug metric json_results["throughput"] = throughputs + print(f' Saving results to {args.output_path}') with open(args.output_path, "w") as fp: json.dump(json_results, fp, indent=1) @@ -622,4 +622,4 @@ def escape_str(s): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/inference/gpt/utils.py b/examples/inference/gpt/utils.py index efd4fdab4f..fc05ec0c49 100644 --- a/examples/inference/gpt/utils.py +++ b/examples/inference/gpt/utils.py @@ -66,6 +66,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser: default=0, help='Return the top n logprobs for the generated tokens and their corresponding token as a dictionary', ) + group.add_argument( + "--return-prompt-top-n-logprobs", + action='store_true', + default=False, + help='Return the top n logprobs for the prompt tokens and their corresponding token as a dictionary', + ) group.add_argument( "--incoming-requests-per-step", type=int, default=None, @@ -96,6 +102,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser: default="gpt", help="Model provider", ) + group.add_argument( + "--skip-prompt-log-probs", + action='store_true', + default=False, + help='Skip prompt log probs.', + ) group.add_argument( "--output-path", type=str, @@ -300,7 +312,8 @@ def get_requests_from_file( # Load prompts. n_prompts = sum(1 for _ in open(args.prompt_file)) prompts = [] - sampling_params = get_default_sampling_params(tokenizer.eod) + if sampling_params is None: + sampling_params = get_default_sampling_params(tokenizer.eod) sampling_params_list = [] with open(args.prompt_file) as f: for line in tqdm(f.readlines(), "read prompt file", total=n_prompts): diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 43b3823b6a..7162c58f1f 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1745,7 +1745,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T def calculate_log_probs( self, logits: Tensor, new_tokens: Tensor, only_last_token_logits: Optional[bool] = False - ) -> List[List[float]]: + ) -> Tuple[List[List[float]], Tensor]: """Calculate log probs for all active requests and return them. TODO: @wdykas support top-n log probs. @@ -1758,14 +1758,16 @@ def calculate_log_probs( Returns: List of lists where each inner list contains log probs for a request in the same order as the active requests (from paused_request_count to total_request_count). + log_probs (Tensor): Used to compute top n logprobs later if required. """ + # Calculate log_probs (sequence_length x vocab_size) log_probs = F.log_softmax(logits.squeeze(0).float(), dim=-1) if only_last_token_logits or self.is_decode_only(): seq_idx = torch.arange(len(new_tokens), dtype=torch.int32, device=logits.device) selected_log_probs = log_probs[seq_idx, new_tokens] - return [[lp] for lp in selected_log_probs.flatten().tolist()] + return [[lp] for lp in selected_log_probs.flatten().tolist()], log_probs # Get the selected token ids for all tokens. # We shift the active token window left by one to remove the first prompt token for @@ -1808,7 +1810,7 @@ def calculate_log_probs( ) # Convert each log prob tensor into a list - return [lp.tolist() for lp in selected_log_probs_list] + return [lp.tolist() for lp in selected_log_probs_list], log_probs def get_kvcache_utilization_stats(self) -> dict: """Compute KV cache buffer utilization stats for the current step. diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 5fad136930..8bc9446c30 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -687,6 +687,14 @@ def _add_request( request.sampling_params.num_tokens_to_generate is None or request.sampling_params.num_tokens_total is None ) + if request.sampling_params.return_prompt_top_n_logprobs: + assert ( + request.sampling_params.return_log_probs + ), "return_prompt_top_n_logprobs requires sampling_params.return_log_probs to be True" + if request.sampling_params.top_n_logprobs > 0: + assert ( + request.sampling_params.return_log_probs + ), "top_n_logprobs requires sampling_params.return_log_probs to be True" if request.sampling_params.num_tokens_total is not None: request.sampling_params.num_tokens_to_generate = ( request.sampling_params.num_tokens_total - len(request.prompt_tokens) @@ -785,6 +793,7 @@ def post_process_requests( step_time: float, sample: torch.Tensor, log_probs: torch.Tensor, + top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None, ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]: """ Handles post-processing for requests after a step. @@ -795,6 +804,8 @@ def post_process_requests( step_time (float): The latency of the last step sample: (torch.Tensor): The newly generated tokens for each request log_probs: (List): Log probs for each request + top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to + list of (top_n_logprobs, top_n_indices) tuples. Returns: A list of active requests and completed requests as `DynamicInferenceRequest` objects @@ -806,8 +817,8 @@ def post_process_requests( log_probs_iter = log_probs if log_probs else repeat(None) - for request_id, token, request_log_probs in zip( - request_ids.tolist(), sample.tolist(), log_probs_iter + for req_idx, (request_id, token, request_log_probs) in enumerate( + zip(request_ids.tolist(), sample.tolist(), log_probs_iter) ): request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: @@ -823,19 +834,19 @@ def post_process_requests( request.generated_log_probs = [] # If the request log probs span > 1 token we are in prefill if len(request_log_probs) > 1: - request.prompt_log_probs.extend(request_log_probs) + # Add all but the last logprob to prompt_log_probs (last is for first generated token) + request.prompt_log_probs.extend(request_log_probs[:-1]) + # Add the last logprob to generated_log_probs (first generated token) + request.generated_log_probs.extend(request_log_probs[-1:]) else: if ( # If it is a chunked prefill request len(request.prompt_log_probs) > 0 # And we are missing the last token for prefill - and len(request.prompt_log_probs) < len(request.prompt_tokens) + and len(request.prompt_log_probs) < len(request.prompt_tokens) - 1 # And we need to track full prefill and not self.context.materialize_only_last_token_logits ): - assert ( - len(request.prompt_log_probs) == len(request.prompt_tokens) - 1 - ), "Prompt log probs length is not equal to prompt tokens length - 1" request.prompt_log_probs.extend(request_log_probs) else: request.generated_log_probs.extend(request_log_probs) @@ -874,7 +885,43 @@ def post_process_requests( request.prompt_log_probs = [] request.prompt_log_probs.extend(request_log_probs) request.generated_log_probs = [] - active_request_ids.append(request_id) + + active_request_ids.append(request_id) + + # Process top_n_logprobs if available (unified for both regular and chunked prefill) + if top_n_logprobs is not None and req_idx in top_n_logprobs: + # Initialize lists if they don't exist + if request.prompt_top_n_logprobs is None: + request.prompt_top_n_logprobs = [] + if request.generated_top_n_logprobs is None: + request.generated_top_n_logprobs = [] + + top_n_data_list = top_n_logprobs[req_idx] + prompt_length = len(request.prompt_tokens) + + # Process each token's top-n logprobs + for top_n_values, top_n_indices in top_n_data_list: + logit_dict = {} + for logprob, logprob_index in zip( + top_n_values.cpu().tolist(), top_n_indices.cpu().tolist() + ): + key = self.controller.tokenizer.detokenize([logprob_index]) + logit_dict[key] = logprob + + # Simple decision: check total count accumulated so far + total_accumulated = len(request.prompt_top_n_logprobs) + len( + request.generated_top_n_logprobs + ) + + # If return_prompt_top_n_logprobs is True and we haven't reached prompt end, + # append to prompt_top_n_logprobs. Otherwise append to generated_top_n_logprobs. + if ( + request.sampling_params.return_prompt_top_n_logprobs + and total_accumulated < prompt_length - 1 + ): + request.prompt_top_n_logprobs.append(logit_dict) + else: + request.generated_top_n_logprobs.append(logit_dict) return active_request_ids, finished_request_records @@ -1059,12 +1106,14 @@ async def async_bookkeep( """ # Increment finished_request_count. cuda_graph_request_count = None + if step_result is not None: active_request_ids = step_result["active_request_ids"] newly_paused_request_ids = step_result["newly_paused_request_ids"] finished_request_ids = step_result["finished_request_ids"] sample = step_result["sample"] log_probs = step_result["log_probs"] + top_n_logprobs = step_result.get("top_n_logprobs", None) cuda_graph_request_count = step_result["cuda_graph_request_count"] # Add paused events. @@ -1074,10 +1123,14 @@ async def async_bookkeep( # Mark requests finished. [self.get_request(i).add_event_finish() for i in finished_request_ids.tolist()] - # Add finished events. - active_request_ids, finished_request_records = self.post_process_requests( - active_request_ids, finished_request_ids, step_time, sample, log_probs + (active_request_ids, finished_request_records) = self.post_process_requests( + active_request_ids, + finished_request_ids, + step_time, + sample, + log_probs, + top_n_logprobs, ) else: diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index b58fac1b28..ffbab5e02b 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -325,6 +325,8 @@ def get_metadata_labels() -> Dict[str, int]: "termination_id", "return_log_probs", "skip_prompt_log_probs", + "top_n_logprobs", + "return_prompt_top_n_logprobs", ] return {k: v for v, k in enumerate(ret)} diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 2b44b41874..c0198e47d7 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -46,6 +46,7 @@ HAVE_TE = False +# pylint: disable=line-too-long class TextGenerationController: """The text generation controller (the main sampling loop) @@ -103,6 +104,10 @@ def _init_dynamic_sampling_tensors(self): self.termination_id_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) self.return_log_probs_cuda = torch.empty(max_requests, dtype=torch.bool, device=device) self.skip_prompt_log_probs_cuda = torch.empty(max_requests, dtype=torch.bool, device=device) + self.top_n_logprobs_cuda = torch.empty(max_requests, dtype=torch.int32, device=device) + self.return_prompt_top_n_logprobs_cuda = torch.empty( + max_requests, dtype=torch.bool, device=device + ) # Used for inefficient torch sampling. self.torch_sampling_buckets: List[Tensor] = [] @@ -624,6 +629,12 @@ def _dynamic_step_sample_bookkeeping( self.skip_prompt_log_probs_cuda[:active_request_count] = request_metadata[ :, request_metadata_labels["skip_prompt_log_probs"] ].to(dtype=torch.bool, copy=True, non_blocking=True) + self.top_n_logprobs_cuda[:active_request_count] = request_metadata[ + :, request_metadata_labels["top_n_logprobs"] + ].to(dtype=torch.int32, copy=True, non_blocking=True) + self.return_prompt_top_n_logprobs_cuda[:active_request_count] = request_metadata[ + :, request_metadata_labels["return_prompt_top_n_logprobs"] + ].to(dtype=torch.bool, copy=True, non_blocking=True) if backend == "torch": # Bucketize the core sampling parameters. @@ -699,15 +710,35 @@ def _dynamic_step_log_probs_bookkeeping(self) -> bool: active_request_count = context.total_request_count - context.paused_request_count - to_check = self.return_log_probs_cuda[:active_request_count] + # Create a copy to avoid modifying the original tensor with in-place operations + to_check = self.return_log_probs_cuda[:active_request_count].clone() to_check &= ~self.skip_prompt_log_probs_cuda[:active_request_count] assert not ( to_check.any() and materialize_only_last_token_logits - ), "Prompt log probs cannot be calculated if only last token logits are materialized." + ), "Prompt log probs cannot be calculated if only last token logits are materialized. Set materialize_only_last_token_logits to False in DynamicInferenceContext or skip_prompt_log_probs to True in SamplingParams." return self.return_log_probs_cuda[:active_request_count].any() + def _dynamic_step_top_n_logprobs_bookkeeping(self) -> bool: + """Perform bookkeeping necessary to compute top-n log probs for dynamic batching.""" + context = self.inference_wrapped_model.inference_context + materialize_only_last_token_logits = context.materialize_only_last_token_logits + + active_request_count = context.total_request_count - context.paused_request_count + + # Check if any request wants prompt top-n logprobs (top_n > 0 AND return_prompt_top_n = True) + # Create a copy to avoid modifying the original tensor with in-place operations + to_check = (self.top_n_logprobs_cuda[:active_request_count] > 0).clone() + to_check &= self.return_prompt_top_n_logprobs_cuda[:active_request_count] + + assert not ( + to_check.any() and materialize_only_last_token_logits + ), "Prompt top-n logprobs cannot be calculated if only last token logits are materialized. Set materialize_only_last_token_logits to False in DynamicInferenceContext or set return_prompt_top_n_logprobs to False in SamplingParams." + + # Check if any request has top_n_logprobs > 0 + return (self.top_n_logprobs_cuda[:active_request_count] > 0).any() + def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: """Calculate log probs from logits.""" context = self.inference_wrapped_model.inference_context @@ -715,12 +746,94 @@ def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: active_request_count = context.total_request_count - context.paused_request_count - ret = context.calculate_log_probs( + return context.calculate_log_probs( logits, self.sampled_tokens_cuda[:active_request_count], only_last_token_logits=materialize_only_last_token_logits, ) - return ret + + def _dynamic_step_calculate_top_n_logprobs( + self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None + ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: + """Calculate top-n log probs from logits for dynamic batching. + + Args: + logits (Tensor): The logits to compute top-n log probs from. + log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor. + If provided, avoids recomputing log_softmax. Should be the tensor + returned by calculate_log_probs. + + Returns: + A dictionary mapping request_idx to list of (top_n_logprobs, top_n_indices) tuples. + Each tuple in the list represents one token position. + """ + assert log_probs_tensor is not None, ( + "log_probs_tensor must be provided. This should be guaranteed by the calling code " + "computing log_probs when return_top_n_logprobs is True." + ) + + context = self.inference_wrapped_model.inference_context + materialize_only_last_token_logits = context.materialize_only_last_token_logits + + active_request_count = context.total_request_count - context.paused_request_count + + # Handle decode-only mode (only last token) + if materialize_only_last_token_logits or context.is_decode_only(): + # In decode mode or when only last token logits are materialized, + # logits already represent only the last tokens + log_probs = log_probs_tensor[:active_request_count] + + top_n_results = {} + for req_idx in range(active_request_count): + top_n = int(self.top_n_logprobs_cuda[req_idx].item()) + if top_n > 0: + # Get top-n logprobs and indices for this request (single token) + top_n_logits = torch.topk(log_probs[req_idx], k=top_n) + top_n_results[req_idx] = [ + (top_n_logits.values.cpu(), top_n_logits.indices.cpu()) + ] + return top_n_results if top_n_results else None + + # Handle prefill mode - need to extract top-n for tokens per request + # This follows the same pattern as calculate_log_probs in dynamic_context.py + # Note: logits may be padded, so we only take the first active_token_count tokens + log_probs = log_probs_tensor[: context.active_token_count] + + active_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + # Split log_probs across request boundaries + # log_probs has shape [active_token_count, vocab_size] + log_probs_per_request = log_probs.split(active_query_lengths.tolist(), dim=0) + + top_n_results = {} + for req_idx in range(active_request_count): + top_n = int(self.top_n_logprobs_cuda[req_idx].item()) + if top_n > 0: + request_log_probs = log_probs_per_request[ + req_idx + ] # [num_tokens_for_request, vocab_size] + return_prompt_top_n = bool(self.return_prompt_top_n_logprobs_cuda[req_idx].item()) + + # If return_prompt_top_n_logprobs is False, only compute for last token + if not return_prompt_top_n and request_log_probs.size(0) > 1: + # Only compute top-n for the last token (first generated token) + top_n_logits = torch.topk(request_log_probs[-1], k=top_n) + top_n_results[req_idx] = [ + (top_n_logits.values.cpu(), top_n_logits.indices.cpu()) + ] + else: + # Compute top-n for all tokens in the request + top_n_per_token = [] + for token_idx in range(request_log_probs.size(0)): + top_n_logits = torch.topk(request_log_probs[token_idx], k=top_n) + top_n_per_token.append( + (top_n_logits.values.cpu(), top_n_logits.indices.cpu()) + ) + top_n_results[req_idx] = top_n_per_token + + return top_n_results if top_n_results else None def _dynamic_step_context_bookkeeping(self, new_sample) -> Dict[str, Tensor]: """Update the dynamic inference context after sampling. @@ -785,9 +898,6 @@ async def async_generate_output_tokens_dynamic_batch( cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step. """ context = self.inference_wrapped_model.inference_context - materialize_only_last_token_logits = context.materialize_only_last_token_logits - - inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config # No tokens? if context.active_token_count == 0: @@ -814,10 +924,16 @@ async def async_generate_output_tokens_dynamic_batch( new_sample = self._dynamic_step_sample_logits(logits) return_log_probs = self._dynamic_step_log_probs_bookkeeping() - if return_log_probs: - log_probs = self._dynamic_step_calculate_log_probs(logits) - else: - log_probs = None + return_top_n_logprobs = self._dynamic_step_top_n_logprobs_bookkeeping() + + log_probs = None + top_n_logprobs = None + if return_log_probs or return_top_n_logprobs: + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) + if return_top_n_logprobs: + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( + logits, log_probs_tensor + ) if skip_bookkeeping: request_bookkeeping = {} @@ -827,6 +943,7 @@ async def async_generate_output_tokens_dynamic_batch( ret = { "sample": new_sample, "log_probs": log_probs, + "top_n_logprobs": top_n_logprobs, "cuda_graph_request_count": cuda_graph_request_count, } ret.update(request_bookkeeping) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 87a5aba23f..4f2c7dff71 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -890,7 +890,9 @@ def test_calculate_and_store_log_probs(self): prefill_new_tokens = torch.randint(0, 100, (num_active_requests,), device='cuda').long() # Call the function for prefill - prefill_log_probs = dynamic_context.calculate_log_probs(prefill_logits, prefill_new_tokens) + prefill_log_probs, _ = dynamic_context.calculate_log_probs( + prefill_logits, prefill_new_tokens + ) # Calculate expected prefill log probs for the selected tokens expected_prefill_log_probs = ( @@ -928,7 +930,7 @@ def test_calculate_and_store_log_probs(self): 1, num_active_requests, vocab_size, device='cuda', dtype=torch.float32 ) decode_new_tokens = torch.randint(0, 100, (num_active_requests,), device='cuda').long() - decode_log_probs = dynamic_context.calculate_log_probs(decode_logits, decode_new_tokens) + decode_log_probs, _ = dynamic_context.calculate_log_probs(decode_logits, decode_new_tokens) # Verify the stored decode log probabilities expected_decode_log_probs = torch.nn.functional.log_softmax( @@ -983,7 +985,7 @@ def test_calculate_and_store_log_probs(self): 0, 100, (num_active_requests_mixed_step,), device='cuda' ).long() - mixed_step_log_probs = dynamic_context.calculate_log_probs( + mixed_step_log_probs, _ = dynamic_context.calculate_log_probs( mixed_step_logits, mixed_step_new_tokens ) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 7b281ddf80..c05d00e32f 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1085,51 +1085,131 @@ def test_chunked_prefill(self, model_provider: str): @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - @pytest.mark.skip( - reason="test works in isolation, but memory dynamics change when run " - "within unt test suite." - ) - def test_suspend_resume_memory(self): - - # Run tests. - mem_usages = {} - for suspend_resume_interval in None, 8, 4, 2: # interval 1 acts funny. - - # Run test. - env = self._run_test(suspend_resume_interval=suspend_resume_interval, num_gap_steps=1) - - # Record memory usage. - mem_usages[suspend_resume_interval] = env.mem_usage - - # Clear memory to make recorded memories consistent between tests. - # TODO(@lmcafee): why is memory not automatically cleared? - # env.engine.suspend() # TODO(@lmcafee): useful? - del env - - # Utility methods. - get_alloc = lambda mem_stats: mem_stats["allocated_bytes.all.current"] - - # Validate overall 'end' memory usage. - golden_end_bytes = get_alloc(mem_usages[None]["end"]) - for interval, mem_usage in mem_usages.items(): - current_end_bytes = get_alloc(mem_usage["end"]) - assert math.isclose( - golden_end_bytes, current_end_bytes, rel_tol=0.01 - ), f"{current_end_bytes} != {golden_end_bytes}." - - # Validate 'suspend/resume' memory usage. - get_suspend_resume_bytes = lambda key: list( - get_alloc(list(d["suspend_resume"].values())[-1][key]) - for i, d in mem_usages.items() - if i is not None + @pytest.mark.parametrize("return_prompt_top_n_logprobs", [True, False]) + @torch.inference_mode() + def test_top_n_logprobs_dynamic(self, return_prompt_top_n_logprobs: bool): + """ + Test that top_n_logprobs are computed correctly in dynamic batching mode. + Verifies: + 1. top_n_logprobs are returned for generated tokens + 2. return_prompt_top_n_logprobs controls whether prompt top-n logprobs are returned + 3. The top-n values are consistent with the selected token's log prob + """ + # Build test environment with multiple requests of varying lengths + test_config = DynamicEngineTestConfig( + num_requests=4, + min_prompt_length=4, + max_prompt_length=12, + num_tokens_to_generate=4, + materialize_only_last_token_logits=False, ) - suspend_resume_mid_bytes = get_suspend_resume_bytes("mid") - suspend_resume_end_bytes = get_suspend_resume_bytes("end") - for mid_bytes in suspend_resume_mid_bytes: - assert math.isclose( - suspend_resume_mid_bytes[0], mid_bytes, rel_tol=0.01 - ), f"{mid_bytes} != {suspend_resume_mid_bytes[0]}." - for end_bytes in suspend_resume_end_bytes: - assert math.isclose( - suspend_resume_end_bytes[0], end_bytes, rel_tol=0.01 - ), f"{end_bytes} != {suspend_resume_end_bytes[0]}." + env = self._build_test_env(test_config) + + # Create requests with top_n_logprobs enabled + top_n = 5 + requests_to_add = [] + for request in env.requests: + # Update sampling params to include top_n_logprobs + request.sampling_params = SamplingParams( + num_tokens_to_generate=test_config.num_tokens_to_generate, + termination_id=test_config.vocab_size - 1, + return_log_probs=True, + top_n_logprobs=top_n, + return_prompt_top_n_logprobs=return_prompt_top_n_logprobs, + top_k=10, # Add some sampling randomness + ) + requests_to_add.append(request) + + # Add requests and run inference + for request in requests_to_add: + env.engine._add_request(request) + + # Step engine until all requests are finished + while env.engine.has_unfinished_requests(): + result = env.engine.step_modern(verbose=False) + + # Validate results + for request in requests_to_add: + assert request.status == Status.COMPLETED, f"Request {request.request_id} not completed" + + # Validate generated top-n logprobs + assert hasattr( + request, 'generated_top_n_logprobs' + ), f"Request {request.request_id} missing generated_top_n_logprobs" + assert ( + request.generated_top_n_logprobs is not None + ), f"Request {request.request_id} has None generated_top_n_logprobs" + assert len(request.generated_top_n_logprobs) == len( + request.generated_tokens + ), f"Request {request.request_id}: generated_top_n_logprobs length mismatch" + + # Validate each top-n dict + for i, top_n_dict in enumerate(request.generated_top_n_logprobs): + assert isinstance( + top_n_dict, dict + ), f"Request {request.request_id}, token {i}: top_n_dict is not a dict" + assert ( + len(top_n_dict) <= top_n + ), f"Request {request.request_id}, token {i}: too many top-n entries" + assert ( + len(top_n_dict) > 0 + ), f"Request {request.request_id}, token {i}: empty top-n dict" + + # Validate prompt top-n logprobs based on return_prompt_top_n_logprobs flag + if return_prompt_top_n_logprobs: + assert hasattr( + request, 'prompt_top_n_logprobs' + ), f"Request {request.request_id} missing prompt_top_n_logprobs" + assert ( + request.prompt_top_n_logprobs is not None + ), f"Request {request.request_id} has None prompt_top_n_logprobs" + # Prompt top-n should have N-1 entries (excluding first token) + expected_prompt_top_n_len = len(request.prompt_tokens) - 1 + assert ( + len(request.prompt_top_n_logprobs) == expected_prompt_top_n_len + ), f"Request {request.request_id}: prompt_top_n_logprobs length {len(request.prompt_top_n_logprobs)} != expected {expected_prompt_top_n_len}" + + # Validate each prompt top-n dict + for i, top_n_dict in enumerate(request.prompt_top_n_logprobs): + assert isinstance( + top_n_dict, dict + ), f"Request {request.request_id}, prompt token {i}: top_n_dict is not a dict" + assert ( + len(top_n_dict) <= top_n + ), f"Request {request.request_id}, prompt token {i}: too many top-n entries" + assert ( + len(top_n_dict) > 0 + ), f"Request {request.request_id}, prompt token {i}: empty top-n dict" + else: + # When return_prompt_top_n_logprobs is False, prompt_top_n_logprobs should be None or empty + if hasattr(request, 'prompt_top_n_logprobs'): + assert ( + request.prompt_top_n_logprobs is None + or len(request.prompt_top_n_logprobs) == 0 + ), f"Request {request.request_id}: prompt_top_n_logprobs should be None or empty when return_prompt_top_n_logprobs is False" + + # Validate consistency between log_probs and top_n_logprobs + if hasattr(request, 'generated_log_probs') and request.generated_log_probs is not None: + assert len(request.generated_log_probs) == len( + request.generated_top_n_logprobs + ), f"Request {request.request_id}: generated_log_probs and generated_top_n_logprobs length mismatch" + + # Check that the selected token's log prob appears in the top-n + for i, (log_prob, top_n_dict, token_id) in enumerate( + zip( + request.generated_log_probs, + request.generated_top_n_logprobs, + request.generated_tokens, + ) + ): + # Get the token string for this token_id + token_str = env.engine.controller.tokenizer.detokenize([token_id]) + # The selected token should be in the top-n + assert ( + token_str in top_n_dict + ), f"Request {request.request_id}, token {i}: selected token '{token_str}' not in top-n" + # The log prob should match (with some tolerance for floating point precision) + # Using 0.1 tolerance to account for FP16/BF16 precision in mixed precision training + assert ( + abs(log_prob - top_n_dict[token_str]) < 0.1 + ), f"Request {request.request_id}, token {i}: log_prob mismatch {log_prob} vs {top_n_dict[token_str]}" diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index ee6bc5b246..cad2377455 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -732,6 +732,135 @@ def test_zero_tokens_generated_batch_vs_single(self): == request_single.prompt_top_n_logprobs[i][token_str] ) + @pytest.mark.parametrize("return_prompt_top_n_logprobs", [True, False]) + @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) + def test_dynamic_top_n_logprobs_calculation( + self, return_prompt_top_n_logprobs: bool, materialize_only_last_token_logits: bool + ): + """ + Test the _dynamic_step_calculate_top_n_logprobs function directly. + Verifies: + 1. top_n_logprobs are computed for all requests + 2. return_prompt_top_n_logprobs controls computation for prompt tokens + 3. Correct number of tokens are returned for each request + """ + batch_size = 4 + self.setup_model(torch.bfloat16, batch_size=batch_size, static=False) + self.mock_tokenizer.eod = self.vocab_size + + context = self.text_generation_controller.inference_wrapped_model.inference_context + context.materialize_only_last_token_logits = materialize_only_last_token_logits + + # Prepare sampling params + top_n = 5 + request_metadata_labels = DynamicInferenceRequest.get_metadata_labels() + request_metadata = torch.empty( + (batch_size, len(request_metadata_labels)), dtype=torch.float32 + ).cuda() + + # Set top_n_logprobs for all requests + request_metadata[:, request_metadata_labels["top_n_logprobs"]] = top_n + request_metadata[:, request_metadata_labels["return_prompt_top_n_logprobs"]] = float( + return_prompt_top_n_logprobs + ) + + # Bookkeeping + self.text_generation_controller._dynamic_step_sample_bookkeeping( + request_metadata=request_metadata + ) + + if materialize_only_last_token_logits: + # Decode mode: logits for last tokens only + logits = torch.randn(1, batch_size, self.vocab_size).cuda() + + # Set up context state for decode mode + context.paused_request_count = 0 + context.total_request_count = batch_size + + # Compute log probabilities (required by _dynamic_step_calculate_top_n_logprobs) + # Note: squeeze(0) to match what calculate_log_probs does in dynamic_context.py + log_probs_tensor = torch.nn.functional.log_softmax(logits.squeeze(0), dim=-1) + + # Calculate top-n logprobs + top_n_results = self.text_generation_controller._dynamic_step_calculate_top_n_logprobs( + logits, log_probs_tensor + ) + + # Validate results + assert top_n_results is not None, "top_n_results should not be None" + assert ( + len(top_n_results) == batch_size + ), f"Expected {batch_size} requests, got {len(top_n_results)}" + + for req_idx in range(batch_size): + assert req_idx in top_n_results, f"Request {req_idx} missing from results" + top_n_list = top_n_results[req_idx] + + # In decode mode, should have exactly 1 token per request + assert ( + len(top_n_list) == 1 + ), f"Request {req_idx}: expected 1 token, got {len(top_n_list)}" + + top_n_values, top_n_indices = top_n_list[0] + assert top_n_values.shape[0] == top_n, f"Expected {top_n} values" + assert top_n_indices.shape[0] == top_n, f"Expected {top_n} indices" + else: + # Prefill mode: logits for all tokens + # Simulate different prompt lengths + query_lengths = [4, 6, 5, 7] # Different lengths for each request + total_tokens = sum(query_lengths) + + # Set up context state + context.paused_request_count = 0 + context.total_request_count = batch_size + context.active_token_count = total_tokens + context.request_query_lengths = torch.tensor( + [0] * context.paused_request_count + query_lengths, dtype=torch.int32, device='cuda' + ) + + # Create logits for all tokens + logits = torch.randn(1, total_tokens, self.vocab_size).cuda() + + # Compute log probabilities (required by _dynamic_step_calculate_top_n_logprobs) + # Note: squeeze(0) to match what calculate_log_probs does in dynamic_context.py + log_probs_tensor = torch.nn.functional.log_softmax(logits.squeeze(0), dim=-1) + + # Calculate top-n logprobs + top_n_results = self.text_generation_controller._dynamic_step_calculate_top_n_logprobs( + logits, log_probs_tensor + ) + + # Validate results + assert top_n_results is not None, "top_n_results should not be None" + assert ( + len(top_n_results) == batch_size + ), f"Expected {batch_size} requests, got {len(top_n_results)}" + + for req_idx in range(batch_size): + assert req_idx in top_n_results, f"Request {req_idx} missing from results" + top_n_list = top_n_results[req_idx] + + if return_prompt_top_n_logprobs: + # Should have top-n for all tokens + expected_count = query_lengths[req_idx] + assert ( + len(top_n_list) == expected_count + ), f"Request {req_idx}: expected {expected_count} tokens, got {len(top_n_list)}" + else: + # Should have top-n for only the last token (first generated token) + assert ( + len(top_n_list) == 1 + ), f"Request {req_idx}: expected 1 token when return_prompt_top_n_logprobs=False, got {len(top_n_list)}" + + # Validate each token's top-n + for token_idx, (top_n_values, top_n_indices) in enumerate(top_n_list): + assert ( + top_n_values.shape[0] == top_n + ), f"Request {req_idx}, token {token_idx}: expected {top_n} values" + assert ( + top_n_indices.shape[0] == top_n + ), f"Request {req_idx}, token {token_idx}: expected {top_n} indices" + @pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("pp_size", [1, 2])