Skip to content
32 changes: 16 additions & 16 deletions examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pickle
import sys
import warnings
import torch
from argparse import ArgumentParser
from collections import defaultdict
Expand Down Expand Up @@ -56,20 +57,6 @@

from megatron.core.utils import configure_nvtx_profiling

import json

from examples.inference.gpt.utils import (
Request,
add_common_inference_args,
build_dynamic_engine_setup_prefix,
build_requests,
get_curr_time,
)
from megatron.training.checkpointing import load_checkpoint

from model_provider import model_provider
from gpt_builders import gpt_builder

torch.serialization.add_safe_globals([io.BytesIO])
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunState])
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunDiagnostic])
Expand Down Expand Up @@ -188,7 +175,7 @@ def get_inference_context(
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
max_tokens=args.inference_dynamic_batching_max_tokens,
tensor_model_parallel_size=args.tensor_model_parallel_size,
materialize_only_last_token_logits=not args.return_log_probs,
materialize_only_last_token_logits=not (args.return_log_probs or args.return_prompt_top_n_logprobs),
mamba_inference_state_config=mamba_inference_state_config,
cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents,
kv_lora_rank=args.kv_lora_rank if args.multi_latent_attention else None,
Expand Down Expand Up @@ -389,9 +376,15 @@ def _add_request():

# Log probs.
if finished_request.sampling_params.return_log_probs:
if not finished_request.prompt_log_probs:
finished_request.prompt_log_probs = []
request.log_probs = (
finished_request.prompt_log_probs + finished_request.generated_log_probs
)
if finished_request.sampling_params.top_n_logprobs > 0:
request.generated_top_n_logprobs = finished_request.generated_top_n_logprobs
if finished_request.sampling_params.return_prompt_top_n_logprobs:
request.prompt_top_n_logprobs = finished_request.prompt_top_n_logprobs
num_requests_finished += 1
output_times.append(get_curr_time() - output_start)

Expand Down Expand Up @@ -434,9 +427,12 @@ def main():
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
skip_prompt_log_probs=args.skip_prompt_log_probs,
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate,
termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod,
top_n_logprobs=args.top_n_logprobs,
return_prompt_top_n_logprobs=args.return_prompt_top_n_logprobs,
)

model = get_model()
Expand Down Expand Up @@ -553,13 +549,16 @@ def escape_str(s):
# Write every 'n' requests, plus the final request.
for i, req in enumerate(requests):
if i % args.output_every_n_results == 0 or i == len(requests) - 1:
print(f' Attributes of request {i}: {req.__dict__}')
result_dict = {
"input_prompt": req.prompt_text,
"generated_text": req.output_text,
"generated_tokens": req.output_tokens,
"latency": req.time_end - req.time_start,
"cuda_graph_request_count_map" : result["cuda_graph_request_count_map"],
"step_count" : engine.step_count,
"top_n_logprobs" : getattr(req, 'generated_top_n_logprobs', None),
"prompt_top_n_logprobs" : getattr(req, 'prompt_top_n_logprobs', None),
}
if req.sampling_params.return_log_probs:
response_logprobs = req.log_probs
Expand All @@ -569,6 +568,7 @@ def escape_str(s):
# Track system-level throughput as a test / debug metric
json_results["throughput"] = throughputs

print(f' Saving results to {args.output_path}')
with open(args.output_path, "w") as fp:
json.dump(json_results, fp, indent=1)

Expand Down Expand Up @@ -622,4 +622,4 @@ def escape_str(s):


