Skip to content

Commit 29eed5d

Browse files
lmcafee-nvidiatdeneMcore Bot
authored
Dynamic engine suspend/resume via prefill. (#1982)
Co-authored-by: Teodor-Dumitru Ene <[email protected]> Co-authored-by: Mcore Bot <[email protected]> Co-authored-by: Teodor-Dumitru Ene <[email protected]>
1 parent 5e3fa28 commit 29eed5d

21 files changed

+1189
-414
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 128 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import hashlib
4+
import io
45
import json
56
import math
67
import os
@@ -13,14 +14,23 @@
1314
from tqdm import tqdm
1415
from typing import Dict, List, Tuple, Optional
1516

16-
import torch
17-
from tqdm import tqdm
17+
sys.path.append(
18+
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
19+
)
1820

21+
import megatron
22+
from examples.inference.gpt.utils import (
23+
Request,
24+
add_common_inference_args,
25+
build_dynamic_engine_setup_prefix,
26+
build_requests,
27+
get_curr_time,
28+
)
1929
from megatron.core.inference.contexts.dynamic_context import (
2030
ContextOverflowError,
2131
DynamicInferenceContext,
2232
)
23-
from megatron.core.inference.engines import DynamicInferenceEngine
33+
from megatron.core.inference.engines import DynamicInferenceEngine, EngineSuspendedError
2434
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
2535
GPTInferenceWrapper,
2636
)
@@ -53,14 +63,14 @@
5363
build_requests,
5464
get_curr_time,
5565
)
56-
from megatron.training import get_args
57-
from megatron.training import get_model as _get_model
58-
from megatron.training import get_tokenizer, initialize_megatron
5966
from megatron.training.checkpointing import load_checkpoint
6067

61-
import torch
62-
import io
63-
import megatron
68+
from model_provider import model_provider
69+
from gpt_builders import gpt_builder
70+
71+
torch.serialization.add_safe_globals([io.BytesIO])
72+
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunState])
73+
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunDiagnostic])
6474

6575

6676
def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
@@ -76,7 +86,13 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
7686
)
7787
group.add_argument(
7888
"--termination-id", type=int, default=None,
79-
help="Termination ID that overrides `tokenizer.eod`."
89+
help="Termination ID that overrides `tokenizer.eod`.",
90+
)
91+
group.add_argument(
92+
"--suspend-resume-interval", type=int, default=None,
93+
help="Suspend and resume the dynamic engine every "
94+
"`suspend_resume_interval` steps. This is used to tet the suspend/resume "
95+
"system.",
8096
)
8197
group.add_argument('--inference-repeat-n', type=int, default=1, help="Repeat inference iterations N times for benchmarking.")
8298

