Skip to content

Commit 1eae941

Browse files
authored
[#9237][feat] enable iter stats in autodeploy (#9278)
Signed-off-by: Shreyas Misra <[email protected]>
1 parent a7c0b54 commit 1eae941

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
197197
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
198198
"properly passed through.",
199199
)
200+
enable_iter_perf_stats: bool = Field(
201+
default=False, description="Enable iteration performance statistics.", status="prototype"
202+
)
203+
204+
enable_iter_req_stats: bool = Field(
205+
default=False,
206+
description="If true, enables per request stats per iteration. Must also set "
207+
"enable_iter_perf_stats to true to get request stats.",
208+
status="prototype",
209+
)
200210

201211
### VALIDATION #################################################################################
202212
@model_validator(mode="after")

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
from collections import defaultdict
13+
from dataclasses import dataclass
1314
from types import SimpleNamespace
1415
from typing import Dict, List, Optional, Tuple
1516

@@ -46,6 +47,13 @@
4647
from .interface import CachedSequenceInterface, GetInferenceModel
4748

4849

50+
@dataclass
51+
class ReportingInfo:
52+
print_log: bool = False
53+
enable_iter_perf_stats: bool = False
54+
enable_iter_req_stats: bool = False
55+
56+
4957
class _CacheManagerWithFakePool(KVCacheManager):
5058
"""We use the default KVCacheManager but with a fake pool by setting head_dim=0.
5159
@@ -123,14 +131,19 @@ def build_from_config(cls, ad_config: LlmArgs):
123131
vocab_size_padded=factory.vocab_size_padded,
124132
chunk_size=factory.chunk_size,
125133
)
134+
reporting_info = ReportingInfo(
135+
print_log=False,
136+
enable_iter_perf_stats=ad_config.enable_iter_perf_stats,
137+
enable_iter_req_stats=ad_config.enable_iter_req_stats,
138+
)
126139
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
127140
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
128141

129142
# construct inference optimizer
130143
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)
131144

132145
# construct engine
133-
return cls(build_and_optimize, seq_info, device, max_beam_width)
146+
return cls(build_and_optimize, seq_info, device, max_beam_width, reporting_info)
134147

135148
@torch.inference_mode()
136149
def __init__(
@@ -139,20 +152,23 @@ def __init__(
139152
seq_info: SequenceInfo,
140153
device: DeviceLikeType,
141154
max_beam_width: int = 1,
155+
reporting_info: ReportingInfo = ReportingInfo(),
142156
) -> None:
143157
"""Initialize the engine with model and sequence information."""
144158
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
145159
# This is not correctly declared in the base ModelEngine class though...
146160
self.llm_args = SimpleNamespace()
147-
self.llm_args.print_iter_log = False
148-
self.llm_args.enable_iter_perf_stats = False
149-
self.llm_args.enable_iter_req_stats = False
161+
self.llm_args.print_iter_log = reporting_info.print_log
162+
self.llm_args.enable_iter_perf_stats = reporting_info.enable_iter_perf_stats
163+
self.llm_args.enable_iter_req_stats = reporting_info.enable_iter_req_stats
150164
self.llm_args.stream_interval = 1
151165
self.llm_args.attention_dp_config = None
152166
self.llm_args.batch_wait_timeout_ms = 0
153167
self.llm_args.batch_wait_timeout_iters = 0
154168
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155169
self.llm_args.max_num_tokens = seq_info.max_num_tokens
170+
self.iter_counter = 0
171+
self.iter_states = {}
156172

157173
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
158174
self.max_beam_width = max_beam_width
@@ -196,6 +212,9 @@ def _prepare_inputs(
196212
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
197213

198214
dummy_token = -1
215+
num_ctx_requests = len(context_requests)
216+
num_ctx_tokens = 0
217+
num_generation_tokens = 0
199218

200219
# look at context requests first
201220
for request in context_requests:
@@ -206,6 +225,7 @@ def _prepare_inputs(
206225
begin_compute = request.context_current_position
207226
end_compute = begin_compute + request.context_chunk_size
208227
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]
228+
num_ctx_tokens += len(prompt_tokens)
209229

210230
input_ids.append(prompt_tokens)
211231
input_pos.append(begin_compute)
@@ -238,6 +258,7 @@ def _prepare_inputs(
238258
input_pos.append(request.max_beam_num_tokens)
239259
flat_gather_idx.append(request.py_batch_idx)
240260

261+
num_generation_tokens += 1
241262
request.py_batch_idx = request.seq_slot
242263

243264
# store seq slot idx
@@ -267,6 +288,10 @@ def _prepare_inputs(
267288
scatter_ref=dummy_token,
268289
)
269290

291+
self.iter_states["num_ctx_requests"] = num_ctx_requests
292+
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
293+
# TODO: handle extend requests and draft requests for specdec
294+
self.iter_states["num_generation_tokens"] = num_generation_tokens
270295
return last_logit_only
271296

272297
@nvtx_range("ad_compute_logits")
@@ -294,6 +319,7 @@ def forward(
294319
# convert requests and store in sequence info object
295320
new_tokens = getattr(new_tensors_device, "new_tokens", None)
296321
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
322+
self.iter_counter += 1
297323

298324
# compute all logits
299325
logits = self._compute_logits()

tensorrt_llm/bench/benchmark/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ def get_llm(runtime_config: RuntimeConfig, kwargs: dict):
107107
if runtime_config.backend != None:
108108
ignore_trt_only_args(kwargs, runtime_config.backend)
109109

110+
if runtime_config.iteration_log is not None:
111+
kwargs["enable_iter_perf_stats"] = True
112+
110113
if runtime_config.backend == 'pytorch':
111114
llm_cls = PyTorchLLM
112115

113-
if runtime_config.iteration_log is not None:
114-
kwargs["enable_iter_perf_stats"] = True
115-
116116
elif runtime_config.backend == "_autodeploy":
117117
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
118118

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@ def run_benchmark(
3131
"_autodeploy",
3232
"--dataset",
3333
dataset_path,
34+
"--iteration_log",
35+
"iteration_log.log",
3436
"--extra_llm_api_options",
3537
f"{extra_llm_api_options_path}",
3638
]
3739
)
3840
result = runner.invoke(main, args, catch_exceptions=False)
3941
assert result.exit_code == 0
4042

43+
with open("iteration_log.log", "r") as f:
44+
lines = f.readlines()
45+
assert len(lines) > 0
46+
# TODO: add more checks
47+
4148

4249
def prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str):
4350
_DATASET_NAME = "synthetic_128_128.txt"

0 commit comments

Comments
 (0)