if __name__ == "__main__":
main()
main()
15 changes: 14 additions & 1 deletion examples/inference/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
default=0,
help='Return the top n logprobs for the generated tokens and their corresponding token as a dictionary',
)
group.add_argument(
"--return-prompt-top-n-logprobs",
action='store_true',
default=False,
help='Return the top n logprobs for the prompt tokens and their corresponding token as a dictionary',
)
group.add_argument(
"--incoming-requests-per-step",
type=int, default=None,
Expand Down Expand Up @@ -96,6 +102,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
default="gpt",
help="Model provider",
)
group.add_argument(
"--skip-prompt-log-probs",
action='store_true',
default=False,
help='Skip prompt log probs.',
)
group.add_argument(
"--output-path",
type=str,
Expand Down Expand Up @@ -300,7 +312,8 @@ def get_requests_from_file(
# Load prompts.
n_prompts = sum(1 for _ in open(args.prompt_file))
prompts = []
sampling_params = get_default_sampling_params(tokenizer.eod)
if sampling_params is None:
sampling_params = get_default_sampling_params(tokenizer.eod)
sampling_params_list = []
with open(args.prompt_file) as f:
for line in tqdm(f.readlines(), "read prompt file", total=n_prompts):
Expand Down
8 changes: 5 additions & 3 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T

def calculate_log_probs(
self, logits: Tensor, new_tokens: Tensor, only_last_token_logits: Optional[bool] = False
) -> List[List[float]]:
) -> Tuple[List[List[float]], Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, will this break any downstream code or is this private?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point that I missed.

Subscribing to this thread; I am also very interested in the answer to this, since I intended to optimize the calculate_log_probs method in 2 PRs' time (circa Dec 10th?).

@wdykas I think you are the main author of this method. Do you know how much (if any) downstream dependence there should be on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used by internal code. So shouldnt be a problem.

"""Calculate log probs for all active requests and return them.

TODO: @wdykas support top-n log probs.
Expand All @@ -1762,14 +1762,16 @@ def calculate_log_probs(
Returns:
List of lists where each inner list contains log probs for a request in the
same order as the active requests (from paused_request_count to total_request_count).
log_probs (Tensor): Used to compute top n logprobs later if required.
"""

# Calculate log_probs (sequence_length x vocab_size)
log_probs = F.log_softmax(logits.squeeze(0).float(), dim=-1)

if only_last_token_logits or self.is_decode_only():
seq_idx = torch.arange(len(new_tokens), dtype=torch.int32, device=logits.device)
selected_log_probs = log_probs[seq_idx, new_tokens]
return [[lp] for lp in selected_log_probs.flatten().tolist()]
return [[lp] for lp in selected_log_probs.flatten().tolist()], log_probs

# Get the selected token ids for all tokens.
# We shift the active token window left by one to remove the first prompt token for
Expand Down Expand Up @@ -1812,7 +1814,7 @@ def calculate_log_probs(
)

# Convert each log prob tensor into a list
return [lp.tolist() for lp in selected_log_probs_list]
return [lp.tolist() for lp in selected_log_probs_list], log_probs

def get_kvcache_utilization_stats(self) -> dict:
"""Compute KV cache buffer utilization stats for the current step.
Expand Down
75 changes: 64 additions & 11 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,14 @@ def _add_request(
request.sampling_params.num_tokens_to_generate is None
or request.sampling_params.num_tokens_total is None
)
if request.sampling_params.return_prompt_top_n_logprobs:
assert (
request.sampling_params.return_log_probs
), "return_prompt_top_n_logprobs requires sampling_params.return_log_probs to be True"
if request.sampling_params.top_n_logprobs > 0:
assert (
request.sampling_params.return_log_probs
), "top_n_logprobs requires sampling_params.return_log_probs to be True"
if request.sampling_params.num_tokens_total is not None:
request.sampling_params.num_tokens_to_generate = (
request.sampling_params.num_tokens_total - len(request.prompt_tokens)
Expand Down Expand Up @@ -785,6 +793,7 @@ def post_process_requests(
step_time: float,
sample: torch.Tensor,
log_probs: torch.Tensor,
top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None,
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
"""
Handles post-processing for requests after a step.
Expand All @@ -795,6 +804,8 @@ def post_process_requests(
step_time (float): The latency of the last step
sample: (torch.Tensor): The newly generated tokens for each request
log_probs: (List): Log probs for each request
top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to
list of (top_n_logprobs, top_n_indices) tuples.

Returns:
A list of active requests and completed requests as `DynamicInferenceRequest` objects
Expand All @@ -806,8 +817,8 @@ def post_process_requests(

log_probs_iter = log_probs if log_probs else repeat(None)

for request_id, token, request_log_probs in zip(
request_ids.tolist(), sample.tolist(), log_probs_iter
for req_idx, (request_id, token, request_log_probs) in enumerate(
zip(request_ids.tolist(), sample.tolist(), log_probs_iter)
):
request: DynamicInferenceRequest = self.get_request(request_id)
if request_id != self.context.chunked_prefill_request_id:
Expand All @@ -823,19 +834,19 @@ def post_process_requests(
request.generated_log_probs = []
# If the request log probs span > 1 token we are in prefill
if len(request_log_probs) > 1:
request.prompt_log_probs.extend(request_log_probs)
# Add all but the last logprob to prompt_log_probs (last is for first generated token)
request.prompt_log_probs.extend(request_log_probs[:-1])
# Add the last logprob to generated_log_probs (first generated token)
request.generated_log_probs.extend(request_log_probs[-1:])
else:
if (
# If it is a chunked prefill request
len(request.prompt_log_probs) > 0
# And we are missing the last token for prefill
and len(request.prompt_log_probs) < len(request.prompt_tokens)
and len(request.prompt_log_probs) < len(request.prompt_tokens) - 1
# And we need to track full prefill
and not self.context.materialize_only_last_token_logits
):
assert (
len(request.prompt_log_probs) == len(request.prompt_tokens) - 1
), "Prompt log probs length is not equal to prompt tokens length - 1"
request.prompt_log_probs.extend(request_log_probs)
else:
request.generated_log_probs.extend(request_log_probs)
Expand Down Expand Up @@ -874,7 +885,43 @@ def post_process_requests(
request.prompt_log_probs = []
request.prompt_log_probs.extend(request_log_probs)
request.generated_log_probs = []
active_request_ids.append(request_id)

active_request_ids.append(request_id)

# Process top_n_logprobs if available (unified for both regular and chunked prefill)
if top_n_logprobs is not None and req_idx in top_n_logprobs:
# Initialize lists if they don't exist
if request.prompt_top_n_logprobs is None:
request.prompt_top_n_logprobs = []
if request.generated_top_n_logprobs is None:
request.generated_top_n_logprobs = []

top_n_data_list = top_n_logprobs[req_idx]
prompt_length = len(request.prompt_tokens)

# Process each token's top-n logprobs
for top_n_values, top_n_indices in top_n_data_list:
logit_dict = {}
for logprob, logprob_index in zip(
top_n_values.cpu().tolist(), top_n_indices.cpu().tolist()
):
key = self.controller.tokenizer.detokenize([logprob_index])
logit_dict[key] = logprob

# Simple decision: check total count accumulated so far
total_accumulated = len(request.prompt_top_n_logprobs) + len(
request.generated_top_n_logprobs
)

# If return_prompt_top_n_logprobs is True and we haven't reached prompt end,
# append to prompt_top_n_logprobs. Otherwise append to generated_top_n_logprobs.
if (
request.sampling_params.return_prompt_top_n_logprobs
and total_accumulated < prompt_length - 1
):
request.prompt_top_n_logprobs.append(logit_dict)
else:
request.generated_top_n_logprobs.append(logit_dict)

return active_request_ids, finished_request_records

Expand Down Expand Up @@ -1059,12 +1106,14 @@ async def async_bookkeep(
"""
# Increment finished_request_count.
cuda_graph_request_count = None

if step_result is not None:
active_request_ids = step_result["active_request_ids"]
newly_paused_request_ids = step_result["newly_paused_request_ids"]
finished_request_ids = step_result["finished_request_ids"]
sample = step_result["sample"]
log_probs = step_result["log_probs"]
top_n_logprobs = step_result.get("top_n_logprobs", None)
cuda_graph_request_count = step_result["cuda_graph_request_count"]

# Add paused events.
Expand All @@ -1074,10 +1123,14 @@ async def async_bookkeep(

# Mark requests finished.
[self.get_request(i).add_event_finish() for i in finished_request_ids.tolist()]

# Add finished events.
active_request_ids, finished_request_records = self.post_process_requests(
active_request_ids, finished_request_ids, step_time, sample, log_probs
(active_request_ids, finished_request_records) = self.post_process_requests(
active_request_ids,
finished_request_ids,
step_time,
sample,
log_probs,
top_n_logprobs,
)

else:
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def get_metadata_labels() -> Dict[str, int]:
"termination_id",
"return_log_probs",
"skip_prompt_log_probs",
"top_n_logprobs",
"return_prompt_top_n_logprobs",
]
return {k: v for v, k in enumerate(ret)}

Expand Down
Loading