Skip to content

Commit 3b83c3f

Browse files
committed
Revert "Dynamic engine suspend/resume via prefill. (#1982)"
This reverts commit 29eed5d.
1 parent 29eed5d commit 3b83c3f

21 files changed

+414
-1189
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 64 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import hashlib
4-
import io
54
import json
65
import math
76
import os
@@ -14,23 +13,14 @@
1413
from tqdm import tqdm
1514
from typing import Dict, List, Tuple, Optional
1615

17-
sys.path.append(
18-
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
19-
)
16+
import torch
17+
from tqdm import tqdm
2018

21-
import megatron
22-
from examples.inference.gpt.utils import (
23-
Request,
24-
add_common_inference_args,
25-
build_dynamic_engine_setup_prefix,
26-
build_requests,
27-
get_curr_time,
28-
)
2919
from megatron.core.inference.contexts.dynamic_context import (
3020
ContextOverflowError,
3121
DynamicInferenceContext,
3222
)
33-
from megatron.core.inference.engines import DynamicInferenceEngine, EngineSuspendedError
23+
from megatron.core.inference.engines import DynamicInferenceEngine
3424
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
3525
GPTInferenceWrapper,
3626
)
@@ -63,14 +53,14 @@
6353
build_requests,
6454
get_curr_time,
6555
)
56+
from megatron.training import get_args
57+
from megatron.training import get_model as _get_model
58+
from megatron.training import get_tokenizer, initialize_megatron
6659
from megatron.training.checkpointing import load_checkpoint
6760

68-
from model_provider import model_provider
69-
from gpt_builders import gpt_builder
70-
71-
torch.serialization.add_safe_globals([io.BytesIO])
72-
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunState])
73-
torch.serialization.add_safe_globals([megatron.core.rerun_state_machine.RerunDiagnostic])
61+
import torch
62+
import io
63+
import megatron
7464

7565

7666
def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
@@ -86,13 +76,7 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
8676
)
8777
group.add_argument(
8878
"--termination-id", type=int, default=None,
89-
help="Termination ID that overrides `tokenizer.eod`.",
90-
)
91-
group.add_argument(
92-
"--suspend-resume-interval", type=int, default=None,
93-
help="Suspend and resume the dynamic engine every "
94-
"`suspend_resume_interval` steps. This is used to tet the suspend/resume "
95-
"system.",
79+
help="Termination ID that overrides `tokenizer.eod`."
9680
)
9781
group.add_argument('--inference-repeat-n', type=int, default=1, help="Repeat inference iterations N times for benchmarking.")
9882

