Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_inference_context(
active_buffer_size_gb=args.inference_dynamic_batching_active_buffer_size_gb,
max_tokens=args.inference_dynamic_batching_max_tokens,
tensor_model_parallel_size=args.tensor_model_parallel_size,
materialize_only_last_token_logits=not args.return_log_probs,
materialize_only_last_token_logits=not (args.return_log_probs or args.return_prompt_top_n_logprobs),
layer_type_list=layer_type_list,
mamba_conv_states_shape=mamba_conv_states_shape,
mamba_ssm_states_shape=mamba_ssm_states_shape,
Expand Down Expand Up @@ -327,9 +327,19 @@ def _add_request():
request.state = "finished"
request.request_id = finished_request.request_id
if finished_request.sampling_params.return_log_probs:
if not finished_request.prompt_log_probs:
finished_request.prompt_log_probs = []
request.log_probs = (
finished_request.prompt_log_probs + finished_request.generated_log_probs
)
if finished_request.sampling_params.top_n_logprobs > 0:
request.generated_top_n_logprobs = getattr(
finished_request, 'generated_top_n_logprobs', None
)
if finished_request.sampling_params.return_prompt_top_n_logprobs:
request.prompt_top_n_logprobs = getattr(
finished_request, 'prompt_top_n_logprobs', None
)
Comment on lines +399 to +406
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I strongly oppose conditional attributes that make us have to use getattr. I do not see a reason why we cannot just put generated_top_n_logprobs and prompt_top_n_logprobs directly into the InferenceRequest dataclass.

It is more confusing for the InferenceRequest object to gain unlisted attributes over time, than it is for us to list all the attributes ahead of time - even if not all flows will need them.

num_requests_finished += 1
output_times.append(get_curr_time() - output_start)

Expand Down Expand Up @@ -372,9 +382,12 @@ def main():
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
skip_prompt_log_probs=args.skip_prompt_log_probs,
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate,
termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod,
top_n_logprobs=args.top_n_logprobs,
return_prompt_top_n_logprobs=args.return_prompt_top_n_logprobs,
Comment on lines +453 to +454
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but I think that we should be consistent with our naming. The rest of the codebase uses log_probs. Newly added attributes should as well, instead of logprobs.

There's several examples of logprobs in this PR.

)

model = get_model()
Expand Down Expand Up @@ -493,13 +506,16 @@ def escape_str(s):
# Write every 'n' requests, plus the final request.
for i, req in enumerate(requests):
if i % args.output_every_n_results == 0 or i == len(requests) - 1:
print(f' Attributes of request {i}: {req.__dict__}')
result_dict = {
"input_prompt": req.prompt_text,
"generated_text": req.output_text,
"generated_tokens": req.output_tokens,
"latency": req.time_end - req.time_start,
"cuda_graph_request_count_map" : result["cuda_graph_request_count_map"],
"step_count" : engine.step_count,
"top_n_logprobs" : getattr(req, 'generated_top_n_logprobs', None),
"prompt_top_n_logprobs" : getattr(req, 'prompt_top_n_logprobs', None),
Comment on lines +588 to +589
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment here about getattr, versus having the attributes be explicitly listed in the dataclass.

}
if req.sampling_params.return_log_probs:
response_logprobs = req.log_probs
Expand All @@ -509,6 +525,7 @@ def escape_str(s):
# Track system-level throughput as a test / debug metric
json_results["throughput"] = throughputs

print(f' Saving results to {args.output_path}')
with open(args.output_path, "w") as fp:
json.dump(json_results, fp, indent=1)

Expand Down Expand Up @@ -560,4 +577,4 @@ def escape_str(s):