@@ -248,12 +264,12 @@ def run_inference(
248264
num_requests_total = len(requests)
249265
num_requests_added = 0
250266
num_requests_finished = 0
251-
step_id = 0
252267
step_times = {"prefill": [], "decode": []}
253268
add_times = []
254269
output_times = []
255270
tbar = tqdm(total=num_requests_total)
256271
total_output_tokens = 0
272+
attempted_step_count = 0
257273
if args.cuda_graph_impl == "local":
258274
cuda_graph_request_count_map = {r:0 for r in engine.context.cuda_graph_request_counts}
259275
else:
@@ -296,36 +312,75 @@ def _add_request():
296312

297313
# Step inference engine (i.e., generate a token for each active request).
298314
# Before step, we haven't done the scheduling, so we cannot know the is_decode_only
299-
result = engine.step_modern(verbose=True)
315+
try:
316+
result = engine.step_modern(verbose=True)
317+
except EngineSuspendedError as e:
318+
result = e
319+
pass # ignore error in order to call 'engine.resume()' below.
320+
attempted_step_count += 1
321+
300322
# After step, we lost track of last iteration's is_decode_only, so we need to get it from the engine
301323
is_decode_only = engine.is_decode_only
302-
step_id += 1
324+
325+
# Test suspending and resuming engine.
326+
if args.suspend_resume_interval is not None:
327+
328+
# Suspend.
329+
if attempted_step_count % args.suspend_resume_interval == 0:
330+
print("**** step %d/%d ... suspend." % (engine.step_count, attempted_step_count))
331+
engine.suspend()
332+
333+
# Resume, 0+ attempted steps later.
334+
if (
335+
attempted_step_count > 0
336+
and
337+
(attempted_step_count - args.suspend_resume_interval // 2)
338+
% args.suspend_resume_interval == 0
339+
):
340+
print("**** step %d/%d ... resume." % (engine.step_count, attempted_step_count))
341+
engine.resume()
342+
343+
# If engine suspended, continue to next iter.
344+
if isinstance(result, EngineSuspendedError):
345+
continue
303346

304347
# Record cuda_graph_request_count.
305348
cuda_graph_request_count = result["cuda_graph_request_count"]
306349
if args.cuda_graph_impl == "local" and cuda_graph_request_count is not None:
307350
cuda_graph_request_count_map[cuda_graph_request_count] += 1
308351

309352
# Update requests.
310-
active_requests = result["active_requests"]
311-
finished_requests = result["finished_requests"]
353+
active_request_ids = result["active_request_ids"]
354+
finished_request_records = result["finished_request_records"]
312355
step_time = result["step_time"]
313-
if len(active_requests) > 0 or len(finished_requests) > 0:
356+
if len(active_request_ids) > 0 or len(finished_request_records) > 0:
314357
if is_decode_only:
315358
step_times["decode"].append(step_time)
316359
else:
317360
step_times["prefill"].append(step_time)
318361

319362
# Append output tokens.
320363
output_start = get_curr_time()
321-
for finished_request in finished_requests:
364+
for finished_request_record in finished_request_records:
365+
366+
finished_request = finished_request_record.merge(engine.controller.tokenizer)
367+
368+
# Update local request object.
322369
request = requests[finished_request.request_id]
323-
request.output_tokens = finished_request.generated_tokens
324-
total_output_tokens += len(request.output_tokens)
325370
request.time_end = get_curr_time()
326-
request.output_text = finished_request.generated_text
327371
request.state = "finished"
328372
request.request_id = finished_request.request_id
373+
374+
# Update prompt, in case engine has been suspended and resumed.
375+
request.prompt_tokens = finished_request.prompt_tokens
376+
request.prompt_text = finished_request.prompt
377+
378+
# Get output tokens and text.
379+
request.output_tokens = finished_request.generated_tokens
380+
request.output_text = finished_request.generated_text
381+
total_output_tokens += len(request.output_tokens)
382+
383+
# Log probs.
329384
if finished_request.sampling_params.return_log_probs:
330385
request.log_probs = (
331386
finished_request.prompt_log_probs + finished_request.generated_log_probs
@@ -461,7 +516,9 @@ def escape_str(s):
461516
unique_prompt_map[request.prompt_text].append(request_idx)
462517

463518
# Print unique prompts + outputs.
519+
text_hashes = []
464520
for unique_idx, (prompt_text, request_idxs) in enumerate(unique_prompt_map.items()):
521+
465522
# ---- Prompt summary line ----
466523
prompt_len = len(requests[request_idxs[0]].prompt_tokens)
467524
escaped_prompt_text = escape_str(prompt_text)
@@ -476,15 +533,20 @@ def escape_str(s):
476533
# ---- Print each unique output ----
477534
for output_text, output_request_idxs in output_map.items():
478535
if output_text is not None:
479-
o_hash = hashlib.sha256(output_text.encode()).hexdigest()[:6]
536+
# Use hash of prompt + generated text in case engine was
537+
# suspended and resumed, which misaligns boundary between
538+
# prompt and generated tokens.
539+
o_hash = hashlib.sha256(
540+
(prompt_text + output_text).encode()
541+
).hexdigest()[:6]
480542
o_len = len(requests[output_request_idxs[0]].output_tokens)
481543
escaped_output_text = escape_str(output_text)
482-
print(f" >>>> [n {len(output_request_idxs)}, l {o_len}, hash {o_hash}] {escaped_output_text}")
483544
else:
484545
o_hash = "--"
485546
o_len = 0
486547
escaped_output_text = "--"
487-
print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}] {escaped_output_text}")
548+
print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}] {escaped_output_text}")
549+
text_hashes.append(o_hash)
488550

489551
# Write results to JSON. Primarily used for functional testing.
490552
if args.output_path:
@@ -512,47 +574,49 @@ def escape_str(s):
512574
with open(args.output_path, "w") as fp:
513575
json.dump(json_results, fp, indent=1)
514576