@@ -264,12 +248,12 @@ def run_inference(
264248
num_requests_total = len(requests)
265249
num_requests_added = 0
266250
num_requests_finished = 0
251+
step_id = 0
267252
step_times = {"prefill": [], "decode": []}
268253
add_times = []
269254
output_times = []
270255
tbar = tqdm(total=num_requests_total)
271256
total_output_tokens = 0
272-
attempted_step_count = 0
273257
if args.cuda_graph_impl == "local":
274258
cuda_graph_request_count_map = {r:0 for r in engine.context.cuda_graph_request_counts}
275259
else:
@@ -312,75 +296,36 @@ def _add_request():
312296

313297
# Step inference engine (i.e., generate a token for each active request).
314298
# Before step, we haven't done the scheduling, so we cannot know the is_decode_only
315-
try:
316-
result = engine.step_modern(verbose=True)
317-
except EngineSuspendedError as e:
318-
result = e
319-
pass # ignore error in order to call 'engine.resume()' below.
320-
attempted_step_count += 1
321-
299+
result = engine.step_modern(verbose=True)
322300
# After step, we lost track of last iteration's is_decode_only, so we need to get it from the engine
323301
is_decode_only = engine.is_decode_only
324-
325-
# Test suspending and resuming engine.
326-
if args.suspend_resume_interval is not None:
327-
328-
# Suspend.
329-
if attempted_step_count % args.suspend_resume_interval == 0:
330-
print("**** step %d/%d ... suspend." % (engine.step_count, attempted_step_count))
331-
engine.suspend()
332-
333-
# Resume, 0+ attempted steps later.
334-
if (
335-
attempted_step_count > 0
336-
and
337-
(attempted_step_count - args.suspend_resume_interval // 2)
338-
% args.suspend_resume_interval == 0
339-
):
340-
print("**** step %d/%d ... resume." % (engine.step_count, attempted_step_count))
341-
engine.resume()
342-
343-
# If engine suspended, continue to next iter.
344-
if isinstance(result, EngineSuspendedError):
345-
continue
302+
step_id += 1
346303

347304
# Record cuda_graph_request_count.
348305
cuda_graph_request_count = result["cuda_graph_request_count"]
349306
if args.cuda_graph_impl == "local" and cuda_graph_request_count is not None:
350307
cuda_graph_request_count_map[cuda_graph_request_count] += 1
351308

352309
# Update requests.
353-
active_request_ids = result["active_request_ids"]
354-
finished_request_records = result["finished_request_records"]
310+
active_requests = result["active_requests"]
311+
finished_requests = result["finished_requests"]
355312
step_time = result["step_time"]
356-
if len(active_request_ids) > 0 or len(finished_request_records) > 0:
313+
if len(active_requests) > 0 or len(finished_requests) > 0:
357314
if is_decode_only:
358315
step_times["decode"].append(step_time)
359316
else:
360317
step_times["prefill"].append(step_time)
361318

362319
# Append output tokens.
363320
output_start = get_curr_time()
364-
for finished_request_record in finished_request_records:
365-
366-
finished_request = finished_request_record.merge(engine.controller.tokenizer)
367-
368-
# Update local request object.
321+
for finished_request in finished_requests:
369322
request = requests[finished_request.request_id]
323+
request.output_tokens = finished_request.generated_tokens
324+
total_output_tokens += len(request.output_tokens)
370325
request.time_end = get_curr_time()
326+
request.output_text = finished_request.generated_text
371327
request.state = "finished"
372328
request.request_id = finished_request.request_id
373-
374-
# Update prompt, in case engine has been suspended and resumed.
375-
request.prompt_tokens = finished_request.prompt_tokens
376-
request.prompt_text = finished_request.prompt
377-
378-
# Get output tokens and text.
379-
request.output_tokens = finished_request.generated_tokens
380-
request.output_text = finished_request.generated_text
381-
total_output_tokens += len(request.output_tokens)
382-
383-
# Log probs.
384329
if finished_request.sampling_params.return_log_probs:
385330
request.log_probs = (
386331
finished_request.prompt_log_probs + finished_request.generated_log_probs
@@ -516,9 +461,7 @@ def escape_str(s):
516461
unique_prompt_map[request.prompt_text].append(request_idx)
517462

518463
# Print unique prompts + outputs.
519-
text_hashes = []
520464
for unique_idx, (prompt_text, request_idxs) in enumerate(unique_prompt_map.items()):
521-
522465
# ---- Prompt summary line ----
523466
prompt_len = len(requests[request_idxs[0]].prompt_tokens)
524467
escaped_prompt_text = escape_str(prompt_text)
@@ -533,20 +476,15 @@ def escape_str(s):
533476
# ---- Print each unique output ----
534477
for output_text, output_request_idxs in output_map.items():
535478
if output_text is not None:
536-
# Use hash of prompt + generated text in case engine was
537-
# suspended and resumed, which misaligns boundary between
538-
# prompt and generated tokens.
539-
o_hash = hashlib.sha256(
540-
(prompt_text + output_text).encode()
541-
).hexdigest()[:6]
479+
o_hash = hashlib.sha256(output_text.encode()).hexdigest()[:6]
542480
o_len = len(requests[output_request_idxs[0]].output_tokens)
543481
escaped_output_text = escape_str(output_text)
482+
print(f" >>>> [n {len(output_request_idxs)}, l {o_len}, hash {o_hash}] {escaped_output_text}")
544483
else:
545484
o_hash = "--"
546485
o_len = 0
547486
escaped_output_text = "--"
548-
print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}] {escaped_output_text}")
549-
text_hashes.append(o_hash)
487+
print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}] {escaped_output_text}")
550488

551489
# Write results to JSON. Primarily used for functional testing.
552490
if args.output_path:
@@ -574,49 +512,47 @@ def escape_str(s):
574512
with open(args.output_path, "w") as fp:
575513
json.dump(json_results, fp, indent=1)
576514

