1010# limitations under the License.
1111
1212from collections import defaultdict
13+ from dataclasses import dataclass
1314from types import SimpleNamespace
1415from typing import Dict , List , Optional , Tuple
1516
4647from .interface import CachedSequenceInterface , GetInferenceModel
4748
4849
50+ @dataclass
51+ class ReportingInfo :
52+ print_log : bool = False
53+ enable_iter_perf_stats : bool = False
54+ enable_iter_req_stats : bool = False
55+
56+
4957class _CacheManagerWithFakePool (KVCacheManager ):
5058 """We use the default KVCacheManager but with a fake pool by setting head_dim=0.
5159
@@ -123,14 +131,19 @@ def build_from_config(cls, ad_config: LlmArgs):
123131 vocab_size_padded = factory .vocab_size_padded ,
124132 chunk_size = factory .chunk_size ,
125133 )
134+ reporting_info = ReportingInfo (
135+ print_log = False ,
136+ enable_iter_perf_stats = ad_config .enable_iter_perf_stats ,
137+ enable_iter_req_stats = ad_config .enable_iter_req_stats ,
138+ )
126139 # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
127140 # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
128141
129142 # construct inference optimizer
130143 build_and_optimize = InferenceOptimizer (factory = factory , config = ad_config .transforms )
131144
132145 # construct engine
133- return cls (build_and_optimize , seq_info , device , max_beam_width )
146+ return cls (build_and_optimize , seq_info , device , max_beam_width , reporting_info )
134147
135148 @torch .inference_mode ()
136149 def __init__ (
@@ -139,20 +152,23 @@ def __init__(
139152 seq_info : SequenceInfo ,
140153 device : DeviceLikeType ,
141154 max_beam_width : int = 1 ,
155+ reporting_info : ReportingInfo = ReportingInfo (),
142156 ) -> None :
143157 """Initialize the engine with model and sequence information."""
144158 # NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
145159 # This is not correctly declared in the base ModelEngine class though...
146160 self .llm_args = SimpleNamespace ()
147- self .llm_args .print_iter_log = False
148- self .llm_args .enable_iter_perf_stats = False
149- self .llm_args .enable_iter_req_stats = False
161+ self .llm_args .print_iter_log = reporting_info . print_log
162+ self .llm_args .enable_iter_perf_stats = reporting_info . enable_iter_perf_stats
163+ self .llm_args .enable_iter_req_stats = reporting_info . enable_iter_req_stats
150164 self .llm_args .stream_interval = 1
151165 self .llm_args .attention_dp_config = None
152166 self .llm_args .batch_wait_timeout_ms = 0
153167 self .llm_args .batch_wait_timeout_iters = 0
154168 self .llm_args .batch_wait_max_tokens_ratio = 0.0
155169 self .llm_args .max_num_tokens = seq_info .max_num_tokens
170+ self .iter_counter = 0
171+ self .iter_states = {}
156172
157173 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
158174 self .max_beam_width = max_beam_width
@@ -196,6 +212,9 @@ def _prepare_inputs(
196212 extra_args : Dict [str , List [torch .Tensor ]] = defaultdict (list )
197213
198214 dummy_token = - 1
215+ num_ctx_requests = len (context_requests )
216+ num_ctx_tokens = 0
217+ num_generation_tokens = 0
199218
200219 # look at context requests first
201220 for request in context_requests :
@@ -206,6 +225,7 @@ def _prepare_inputs(
206225 begin_compute = request .context_current_position
207226 end_compute = begin_compute + request .context_chunk_size
208227 prompt_tokens = all_prompt_tokens [begin_compute :end_compute ]
228+ num_ctx_tokens += len (prompt_tokens )
209229
210230 input_ids .append (prompt_tokens )
211231 input_pos .append (begin_compute )
@@ -238,6 +258,7 @@ def _prepare_inputs(
238258 input_pos .append (request .max_beam_num_tokens )
239259 flat_gather_idx .append (request .py_batch_idx )
240260
261+ num_generation_tokens += 1
241262 request .py_batch_idx = request .seq_slot
242263
243264 # store seq slot idx
@@ -267,6 +288,10 @@ def _prepare_inputs(
267288 scatter_ref = dummy_token ,
268289 )
269290
291+ self .iter_states ["num_ctx_requests" ] = num_ctx_requests
292+ self .iter_states ["num_ctx_tokens" ] = num_ctx_tokens
293+ # TODO: handle extend requests and draft requests for specdec
294+ self .iter_states ["num_generation_tokens" ] = num_generation_tokens
270295 return last_logit_only
271296
272297 @nvtx_range ("ad_compute_logits" )
@@ -294,6 +319,7 @@ def forward(
294319 # convert requests and store in sequence info object
295320 new_tokens = getattr (new_tensors_device , "new_tokens" , None )
296321 last_logit_only = self ._prepare_inputs (scheduled_requests , resource_manager , new_tokens )
322+ self .iter_counter += 1
297323
298324 # compute all logits
299325 logits = self ._compute_logits ()
0 commit comments