515-
# Timing results.
516-
print("~~~")
517-
peak_alloc_gb = stats["allocated_bytes.all.peak"] / 1024**3
518-
peak_resvd_gb = stats["reserved_bytes.all.peak"] / 1024**3
519-
520-
p_times = step_times["prefill"]
521-
d_times = step_times["decode"]
522-
523-
p_total = sum(p_times)
524-
d_total = sum(d_times)
525-
526-
p_count = len(p_times)
527-
d_count = len(d_times)
528-
529-
p_mean = p_total / p_count
530-
d_mean = d_total / d_count
531-
532-
# Commented out for now as the step/add/output times are not calculated correctly.
533-
# print(
534-
# f"{setup_prefix} … "
535-
# f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
536-
# f"total time: {step_total:.3f}s … "
537-
# f"step time: total {step_total:.3f}s "
538-
# f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
539-
# f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
540-
# f"count [ p {p_count}, d {d_count} ]."
541-
# )
542-
capture_str = (
543-
f"{engine.capture_stats['time']:.2f} sec"
544-
if engine.capture_stats else
545-
"--"
546-
)
547-
print(" … ".join((
548-
f"{setup_prefix}",
549-
f"throughput: {throughput:.3f} tok/s",
550-
f"total time: {total_time:.3f}s",
551-
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB",
552-
f"steps: {engine.step_count:d}",
553-
f"capture {capture_str}",
554-
)))
555-
print("~~~")
577+
# Timing results.
578+
stats = torch.cuda.memory_stats()
579+
throughput = total_output_tokens / total_time
580+
print("~~~")
581+
peak_alloc_gb = stats["allocated_bytes.all.peak"] / 1024**3
582+
peak_resvd_gb = stats["reserved_bytes.all.peak"] / 1024**3
583+
584+
p_times = step_times["prefill"]
585+
d_times = step_times["decode"]
586+
587+
p_total = sum(p_times)
588+
d_total = sum(d_times)
589+
590+
p_count = len(p_times)
591+
d_count = len(d_times)
592+
593+
p_mean = p_total / p_count
594+
d_mean = d_total / d_count if d_count != 0 else 0.
595+
596+
# Commented out for now as the step/add/output times are not calculated correctly.
597+
# print(
598+
# f"{setup_prefix} … "
599+
# f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
600+
# f"total time: {step_total:.3f}s … "
601+
# f"step time: total {step_total:.3f}s "
602+
# f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
603+
# f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
604+
# f"count [ p {p_count}, d {d_count} ]."
605+
# )
606+
capture_str = (
607+
f"{engine.capture_stats['time']:.2f} sec"
608+
if engine.capture_stats else
609+
"--"
610+
)
611+
print(
612+
f"{setup_prefix} … "
613+
f"throughput: {throughput:.3f} tok/s",
614+
f"total time: {total_time:.3f}s … "
615+
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
616+
f"steps: {engine.step_count:d} … "
617+
f"capture {capture_str} … "
618+
)
619+
print("~~~")
556620

557621
# Stop Nsight profiler.
558622
if os.environ.get("NSIGHT_PREFIX"):

examples/inference/gpt/gpt_dynamic_inference_12b.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2626
: ${ACTIVE_BUFFER_SIZE_GB=50.}
2727

2828
# Cuda graphs.
29-
: ${CUDA_GRAPH_IMPL=local}
3029
: ${NUM_CUDA_GRAPHS=16}
31-
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3230

3331
# Miscellaneous.
3432
: ${USE_COORDINATOR=0}
@@ -87,6 +85,10 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
8785
--cuda-graph-impl local \
8886
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
8987
"
88+
else
89+
ARGS+=" \
90+
--cuda-graph-impl none \
91+
"
9092
fi
9193

9294
# Prompts.

examples/inference/gpt/gpt_dynamic_inference_357m.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2727
: ${ACTIVE_BUFFER_SIZE_GB=50.}
2828

2929
# Cuda graphs.
30-
: ${CUDA_GRAPH_IMPL=local}
3130
: ${NUM_CUDA_GRAPHS=16}
32-
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3331

3432
# Miscellaneous.
3533
: ${USE_COORDINATOR=0}
@@ -73,6 +71,10 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
7371
--cuda-graph-impl local \
7472
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
7573
"
74+
else
75+
ARGS+=" \
76+
--cuda-graph-impl none \
77+
"
7678
fi
7779

7880
# Prompts.

0 commit comments

Comments
 (0)