Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def _add_request():
request.log_probs = (
finished_request.prompt_log_probs + finished_request.generated_log_probs
)
if finished_request.sampling_params.top_n_logprobs > 0:
request.generated_top_n_logprobs = getattr(
finished_request, 'generated_top_n_logprobs', None
)
num_requests_finished += 1
output_times.append(get_curr_time() - output_start)

Expand Down Expand Up @@ -375,6 +379,7 @@ def main():
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate,
termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod,
top_n_logprobs=args.top_n_logprobs,
)

model = get_model()
Expand Down Expand Up @@ -493,13 +498,15 @@ def escape_str(s):
# Write every 'n' requests, plus the final request.
for i, req in enumerate(requests):
if i % args.output_every_n_results == 0 or i == len(requests) - 1:
print(f' Attributes of request {i}: {req.__dict__}')
result_dict = {
"input_prompt": req.prompt_text,
"generated_text": req.output_text,
"generated_tokens": req.output_tokens,
"latency": req.time_end - req.time_start,
"cuda_graph_request_count_map" : result["cuda_graph_request_count_map"],
"step_count" : engine.step_count,
"top_n_logprobs" : getattr(req, 'generated_top_n_logprobs', None),
}
if req.sampling_params.return_log_probs:
response_logprobs = req.log_probs
Expand All @@ -509,6 +516,7 @@ def escape_str(s):
# Track system-level throughput as a test / debug metric
json_results["throughput"] = throughputs

print(f' Saving results to {args.output_path}')
with open(args.output_path, "w") as fp:
json.dump(json_results, fp, indent=1)

Expand Down
27 changes: 24 additions & 3 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def post_process_requests(
step_time: float,
sample: torch.Tensor,
log_probs: torch.Tensor,
top_n_logprobs_dict: Optional[Dict[int, List[Dict[str, float]]]] = None,
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
"""
Handles post-processing for requests after a step.
Expand All @@ -552,6 +553,7 @@ def post_process_requests(
step_time (float): The latency of the last step
sample: (torch.Tensor): The newly generated tokens for each request
log_probs: (List): Log probs for each request
top_n_logprobs_dict (Optional[Dict]): Top n log probs for each request

Returns:
A list of active requests and completed requests as `DynamicInferenceRequest` objects
Expand All @@ -563,9 +565,17 @@ def post_process_requests(

log_probs_iter = log_probs if log_probs else repeat(None)

for request_id, token, request_log_probs in zip(
# Create a mapping from context indices to request_ids for top_n_logprobs
context_idx_to_request_id = {}
if top_n_logprobs_dict is not None:
active_request_ids = self.context.request_ids[
self.context.paused_request_count : self.context.total_request_count
].tolist()
context_idx_to_request_id = {i: rid for i, rid in enumerate(active_request_ids)}

for request_idx, (request_id, token, request_log_probs) in enumerate(zip(
request_ids.tolist(), sample.tolist(), log_probs_iter
):
)):
request: DynamicInferenceRequest = self.requests[request_id]
if request_id != self.context.chunked_prefill_request_id:
request.generated_tokens.append(token)
Expand Down Expand Up @@ -597,6 +607,15 @@ def post_process_requests(
else:
request.generated_log_probs.extend(request_log_probs)

# Handle top_n_logprobs
if top_n_logprobs_dict is not None and request_idx in context_idx_to_request_id:
context_idx = request_idx
if context_idx in top_n_logprobs_dict and top_n_logprobs_dict[context_idx]:
if not hasattr(request, 'generated_top_n_logprobs') or request.generated_top_n_logprobs is None:
request.generated_top_n_logprobs = []
# Append the top_n_logprobs for this step
request.generated_top_n_logprobs.extend(top_n_logprobs_dict[context_idx])

if request_id in finished_request_ids:
request.generated_length = len(request.generated_tokens)
request.status = Status.COMPLETED
Expand Down Expand Up @@ -767,6 +786,7 @@ async def async_step(
finished_request_ids = result["finished_request_ids"]
sample = result["sample"]
log_probs = result["log_probs"]
top_n_logprobs_dict = result.get("top_n_logprobs_dict", None)
cuda_graph_request_count = result["cuda_graph_request_count"]

# Add paused events.
Expand All @@ -779,7 +799,8 @@ async def async_step(

# Add finished events.
(active_requests, finished_requests) = self.post_process_requests(
active_request_ids, finished_request_ids, step_time, sample, log_probs
active_request_ids, finished_request_ids, step_time, sample, log_probs,
top_n_logprobs_dict=top_n_logprobs_dict
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ async def async_generate_output_tokens_dynamic_batch(
finished_request_ids (Tensor): Finished request IDs.
sample (Tensor): New sample.
log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested.
top_n_logprobs_dict (Optional[Dict]): Top n log probabilities for each request, if requested.
cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step.
"""
context = self.inference_wrapped_model.inference_context
Expand Down