11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
33import hashlib
4+ import io
45import json
56import math
67import os
1314from tqdm import tqdm
1415from typing import Dict , List , Tuple , Optional
1516
16- import torch
17- from tqdm import tqdm
17+ sys .path .append (
18+ os .path .abspath (os .path .join (os .path .dirname (__file__ ), os .path .pardir , os .path .pardir ))
19+ )
1820
21+ import megatron
22+ from examples .inference .gpt .utils import (
23+ Request ,
24+ add_common_inference_args ,
25+ build_dynamic_engine_setup_prefix ,
26+ build_requests ,
27+ get_curr_time ,
28+ )
1929from megatron .core .inference .contexts .dynamic_context import (
2030 ContextOverflowError ,
2131 DynamicInferenceContext ,
2232)
23- from megatron .core .inference .engines import DynamicInferenceEngine
33+ from megatron .core .inference .engines import DynamicInferenceEngine , EngineSuspendedError
2434from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
2535 GPTInferenceWrapper ,
2636)
5363 build_requests ,
5464 get_curr_time ,
5565)
56- from megatron .training import get_args
57- from megatron .training import get_model as _get_model
58- from megatron .training import get_tokenizer , initialize_megatron
5966from megatron .training .checkpointing import load_checkpoint
6067
61- import torch
62- import io
63- import megatron
68+ from model_provider import model_provider
69+ from gpt_builders import gpt_builder
70+
71+ torch .serialization .add_safe_globals ([io .BytesIO ])
72+ torch .serialization .add_safe_globals ([megatron .core .rerun_state_machine .RerunState ])
73+ torch .serialization .add_safe_globals ([megatron .core .rerun_state_machine .RerunDiagnostic ])
6474
6575
6676def add_dynamic_inference_args (parser : ArgumentParser ) -> ArgumentParser :
@@ -76,7 +86,13 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
7686 )
7787 group .add_argument (
7888 "--termination-id" , type = int , default = None ,
79- help = "Termination ID that overrides `tokenizer.eod`."
89+ help = "Termination ID that overrides `tokenizer.eod`." ,
90+ )
91+ group .add_argument (
92+ "--suspend-resume-interval" , type = int , default = None ,
93+ help = "Suspend and resume the dynamic engine every "
94+ "`suspend_resume_interval` steps. This is used to tet the suspend/resume "
95+ "system." ,
8096 )
8197 group .add_argument ('--inference-repeat-n' , type = int , default = 1 , help = "Repeat inference iterations N times for benchmarking." )
8298
@@ -248,12 +264,12 @@ def run_inference(
248264 num_requests_total = len (requests )
249265 num_requests_added = 0
250266 num_requests_finished = 0
251- step_id = 0
252267 step_times = {"prefill" : [], "decode" : []}
253268 add_times = []
254269 output_times = []
255270 tbar = tqdm (total = num_requests_total )
256271 total_output_tokens = 0
272+ attempted_step_count = 0
257273 if args .cuda_graph_impl == "local" :
258274 cuda_graph_request_count_map = {r :0 for r in engine .context .cuda_graph_request_counts }
259275 else :
@@ -296,36 +312,75 @@ def _add_request():
296312
297313 # Step inference engine (i.e., generate a token for each active request).
298314 # Before step, we haven't done the scheduling, so we cannot know the is_decode_only
299- result = engine .step_modern (verbose = True )
315+ try :
316+ result = engine .step_modern (verbose = True )
317+ except EngineSuspendedError as e :
318+ result = e
319+ pass # ignore error in order to call 'engine.resume()' below.
320+ attempted_step_count += 1
321+
300322 # After step, we lost track of last iteration's is_decode_only, so we need to get it from the engine
301323 is_decode_only = engine .is_decode_only
302- step_id += 1
324+
325+ # Test suspending and resuming engine.
326+ if args .suspend_resume_interval is not None :
327+
328+ # Suspend.
329+ if attempted_step_count % args .suspend_resume_interval == 0 :
330+ print ("**** step %d/%d ... suspend." % (engine .step_count , attempted_step_count ))
331+ engine .suspend ()
332+
333+ # Resume, 0+ attempted steps later.
334+ if (
335+ attempted_step_count > 0
336+ and
337+ (attempted_step_count - args .suspend_resume_interval // 2 )
338+ % args .suspend_resume_interval == 0
339+ ):
340+ print ("**** step %d/%d ... resume." % (engine .step_count , attempted_step_count ))
341+ engine .resume ()
342+
343+ # If engine suspended, continue to next iter.
344+ if isinstance (result , EngineSuspendedError ):
345+ continue
303346
304347 # Record cuda_graph_request_count.
305348 cuda_graph_request_count = result ["cuda_graph_request_count" ]
306349 if args .cuda_graph_impl == "local" and cuda_graph_request_count is not None :
307350 cuda_graph_request_count_map [cuda_graph_request_count ] += 1
308351
309352 # Update requests.
310- active_requests = result ["active_requests " ]
311- finished_requests = result ["finished_requests " ]
353+ active_request_ids = result ["active_request_ids " ]
354+ finished_request_records = result ["finished_request_records " ]
312355 step_time = result ["step_time" ]
313- if len (active_requests ) > 0 or len (finished_requests ) > 0 :
356+ if len (active_request_ids ) > 0 or len (finished_request_records ) > 0 :
314357 if is_decode_only :
315358 step_times ["decode" ].append (step_time )
316359 else :
317360 step_times ["prefill" ].append (step_time )
318361
319362 # Append output tokens.
320363 output_start = get_curr_time ()
321- for finished_request in finished_requests :
364+ for finished_request_record in finished_request_records :
365+
366+ finished_request = finished_request_record .merge (engine .controller .tokenizer )
367+
368+ # Update local request object.
322369 request = requests [finished_request .request_id ]
323- request .output_tokens = finished_request .generated_tokens
324- total_output_tokens += len (request .output_tokens )
325370 request .time_end = get_curr_time ()
326- request .output_text = finished_request .generated_text
327371 request .state = "finished"
328372 request .request_id = finished_request .request_id
373+
374+ # Update prompt, in case engine has been suspended and resumed.
375+ request .prompt_tokens = finished_request .prompt_tokens
376+ request .prompt_text = finished_request .prompt
377+
378+ # Get output tokens and text.
379+ request .output_tokens = finished_request .generated_tokens
380+ request .output_text = finished_request .generated_text
381+ total_output_tokens += len (request .output_tokens )
382+
383+ # Log probs.
329384 if finished_request .sampling_params .return_log_probs :
330385 request .log_probs = (
331386 finished_request .prompt_log_probs + finished_request .generated_log_probs
@@ -461,7 +516,9 @@ def escape_str(s):
461516 unique_prompt_map [request .prompt_text ].append (request_idx )
462517
463518 # Print unique prompts + outputs.
519+ text_hashes = []
464520 for unique_idx , (prompt_text , request_idxs ) in enumerate (unique_prompt_map .items ()):
521+
465522 # ---- Prompt summary line ----
466523 prompt_len = len (requests [request_idxs [0 ]].prompt_tokens )
467524 escaped_prompt_text = escape_str (prompt_text )
@@ -476,15 +533,20 @@ def escape_str(s):
476533 # ---- Print each unique output ----
477534 for output_text , output_request_idxs in output_map .items ():
478535 if output_text is not None :
479- o_hash = hashlib .sha256 (output_text .encode ()).hexdigest ()[:6 ]
536+ # Use hash of prompt + generated text in case engine was
537+ # suspended and resumed, which misaligns boundary between
538+ # prompt and generated tokens.
539+ o_hash = hashlib .sha256 (
540+ (prompt_text + output_text ).encode ()
541+ ).hexdigest ()[:6 ]
480542 o_len = len (requests [output_request_idxs [0 ]].output_tokens )
481543 escaped_output_text = escape_str (output_text )
482- print (f" >>>> [n { len (output_request_idxs )} , l { o_len } , hash { o_hash } ] { escaped_output_text } " )
483544 else :
484545 o_hash = "--"
485546 o_len = 0
486547 escaped_output_text = "--"
487- print (f" >>>> [n { len (output_request_idxs )} , { o_len } tokens, hash { o_hash } ] { escaped_output_text } " )
548+ print (f" >>>> [n { len (output_request_idxs )} , { o_len } tokens, hash { o_hash } ] { escaped_output_text } " )
549+ text_hashes .append (o_hash )
488550
489551 # Write results to JSON. Primarily used for functional testing.
490552 if args .output_path :
@@ -512,47 +574,49 @@ def escape_str(s):
512574 with open (args .output_path , "w" ) as fp :
513575 json .dump (json_results , fp , indent = 1 )
514576
515- # Timing results.
516- print ("~~~" )
517- peak_alloc_gb = stats ["allocated_bytes.all.peak" ] / 1024 ** 3
518- peak_resvd_gb = stats ["reserved_bytes.all.peak" ] / 1024 ** 3
519-
520- p_times = step_times ["prefill" ]
521- d_times = step_times ["decode" ]
522-
523- p_total = sum (p_times )
524- d_total = sum (d_times )
525-
526- p_count = len (p_times )
527- d_count = len (d_times )
528-
529- p_mean = p_total / p_count
530- d_mean = d_total / d_count
531-
532- # Commented out for now as the step/add/output times are not calculated correctly.
533- # print(
534- # f"{setup_prefix} … "
535- # f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
536- # f"total time: {step_total:.3f}s … "
537- # f"step time: total {step_total:.3f}s "
538- # f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
539- # f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
540- # f"count [ p {p_count}, d {d_count} ]."
541- # )
542- capture_str = (
543- f"{ engine .capture_stats ['time' ]:.2f} sec"
544- if engine .capture_stats else
545- "--"
546- )
547- print (" … " .join ((
548- f"{ setup_prefix } " ,
549- f"throughput: { throughput :.3f} tok/s" ,
550- f"total time: { total_time :.3f} s" ,
551- f"mem { peak_alloc_gb :.1f} /{ peak_resvd_gb :.1f} GB" ,
552- f"steps: { engine .step_count :d} " ,
553- f"capture { capture_str } " ,
554- )))
555- print ("~~~" )
577+ # Timing results.
578+ stats = torch .cuda .memory_stats ()
579+ throughput = total_output_tokens / total_time
580+ print ("~~~" )
581+ peak_alloc_gb = stats ["allocated_bytes.all.peak" ] / 1024 ** 3
582+ peak_resvd_gb = stats ["reserved_bytes.all.peak" ] / 1024 ** 3
583+
584+ p_times = step_times ["prefill" ]
585+ d_times = step_times ["decode" ]
586+
587+ p_total = sum (p_times )
588+ d_total = sum (d_times )
589+
590+ p_count = len (p_times )
591+ d_count = len (d_times )
592+
593+ p_mean = p_total / p_count
594+ d_mean = d_total / d_count if d_count != 0 else 0.
595+
596+ # Commented out for now as the step/add/output times are not calculated correctly.
597+ # print(
598+ # f"{setup_prefix} … "
599+ # f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
600+ # f"total time: {step_total:.3f}s … "
601+ # f"step time: total {step_total:.3f}s "
602+ # f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
603+ # f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
604+ # f"count [ p {p_count}, d {d_count} ]."
605+ # )
606+ capture_str = (
607+ f"{ engine .capture_stats ['time' ]:.2f} sec"
608+ if engine .capture_stats else
609+ "--"
610+ )
611+ print (
612+ f"{ setup_prefix } … "
613+ f"throughput: { throughput :.3f} tok/s" ,
614+ f"total time: { total_time :.3f} s … "
615+ f"mem { peak_alloc_gb :.1f} /{ peak_resvd_gb :.1f} GB … "
616+ f"steps: { engine .step_count :d} … "
617+ f"capture { capture_str } … "
618+ )
619+ print ("~~~" )
556620
557621 # Stop Nsight profiler.
558622 if os .environ .get ("NSIGHT_PREFIX" ):
0 commit comments