if __name__ == "__main__":
main()
main()
12 changes: 12 additions & 0 deletions examples/inference/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
default=0,
help='Return the top n logprobs for the generated tokens and their corresponding token as a dictionary',
)
group.add_argument(
"--return-prompt-top-n-logprobs",
action='store_true',
default=False,
help='Return the top n logprobs for the prompt tokens and their corresponding token as a dictionary',
)
group.add_argument(
"--incoming-requests-per-step",
type=int, default=None,
Expand Down Expand Up @@ -90,6 +96,12 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
default="gpt",
help="Model provider",
)
group.add_argument(
"--skip-prompt-log-probs",
action='store_true',
default=False,
help='Skip prompt log probs.',
)
group.add_argument(
"--output-path",
type=str,
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T

def calculate_log_probs(
self, logits: Tensor, new_tokens: Tensor, only_last_token_logits: Optional[bool] = False
) -> List[List[float]]:
) -> Tuple[List[List[float]], Tensor]:
"""Calculate log probs for all active requests and return them.

TODO: @wdykas support top-n log probs.
Expand All @@ -1696,7 +1696,7 @@ def calculate_log_probs(
if only_last_token_logits or self.is_decode_only():
seq_idx = torch.arange(len(new_tokens), dtype=torch.int32, device=logits.device)
selected_log_probs = log_probs[seq_idx, new_tokens]
return [[lp] for lp in selected_log_probs.flatten().tolist()]
return [[lp] for lp in selected_log_probs.flatten().tolist()], log_probs

# Get the selected token ids for all tokens.
# We shift the active token window left by one to remove the first prompt token for
Expand Down Expand Up @@ -1739,7 +1739,7 @@ def calculate_log_probs(
)

# Convert each log prob tensor into a list
return [lp.tolist() for lp in selected_log_probs_list]
return [lp.tolist() for lp in selected_log_probs_list], log_probs

def get_kvcache_utilization_stats(self) -> dict:
"""Compute KV cache buffer utilization stats for the current step.
Expand Down
109 changes: 99 additions & 10 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def _add_request(
request.sampling_params.num_tokens_to_generate is None
or request.sampling_params.num_tokens_total is None
)
if request.sampling_params.top_n_logprobs > 0:
assert (
request.sampling_params.return_log_probs
), "top_n_logprobs requires sampling_params.return_log_probs to be True"
if request.sampling_params.num_tokens_total is not None:
request.sampling_params.num_tokens_to_generate = (
request.sampling_params.num_tokens_total - len(request.prompt_tokens)
Expand Down Expand Up @@ -542,6 +546,7 @@ def post_process_requests(
step_time: float,
sample: torch.Tensor,
log_probs: torch.Tensor,
top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None,
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
"""
Handles post-processing for requests after a step.
Expand All @@ -552,6 +557,8 @@ def post_process_requests(
step_time (float): The latency of the last step
sample: (torch.Tensor): The newly generated tokens for each request
log_probs: (List): Log probs for each request
top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to
list of (top_n_logprobs, top_n_indices) tuples.

Returns:
A list of active requests and completed requests as `DynamicInferenceRequest` objects
Expand All @@ -563,8 +570,8 @@ def post_process_requests(

log_probs_iter = log_probs if log_probs else repeat(None)

for request_id, token, request_log_probs in zip(
request_ids.tolist(), sample.tolist(), log_probs_iter
for req_idx, (request_id, token, request_log_probs) in enumerate(
zip(request_ids.tolist(), sample.tolist(), log_probs_iter)
):
request: DynamicInferenceRequest = self.requests[request_id]
if request_id != self.context.chunked_prefill_request_id:
Expand All @@ -580,23 +587,58 @@ def post_process_requests(
request.generated_log_probs = []
# If the request log probs span > 1 token we are in prefill
if len(request_log_probs) > 1:
request.prompt_log_probs.extend(request_log_probs)
# Add all but the last logprob to prompt_log_probs (last is for first generated token)
request.prompt_log_probs.extend(request_log_probs[:-1])
# Add the last logprob to generated_log_probs (first generated token)
request.generated_log_probs.extend(request_log_probs[-1:])
else:
if (
# If it is a chunked prefill request
len(request.prompt_log_probs) > 0
# And we are missing the last token for prefill
and len(request.prompt_log_probs) < len(request.prompt_tokens)
and len(request.prompt_log_probs) < len(request.prompt_tokens) - 1
# And we need to track full prefill
and not self.context.materialize_only_last_token_logits
):
assert (
len(request.prompt_log_probs) == len(request.prompt_tokens) - 1
), "Prompt log probs length is not equal to prompt tokens length - 1"
request.prompt_log_probs.extend(request_log_probs)
else:
request.generated_log_probs.extend(request_log_probs)

# Process top_n_logprobs if available
if top_n_logprobs is not None and req_idx in top_n_logprobs:
# Initialize lists if they don't exist
if request.prompt_top_n_logprobs is None:
request.prompt_top_n_logprobs = []
if request.generated_top_n_logprobs is None:
request.generated_top_n_logprobs = []

top_n_data_list = top_n_logprobs[req_idx]
prompt_length = len(request.prompt_tokens)

# Process each token's top-n logprobs
for top_n_values, top_n_indices in top_n_data_list:
logit_dict = {}
for logprob, logprob_index in zip(
top_n_values.tolist(), top_n_indices.tolist()
):
key = self.controller.tokenizer.detokenize([logprob_index])
logit_dict[key] = logprob

# Simple decision: check total count accumulated so far
total_accumulated = len(request.prompt_top_n_logprobs) + len(
request.generated_top_n_logprobs
)

# If return_prompt_top_n_logprobs is True and we haven't reached prompt end,
# append to prompt_top_n_logprobs. Otherwise append to generated_top_n_logprobs.
if (
request.sampling_params.return_prompt_top_n_logprobs
and total_accumulated < prompt_length - 1
):
request.prompt_top_n_logprobs.append(logit_dict)
else:
request.generated_top_n_logprobs.append(logit_dict)

if request_id in finished_request_ids:
request.generated_length = len(request.generated_tokens)
request.status = Status.COMPLETED
Expand Down Expand Up @@ -630,7 +672,49 @@ def post_process_requests(
request.prompt_log_probs = []
request.prompt_log_probs.extend(request_log_probs)
request.generated_log_probs = []
active_requests.append(request)

# Process top_n_logprobs for chunked prefill if available
if top_n_logprobs is not None and req_idx in top_n_logprobs:
# Initialize lists if they don't exist
if (
not hasattr(request, 'prompt_top_n_logprobs')
or request.prompt_top_n_logprobs is None
):
request.prompt_top_n_logprobs = []
if (
not hasattr(request, 'generated_top_n_logprobs')
or request.generated_top_n_logprobs is None
):
request.generated_top_n_logprobs = []

top_n_data_list = top_n_logprobs[req_idx]
prompt_length = len(request.prompt_tokens)

for top_n_values, top_n_indices in top_n_data_list:
# Convert to dictionary format: {token_str: logprob}
logit_dict = {}
for logprob, logprob_index in zip(
top_n_values.cpu().tolist(), top_n_indices.cpu().tolist()
):
key = self.controller.tokenizer.detokenize([logprob_index])
logit_dict[key] = logprob

# Simple decision: check total count accumulated so far
total_accumulated = len(request.prompt_top_n_logprobs) + len(
request.generated_top_n_logprobs
)

# If return_prompt_top_n_logprobs is True and we haven't reached prompt end,
# append to prompt_top_n_logprobs. Otherwise append to generated_top_n_logprobs.
if (
request.sampling_params.return_prompt_top_n_logprobs
and total_accumulated < prompt_length - 1
):
request.prompt_top_n_logprobs.append(logit_dict)
else:
request.generated_top_n_logprobs.append(logit_dict)

active_requests.append(request)

return active_requests, finished_requests

Expand Down Expand Up @@ -767,6 +851,7 @@ async def async_step(
finished_request_ids = result["finished_request_ids"]
sample = result["sample"]
log_probs = result["log_probs"]
top_n_logprobs = result.get("top_n_logprobs", None)
cuda_graph_request_count = result["cuda_graph_request_count"]

# Add paused events.
Expand All @@ -776,10 +861,14 @@ async def async_step(

# Mark requests finished.
[self.requests[i].add_event_finish() for i in finished_request_ids.tolist()]

# Add finished events.
(active_requests, finished_requests) = self.post_process_requests(
active_request_ids, finished_request_ids, step_time, sample, log_probs
active_request_ids,
finished_request_ids,
step_time,
sample,
log_probs,
top_n_logprobs,
)

else:
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def get_metadata_labels() -> Dict[str, int]:
"termination_id",
"return_log_probs",
"skip_prompt_log_probs",
"top_n_logprobs",
"return_prompt_top_n_logprobs",
]
return {k: v for v, k in enumerate(ret)}

Expand Down
Loading