Skip to content

Commit 666df4f

Browse files
committed
clean up
1 parent 04ba668 commit 666df4f

File tree

3 files changed

+17
-45
lines changed

3 files changed

+17
-45
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,12 @@ def warmup(self, resource_manager: ResourceManager) -> None:
562562
# Reset the global cuda graph dummy request to None in warmup.
563563
self.cuda_graph_runner.padding_dummy_request = None
564564

565+
# TODO: current warmup_request is not suitable for context parallelism.
565566
cp_type = self.mapping.cp_config.get('cp_type', None)
566567
if cp_type is not None:
567-
if cp_type in [CpType.ULYSSES, CpType.STAR]:
568-
assert False, "cp_type must be HELIX for helix benchmarking."
569-
print("[ModelEngine::warmup] EARLY RETURN since cp_type ",
570-
cp_type)
571-
return
568+
logger.info("[ModelEngine::warmup] Skipping warmup for cp_type: ",
569+
cp_type.name)
570+
return
572571

573572
self._run_torch_compile_warmup(resource_manager)
574573
self._run_autotuner_warmup(resource_manager)
@@ -1063,12 +1062,10 @@ def _init_max_seq_len(self):
10631062
# NOTE: py_executor_creator makes sure that the executor uses this
10641063
# smaller value as its max_seq_len too.
10651064
logger.warning(
1066-
f"\n*******************************************************\n"
1067-
f"Specified {self.max_seq_len=} is larger than what the model can support\n"
1068-
f"({inferred_max_seq_len}). NOT Setting max_seq_len to {inferred_max_seq_len}. "
1069-
f"ARE YOU SURE ABOUT THIS?\n"
1070-
f"*******************************************************\n")
1071-
# self.max_seq_len = inferred_max_seq_len
1065+
f"Specified {self.max_seq_len=} is larger than what the model can support "
1066+
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
1067+
)
1068+
self.max_seq_len = inferred_max_seq_len
10721069

10731070
def _infer_max_seq_len_from_config(self) -> int:
10741071

@@ -2137,8 +2134,7 @@ def _prepare_tp_inputs_no_cache(
21372134
attn_metadata.padded_num_tokens = padded_num_tokens if padded_num_tokens != num_tokens else None
21382135

21392136
if self.enable_attention_dp:
2140-
all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens)
2141-
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
2137+
attn_metadata.all_rank_num_tokens = attn_all_rank_num_tokens
21422138

21432139
virtual_num_tokens = num_tokens
21442140
if attn_metadata.padded_num_tokens is not None:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,9 +1889,6 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
18891889

18901890
for req in scheduled_batch.generation_requests:
18911891
if req.is_disagg_generation_transmission_complete:
1892-
print(
1893-
"[PyExecutor::_prepare_disagg_gen_transmission_complete]: TRANSMISSION COMPLETE for request ID: ",
1894-
req.py_request_id)
18951892
req.state = LlmRequestState.GENERATION_IN_PROGRESS
18961893
req.context_current_position = req.prompt_len
18971894
req.decoding_iter = 1
@@ -1903,9 +1900,6 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
19031900
beam_width = req.sampling_config.beam_width
19041901

19051902
for beam in range(0, beam_width):
1906-
print(
1907-
f"[PyExecutor::_prepare_disagg_gen_transmission_complete]: Adding new token {torch.tensor(first_gen_tokens[beam])} for beam {beam}."
1908-
)
19091903
req.add_new_token(first_gen_tokens[beam], beam)
19101904

19111905
@nvtx_range("_recv_disagg_gen_cache")
@@ -2001,24 +1995,12 @@ def _forward_step(self,
20011995
)
20021996
def forward(scheduled_requests, resource_manager, new_tensors_device,
20031997
gather_context_logits, cache_indirection_buffer):
2004-
# iter_begin = time.time()
2005-
result = self.model_engine.forward(
1998+
return self.model_engine.forward(
20061999
scheduled_requests,
20072000
resource_manager,
20082001
new_tensors_device,
20092002
gather_context_logits=gather_context_logits,
20102003
cache_indirection_buffer=cache_indirection_buffer)
2011-
# torch.cuda.synchronize()
2012-
# iter_end = time.time()
2013-
# iter_latency_ms = (iter_end - iter_begin) * 1e3
2014-
# if self.model_engine.iter_counter > 10 and self.dist.rank == 0:
2015-
# logger.info(f"[PyExecutor::_forward_step] CUSTOM LOG: iter={self.model_engine.iter_counter}, "
2016-
# f"rank={self.dist.rank}, "
2017-
# f"active_requests={len(self.active_requests)}, "
2018-
# f"scheduled_generation_requests={len(scheduled_requests.generation_requests)}, "
2019-
# f"scheduled_batch_size={scheduled_requests.batch_size}, "
2020-
# f"iter_latency_ms={iter_latency_ms}ms")
2021-
return result
20222004

20232005
try:
20242006
gather_context_logits = any(
@@ -2085,8 +2067,7 @@ def _update_request_states_star_attention(
20852067
@nvtx_range("_update_request_states")
20862068
def _update_request_states(self, scheduled_requests: ScheduledRequests):
20872069
cp_config = self.dist.cp_config
2088-
# note: helix parallelism uses the same logic as tp parallelism here
2089-
if 'cp_type' in cp_config and cp_config['cp_type'] != CpType.HELIX:
2070+
if 'cp_type' in cp_config:
20902071
cp_type = cp_config['cp_type']
20912072
if cp_type == CpType.STAR:
20922073
self._update_request_states_star_attention(scheduled_requests)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ def __init__(
177177
indexer_k_cache_index_head_dim: int = 0,
178178
**kwargs,
179179
) -> None:
180-
# Couple of places where we assume tokens_per_block is 32: Let's assert here for now.
181-
# 1) block assignment in merge_helix_requests
182-
# 2) computation of cache_transceiver_config.max_tokens_in_buffer.
183-
assert tokens_per_block == 32, "tokens_per_block must be 32 for helix benchmarking."
184180
self.mapping = mapping
185181
self.dtype = dtype
186182
self.kv_cache_type = kv_cache_type
@@ -443,18 +439,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
443439

444440
for req in generation_batch:
445441
# Skip allocating KV cache at decode for inactive helix ranks.
446-
##################################################################
447-
# TODO: This should be set elsewhere. For now, we hardcode that last rank is active.
448-
# Maybe right after pyexecutor._schedule() or in sampler.update_requests() at end of
449-
# executor loop for next step.
450442
if self.mapping.has_cp_helix():
451443
if self.mapping.cp_rank != self.mapping.cp_size - 1:
452444
req.py_helix_is_inactive_rank = True
453-
##################################################################
454445
if req.py_helix_is_inactive_rank:
455-
# print(f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Skipping KV allocation for request {req.py_request_id}.")
446+
print(
447+
f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Skipping KV allocation for request {req.py_request_id}."
448+
)
456449
continue
457-
# print(f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Adding KV allocation for request {req.py_request_id}.")
450+
print(
451+
f"[ResourceManager::prepare_resources][rank {self.mapping.rank}] Adding KV allocation for request {req.py_request_id}."
452+
)
458453
self.impl.add_token(req.py_request_id)
459454
for _ in range(get_draft_token_length(req)):
460455
self.impl.add_token(req.py_request_id)

0 commit comments

Comments
 (0)