11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
33import hashlib
4- import io
54import json
65import math
76import os
1413from tqdm import tqdm
1514from typing import Dict , List , Tuple , Optional
1615
17- sys .path .append (
18- os .path .abspath (os .path .join (os .path .dirname (__file__ ), os .path .pardir , os .path .pardir ))
19- )
16+ import torch
17+ from tqdm import tqdm
2018
21- import megatron
22- from examples .inference .gpt .utils import (
23- Request ,
24- add_common_inference_args ,
25- build_dynamic_engine_setup_prefix ,
26- build_requests ,
27- get_curr_time ,
28- )
2919from megatron .core .inference .contexts .dynamic_context import (
3020 ContextOverflowError ,
3121 DynamicInferenceContext ,
3222)
33- from megatron .core .inference .engines import DynamicInferenceEngine , EngineSuspendedError
23+ from megatron .core .inference .engines import DynamicInferenceEngine
3424from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
3525 GPTInferenceWrapper ,
3626)
6353 build_requests ,
6454 get_curr_time ,
6555)
56+ from megatron .training import get_args
57+ from megatron .training import get_model as _get_model
58+ from megatron .training import get_tokenizer , initialize_megatron
6659from megatron .training .checkpointing import load_checkpoint
6760
68- from model_provider import model_provider
69- from gpt_builders import gpt_builder
70-
71- torch .serialization .add_safe_globals ([io .BytesIO ])
72- torch .serialization .add_safe_globals ([megatron .core .rerun_state_machine .RerunState ])
73- torch .serialization .add_safe_globals ([megatron .core .rerun_state_machine .RerunDiagnostic ])
61+ import torch
62+ import io
63+ import megatron
7464
7565
7666def add_dynamic_inference_args (parser : ArgumentParser ) -> ArgumentParser :
@@ -86,13 +76,7 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
8676 )
8777 group .add_argument (
8878 "--termination-id" , type = int , default = None ,
89- help = "Termination ID that overrides `tokenizer.eod`." ,
90- )
91- group .add_argument (
92- "--suspend-resume-interval" , type = int , default = None ,
93- help = "Suspend and resume the dynamic engine every "
94- "`suspend_resume_interval` steps. This is used to tet the suspend/resume "
95- "system." ,
79+ help = "Termination ID that overrides `tokenizer.eod`."
9680 )
9781 group .add_argument ('--inference-repeat-n' , type = int , default = 1 , help = "Repeat inference iterations N times for benchmarking." )
9882
@@ -264,12 +248,12 @@ def run_inference(
264248 num_requests_total = len (requests )
265249 num_requests_added = 0
266250 num_requests_finished = 0
251+ step_id = 0
267252 step_times = {"prefill" : [], "decode" : []}
268253 add_times = []
269254 output_times = []
270255 tbar = tqdm (total = num_requests_total )
271256 total_output_tokens = 0
272- attempted_step_count = 0
273257 if args .cuda_graph_impl == "local" :
274258 cuda_graph_request_count_map = {r :0 for r in engine .context .cuda_graph_request_counts }
275259 else :
@@ -312,75 +296,36 @@ def _add_request():
312296
313297 # Step inference engine (i.e., generate a token for each active request).
314298 # Before step, we haven't done the scheduling, so we cannot know the is_decode_only
315- try :
316- result = engine .step_modern (verbose = True )
317- except EngineSuspendedError as e :
318- result = e
319- pass # ignore error in order to call 'engine.resume()' below.
320- attempted_step_count += 1
321-
299+ result = engine .step_modern (verbose = True )
322300 # After step, we lost track of last iteration's is_decode_only, so we need to get it from the engine
323301 is_decode_only = engine .is_decode_only
324-
325- # Test suspending and resuming engine.
326- if args .suspend_resume_interval is not None :
327-
328- # Suspend.
329- if attempted_step_count % args .suspend_resume_interval == 0 :
330- print ("**** step %d/%d ... suspend." % (engine .step_count , attempted_step_count ))
331- engine .suspend ()
332-
333- # Resume, 0+ attempted steps later.
334- if (
335- attempted_step_count > 0
336- and
337- (attempted_step_count - args .suspend_resume_interval // 2 )
338- % args .suspend_resume_interval == 0
339- ):
340- print ("**** step %d/%d ... resume." % (engine .step_count , attempted_step_count ))
341- engine .resume ()
342-
343- # If engine suspended, continue to next iter.
344- if isinstance (result , EngineSuspendedError ):
345- continue
302+ step_id += 1
346303
347304 # Record cuda_graph_request_count.
348305 cuda_graph_request_count = result ["cuda_graph_request_count" ]
349306 if args .cuda_graph_impl == "local" and cuda_graph_request_count is not None :
350307 cuda_graph_request_count_map [cuda_graph_request_count ] += 1
351308
352309 # Update requests.
353- active_request_ids = result ["active_request_ids " ]
354- finished_request_records = result ["finished_request_records " ]
310+ active_requests = result ["active_requests " ]
311+ finished_requests = result ["finished_requests " ]
355312 step_time = result ["step_time" ]
356- if len (active_request_ids ) > 0 or len (finished_request_records ) > 0 :
313+ if len (active_requests ) > 0 or len (finished_requests ) > 0 :
357314 if is_decode_only :
358315 step_times ["decode" ].append (step_time )
359316 else :
360317 step_times ["prefill" ].append (step_time )
361318
362319 # Append output tokens.
363320 output_start = get_curr_time ()
364- for finished_request_record in finished_request_records :
365-
366- finished_request = finished_request_record .merge (engine .controller .tokenizer )
367-
368- # Update local request object.
321+ for finished_request in finished_requests :
369322 request = requests [finished_request .request_id ]
323+ request .output_tokens = finished_request .generated_tokens
324+ total_output_tokens += len (request .output_tokens )
370325 request .time_end = get_curr_time ()
326+ request .output_text = finished_request .generated_text
371327 request .state = "finished"
372328 request .request_id = finished_request .request_id
373-
374- # Update prompt, in case engine has been suspended and resumed.
375- request .prompt_tokens = finished_request .prompt_tokens
376- request .prompt_text = finished_request .prompt
377-
378- # Get output tokens and text.
379- request .output_tokens = finished_request .generated_tokens
380- request .output_text = finished_request .generated_text
381- total_output_tokens += len (request .output_tokens )
382-
383- # Log probs.
384329 if finished_request .sampling_params .return_log_probs :
385330 request .log_probs = (
386331 finished_request .prompt_log_probs + finished_request .generated_log_probs
@@ -516,9 +461,7 @@ def escape_str(s):
516461 unique_prompt_map [request .prompt_text ].append (request_idx )
517462
518463 # Print unique prompts + outputs.
519- text_hashes = []
520464 for unique_idx , (prompt_text , request_idxs ) in enumerate (unique_prompt_map .items ()):
521-
522465 # ---- Prompt summary line ----
523466 prompt_len = len (requests [request_idxs [0 ]].prompt_tokens )
524467 escaped_prompt_text = escape_str (prompt_text )
@@ -533,20 +476,15 @@ def escape_str(s):
533476 # ---- Print each unique output ----
534477 for output_text , output_request_idxs in output_map .items ():
535478 if output_text is not None :
536- # Use hash of prompt + generated text in case engine was
537- # suspended and resumed, which misaligns boundary between
538- # prompt and generated tokens.
539- o_hash = hashlib .sha256 (
540- (prompt_text + output_text ).encode ()
541- ).hexdigest ()[:6 ]
479+ o_hash = hashlib .sha256 (output_text .encode ()).hexdigest ()[:6 ]
542480 o_len = len (requests [output_request_idxs [0 ]].output_tokens )
543481 escaped_output_text = escape_str (output_text )
482+ print (f" >>>> [n { len (output_request_idxs )} , l { o_len } , hash { o_hash } ] { escaped_output_text } " )
544483 else :
545484 o_hash = "--"
546485 o_len = 0
547486 escaped_output_text = "--"
548- print (f" >>>> [n { len (output_request_idxs )} , { o_len } tokens, hash { o_hash } ] { escaped_output_text } " )
549- text_hashes .append (o_hash )
487+ print (f" >>>> [n { len (output_request_idxs )} , { o_len } tokens, hash { o_hash } ] { escaped_output_text } " )
550488
551489 # Write results to JSON. Primarily used for functional testing.
552490 if args .output_path :
@@ -574,49 +512,47 @@ def escape_str(s):
574512 with open (args .output_path , "w" ) as fp :
575513 json .dump (json_results , fp , indent = 1 )
576514
577- # Timing results.
578- stats = torch .cuda .memory_stats ()
579- throughput = total_output_tokens / total_time
580- print ("~~~" )
581- peak_alloc_gb = stats ["allocated_bytes.all.peak" ] / 1024 ** 3
582- peak_resvd_gb = stats ["reserved_bytes.all.peak" ] / 1024 ** 3
583-
584- p_times = step_times ["prefill" ]
585- d_times = step_times ["decode" ]
586-
587- p_total = sum (p_times )
588- d_total = sum (d_times )
589-
590- p_count = len (p_times )
591- d_count = len (d_times )
592-
593- p_mean = p_total / p_count
594- d_mean = d_total / d_count if d_count != 0 else 0.
595-
596- # Commented out for now as the step/add/output times are not calculated correctly.
597- # print(
598- # f"{setup_prefix} … "
599- # f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
600- # f"total time: {step_total:.3f}s … "
601- # f"step time: total {step_total:.3f}s "
602- # f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
603- # f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
604- # f"count [ p {p_count}, d {d_count} ]."
605- # )
606- capture_str = (
607- f"{ engine .capture_stats ['time' ]:.2f} sec"
608- if engine .capture_stats else
609- "--"
610- )
611- print (
612- f"{ setup_prefix } … "
613- f"throughput: { throughput :.3f} tok/s" ,
614- f"total time: { total_time :.3f} s … "
615- f"mem { peak_alloc_gb :.1f} /{ peak_resvd_gb :.1f} GB … "
616- f"steps: { engine .step_count :d} … "
617- f"capture { capture_str } … "
618- )
619- print ("~~~" )
515+ # Timing results.
516+ print ("~~~" )
517+ peak_alloc_gb = stats ["allocated_bytes.all.peak" ] / 1024 ** 3
518+ peak_resvd_gb = stats ["reserved_bytes.all.peak" ] / 1024 ** 3
519+
520+ p_times = step_times ["prefill" ]
521+ d_times = step_times ["decode" ]
522+
523+ p_total = sum (p_times )
524+ d_total = sum (d_times )
525+
526+ p_count = len (p_times )
527+ d_count = len (d_times )
528+
529+ p_mean = p_total / p_count
530+ d_mean = d_total / d_count
531+
532+ # Commented out for now as the step/add/output times are not calculated correctly.
533+ # print(
534+ # f"{setup_prefix} … "
535+ # f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
536+ # f"total time: {step_total:.3f}s … "
537+ # f"step time: total {step_total:.3f}s "
538+ # f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
539+ # f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
540+ # f"count [ p {p_count}, d {d_count} ]."
541+ # )
542+ capture_str = (
543+ f"{ engine .capture_stats ['time' ]:.2f} sec"
544+ if engine .capture_stats else
545+ "--"
546+ )
547+ print (" … " .join ((
548+ f"{ setup_prefix } " ,
549+ f"throughput: { throughput :.3f} tok/s" ,
550+ f"total time: { total_time :.3f} s" ,
551+ f"mem { peak_alloc_gb :.1f} /{ peak_resvd_gb :.1f} GB" ,
552+ f"steps: { engine .step_count :d} " ,
553+ f"capture { capture_str } " ,
554+ )))
555+ print ("~~~" )
620556
621557 # Stop Nsight profiler.
622558 if os .environ .get ("NSIGHT_PREFIX" ):
0 commit comments