577-
# Timing results.
578-
stats = torch.cuda.memory_stats()
579-
throughput = total_output_tokens / total_time
580-
print("~~~")
581-
peak_alloc_gb = stats["allocated_bytes.all.peak"] / 1024**3
582-
peak_resvd_gb = stats["reserved_bytes.all.peak"] / 1024**3
583-
584-
p_times = step_times["prefill"]
585-
d_times = step_times["decode"]
586-
587-
p_total = sum(p_times)
588-
d_total = sum(d_times)
589-
590-
p_count = len(p_times)
591-
d_count = len(d_times)
592-
593-
p_mean = p_total / p_count
594-
d_mean = d_total / d_count if d_count != 0 else 0.
595-
596-
# Commented out for now as the step/add/output times are not calculated correctly.
597-
# print(
598-
# f"{setup_prefix} … "
599-
# f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
600-
# f"total time: {step_total:.3f}s … "
601-
# f"step time: total {step_total:.3f}s "
602-
# f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
603-
# f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
604-
# f"count [ p {p_count}, d {d_count} ]."
605-
# )
606-
capture_str = (
607-
f"{engine.capture_stats['time']:.2f} sec"
608-
if engine.capture_stats else
609-
"--"
610-
)
611-
print(
612-
f"{setup_prefix} … "
613-
f"throughput: {throughput:.3f} tok/s",
614-
f"total time: {total_time:.3f}s … "
615-
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
616-
f"steps: {engine.step_count:d} … "
617-
f"capture {capture_str} … "
618-
)
619-
print("~~~")
515+
# Timing results.
516+
print("~~~")
517+
peak_alloc_gb = stats["allocated_bytes.all.peak"] / 1024**3
518+
peak_resvd_gb = stats["reserved_bytes.all.peak"] / 1024**3
519+
520+
p_times = step_times["prefill"]
521+
d_times = step_times["decode"]
522+
523+
p_total = sum(p_times)
524+
d_total = sum(d_times)
525+
526+
p_count = len(p_times)
527+
d_count = len(d_times)
528+
529+
p_mean = p_total / p_count
530+
d_mean = d_total / d_count
531+
532+
# Commented out for now as the step/add/output times are not calculated correctly.
533+
# print(
534+
# f"{setup_prefix} … "
535+
# f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
536+
# f"total time: {step_total:.3f}s … "
537+
# f"step time: total {step_total:.3f}s "
538+
# f"[ p {p_total:.3f}s, d {d_total:.3f}s ], "
539+
# f"mean [ p {p_mean:.3f}s, d {d_mean:.3f}s ], "
540+
# f"count [ p {p_count}, d {d_count} ]."
541+
# )
542+
capture_str = (
543+
f"{engine.capture_stats['time']:.2f} sec"
544+
if engine.capture_stats else
545+
"--"
546+
)
547+
print(" … ".join((
548+
f"{setup_prefix}",
549+
f"throughput: {throughput:.3f} tok/s",
550+
f"total time: {total_time:.3f}s",
551+
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB",
552+
f"steps: {engine.step_count:d}",
553+
f"capture {capture_str}",
554+
)))
555+
print("~~~")
620556

621557
# Stop Nsight profiler.
622558
if os.environ.get("NSIGHT_PREFIX"):

examples/inference/gpt/gpt_dynamic_inference_12b.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2626
: ${ACTIVE_BUFFER_SIZE_GB=50.}
2727

2828
# Cuda graphs.
29+
: ${CUDA_GRAPH_IMPL=local}
2930
: ${NUM_CUDA_GRAPHS=16}
31+
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3032

3133
# Miscellaneous.
3234
: ${USE_COORDINATOR=0}
@@ -85,10 +87,6 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
8587
--cuda-graph-impl local \
8688
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
8789
"
88-
else
89-
ARGS+=" \
90-
--cuda-graph-impl none \
91-
"
9290
fi
9391

9492
# Prompts.

examples/inference/gpt/gpt_dynamic_inference_357m.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
2727
: ${ACTIVE_BUFFER_SIZE_GB=50.}
2828

2929
# Cuda graphs.
30+
: ${CUDA_GRAPH_IMPL=local}
3031
: ${NUM_CUDA_GRAPHS=16}
32+
: ${CUDA_GRAPH_SHARE_IO_BUFFERS=1}
3133

3234
# Miscellaneous.
3335
: ${USE_COORDINATOR=0}
@@ -71,10 +73,6 @@ if [ "${NUM_CUDA_GRAPHS}" != "0" ]; then
7173
--cuda-graph-impl local \
7274
--inference-dynamic-batching-num-cuda-graphs ${NUM_CUDA_GRAPHS} \
7375
"
74-
else
75-
ARGS+=" \
76-
--cuda-graph-impl none \
77-
"
7876
fi
7977

8078
# Prompts.

0 commit comments

Comments
 (0)