Skip to content

Commit 50436a1

Browse files
committed
more cleanup
1 parent 6e5fe5d commit 50436a1

File tree

11 files changed

+24
-304
lines changed

11 files changed

+24
-304
lines changed
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
["Global warming is the long term rise in Earth temperature caused by greenhouse gases from human activity, burning fossil fuels, and deforestation. It leads to melting ice, rising seas, and extreme weather that threaten ecosystems, wildlife, and people. Urgent global action is "]
1+
[
2+
"What is the capital of Germany?",
3+
"Explain the theory of relativity.",
4+
"What are the benefits of using asyncio in Python?",
5+
"Describe the process of photosynthesis.",
6+
"How does a blockchain work?"
7+
]

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,6 @@ def __init__(self, mapping: Mapping):
347347
mapping_with_helix = None
348348
if self.mapping.cp_size > 1:
349349
print(f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
350-
# TODO: More principled thing to do would be to update mapping to account for
351-
# repurposing of CP ranks to TP.
352350
mapping_with_helix = copy.deepcopy(self.mapping)
353351
mapping_without_helix = Mapping(
354352
world_size=self.mapping.world_size,
@@ -401,15 +399,20 @@ def recv_object(self, src, tag=0):
401399
return mpi_recv_object(src, tag)
402400

403401
def create_tp_comm(self):
404-
print(f"[MPIDist::create_tp_comm] rank: {self.mapping.rank}, tp_rank: {self.mapping.tp_rank}, tp_group: {self.mapping.tp_group}")
405402
new_group = mpi_comm().group.Incl(self.mapping.tp_group)
406403
self.tp_comm = mpi_comm().Create_group(new_group)
407404

408405
def create_pp_comm(self):
409-
print(f"[MPIDist::create_pp_comm] rank: {self.mapping.rank}, pp_rank: {self.mapping.pp_rank}, pp_group: {self.mapping.pp_group}")
410406
new_group = mpi_comm().group.Incl(self.mapping.pp_group)
411407
self.pp_comm = mpi_comm().Create_group(new_group)
412408

409+
def create_cp_comm(self):
410+
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
411+
self.cp_comm = mpi_comm().Create_group(new_group)
412+
413+
def cp_allgather(self, obj):
414+
return self.cp_comm.allgather(obj)
415+
413416
def tp_allgather(self, obj):
414417
return self.tp_comm.allgather(obj)
415418

@@ -430,14 +433,6 @@ def pp_gather(self, obj):
430433
def pp_broadcast(self, obj, root=0):
431434
return self.pp_comm.bcast(obj, root)
432435

433-
def create_cp_comm(self):
434-
print(f"[MPIDist::create_cp_comm] rank: {self.mapping.rank}, cp_rank: {self.mapping.cp_rank}, cp_group: {self.mapping.cp_group}")
435-
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
436-
self.cp_comm = mpi_comm().Create_group(new_group)
437-
438-
def cp_allgather(self, obj):
439-
return self.cp_comm.allgather(obj)
440-
441436

442437
class MultiHandleWrapper:
443438
"""

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 0 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,6 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
15771577
PretrainedConfig]):
15781578

15791579
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
1580-
###############################################################################
15811580
self.mapping_with_cp = None
15821581
# Note: Currently the usage of mapping is all over the place making its usage brittle
15831582
# in this file. As a temporary WAR, we hold on to an original copy of mapping when CP
@@ -1606,7 +1605,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16061605
moe_ep_size=model_config.mapping.moe_ep_size,
16071606
enable_attention_dp=model_config.mapping.enable_attention_dp)
16081607
model_config._frozen = True
1609-
###############################################################################
16101608

16111609
# Rename some keys of quant_config_dict to support legacy checkpoints
16121610
if model_config.quant_config_dict is not None:
@@ -1656,7 +1654,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16561654
self.epilogue.extend(self.draft_model.mtp_layers)
16571655
self.epilogue.append(self.spec_worker)
16581656

1659-
###############################################################################
16601657
# Undo any manipulations done to mapping.
16611658
if self.mapping_with_cp is not None:
16621659
print(
@@ -1665,7 +1662,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16651662
model_config._frozen = False
16661663
model_config.mapping = self.mapping_with_cp
16671664
model_config._frozen = True
1668-
###############################################################################
16691665

16701666
def forward(
16711667
self,
@@ -1677,33 +1673,6 @@ def forward(
16771673
return_context_logits: bool = False,
16781674
**kwargs,
16791675
) -> torch.Tensor:
1680-
# with use_torch_printoptions(sci_mode=False,
1681-
# threshold=16,
1682-
# edgeitems=2,
1683-
# linewidth=120):
1684-
# print(
1685-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] input_ids: {input_ids}"
1686-
# )
1687-
# print(
1688-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] position_ids: {position_ids}"
1689-
# )
1690-
# print(
1691-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] helix_is_inactive_rank: {attn_metadata.helix_is_inactive_rank}"
1692-
# )
1693-
# print(
1694-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] kv_cache_params.num_cached_tokens_per_seq: {attn_metadata.kv_cache_params.num_cached_tokens_per_seq}"
1695-
# )
1696-
# print(
1697-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] kv_lens_cuda: {attn_metadata.kv_lens_cuda}"
1698-
# )
1699-
# assert attn_metadata.kv_cache_manager.tokens_per_block == 32
1700-
# block_ids_per_seq = attn_metadata.kv_cache_manager.get_batch_cache_indices(
1701-
# attn_metadata.request_ids)
1702-
# for request_id, block_ids in zip(attn_metadata.request_ids,
1703-
# block_ids_per_seq):
1704-
# print(
1705-
# f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] request_id: {request_id}, block_ids: {torch.tensor(block_ids)}"
1706-
# )
17071676
return super().forward(attn_metadata=attn_metadata,
17081677
input_ids=input_ids,
17091678
position_ids=position_ids,
@@ -1712,147 +1681,6 @@ def forward(
17121681
return_context_logits=return_context_logits,
17131682
**kwargs)
17141683

1715-
def _save_block_information_to_disk(self, attn_metadata: AttentionMetadata,
1716-
position_ids: torch.Tensor):
1717-
"""Save KV cache block information to disk using safetensors format."""
1718-
import json
1719-
from pathlib import Path
1720-
1721-
import safetensors.torch
1722-
1723-
# Only save on rank 0 in prefill mode.
1724-
if (attn_metadata.helix_is_inactive_rank is not None
1725-
or self.model_config.mapping.rank != 0
1726-
or len(position_ids[0]) != 52):
1727-
return
1728-
1729-
# Create directory for saving block data
1730-
save_dir = Path(
1731-
"/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers"
1732-
)
1733-
save_dir.mkdir(exist_ok=True)
1734-
1735-
block_ids_per_seq = attn_metadata.kv_cache_manager.get_batch_cache_indices(
1736-
attn_metadata.request_ids)
1737-
for request_id, block_ids in zip(attn_metadata.request_ids,
1738-
block_ids_per_seq):
1739-
# Save blocks for requests with exactly 2 blocks.
1740-
if len(block_ids) == 2:
1741-
request_save_dir = save_dir / f"request_{request_id}"
1742-
request_save_dir.mkdir(exist_ok=True)
1743-
1744-
# Iterate through all layers and save KV cache buffers for each layer.
1745-
for layer_idx in range(self.config.num_hidden_layers):
1746-
# Get KV cache buffers for this layer.
1747-
kv_buffer = attn_metadata.kv_cache_manager.get_buffers(
1748-
layer_idx)
1749-
1750-
# Save each block separately for this layer.
1751-
for i, block_id in enumerate(block_ids):
1752-
# Get block data from KV cache for this layer.
1753-
request_kv_data = kv_buffer[block_id]
1754-
1755-
# Create separate data dictionary for this block.
1756-
block_data = {"block_data": request_kv_data.cpu()}
1757-
1758-
# Create separate metadata for this block, including layer information.
1759-
block_metadata = {
1760-
"request_id": int(request_id),
1761-
"layer_idx": int(layer_idx),
1762-
"block_id": int(block_id),
1763-
"block_index": i,
1764-
"block_shape": list(request_kv_data.shape),
1765-
"tokens_per_block":
1766-
attn_metadata.kv_cache_manager.tokens_per_block,
1767-
"rank": self.model_config.mapping.rank,
1768-
}
1769-
1770-
# Save each block's data separately using safetensors, including layer in filename.
1771-
block_safetensors_path = request_save_dir / f"layer_{layer_idx}_block_id_{block_id}_rank_{self.model_config.mapping.rank}.safetensors"
1772-
safetensors.torch.save_file(block_data,
1773-
str(block_safetensors_path))
1774-
1775-
# Save each block's metadata separately as JSON, including layer in filename.
1776-
block_metadata_path = request_save_dir / f"layer_{layer_idx}_block_id_{block_id}_rank_{self.model_config.mapping.rank}_metadata.json"
1777-
with open(block_metadata_path, 'w') as f:
1778-
json.dump(block_metadata, f, indent=2)
1779-
1780-
print(
1781-
f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank {self.model_config.mapping.rank}] "
1782-
f"Saved layer {layer_idx} block (ID: {block_id}) for request {request_id}, shape: {request_kv_data.shape} "
1783-
f"to {block_safetensors_path.name}")
1784-
1785-
print(
1786-
f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank {self.model_config.mapping.rank}] "
1787-
f"Saved block information for request {request_id} to {request_save_dir}"
1788-
)
1789-
1790-
def _read_block_information_from_disk(self,
1791-
attn_metadata: AttentionMetadata,
1792-
position_ids: torch.Tensor):
1793-
"""Read KV cache block information from disk using safetensors format."""
1794-
from pathlib import Path
1795-
1796-
import safetensors.torch
1797-
1798-
# Early return in prefill mode.
1799-
if (attn_metadata.helix_is_inactive_rank is None):
1800-
return
1801-
1802-
# Early return if this isn't the first decode step.
1803-
if (position_ids[0][0].item() != 52):
1804-
print(
1805-
f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank {self.model_config.mapping.rank}] "
1806-
f"Early return in decode mode because this isn't the first decode step {position_ids[0][0].item()}"
1807-
)
1808-
return
1809-
1810-
block_ids_per_seq = attn_metadata.kv_cache_manager.get_batch_cache_indices(
1811-
attn_metadata.request_ids)
1812-
for request_id, block_ids in zip(attn_metadata.request_ids,
1813-
block_ids_per_seq):
1814-
1815-
# Read blocks for requests with exactly 1 block.
1816-
assert len(block_ids) == 1
1817-
1818-
# Read KV cache for all layers.
1819-
for layer_idx in range(self.config.num_hidden_layers):
1820-
# Determine file path based on rank and layer.
1821-
if self.model_config.mapping.rank == 0:
1822-
# Inactive rank.
1823-
read_file = Path(
1824-
f"/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers/request_2048/layer_{layer_idx}_block_id_257_rank_0.safetensors"
1825-
)
1826-
else:
1827-
# Active rank.
1828-
read_file = Path(
1829-
f"/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers/request_2048/layer_{layer_idx}_block_id_258_rank_0.safetensors"
1830-
)
1831-
1832-
# Get KV cache buffers for this layer.
1833-
kv_buffer = attn_metadata.kv_cache_manager.get_buffers(
1834-
layer_idx)
1835-
1836-
# Get block data from KV cache.
1837-
request_kv_data = kv_buffer[block_ids[0]]
1838-
1839-
# Load block data from disk.
1840-
loaded_data = safetensors.torch.load_file(read_file)
1841-
block_read_data = loaded_data['block_data'].to(
1842-
request_kv_data.device)
1843-
1844-
# Copy block data to KV cache.
1845-
request_kv_data.copy_(block_read_data)
1846-
1847-
print(
1848-
f"[DeepseekV3ForCausalLM::_read_block_information_from_disk][rank {self.model_config.mapping.rank}] "
1849-
f"Layer {layer_idx}: request_kv_data: {request_kv_data}")
1850-
1851-
print(
1852-
f"[DeepseekV3ForCausalLM::_read_block_information_from_disk][rank {self.model_config.mapping.rank}] "
1853-
f"Read block data for request {request_id}, layer {layer_idx}, shape: {block_read_data.shape} "
1854-
f"from {read_file.name}")
1855-
18561684
def load_weights(self, weights: Dict):
18571685
weight_loader = DeepseekV3WeightLoader(self)
18581686
weight_loader.load_weights(weights)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def __init__(
709709
self.hidden_size = hidden_size
710710
self.num_heads = num_attention_heads
711711
self.num_key_value_heads = num_key_value_heads
712+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
712713
assert self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads"
713714
self.qk_nope_head_dim = qk_nope_head_dim
714715
self.qk_rope_head_dim = qk_rope_head_dim
@@ -761,7 +762,7 @@ def __init__(
761762
if self.mapping.has_cp_ulysses():
762763
raise NotImplementedError("MLA doesn't support CP Ulyssees yet")
763764
if self.mapping.cp_size > 1:
764-
assert self.mapping.cp_config['cp_type'] == CpType.HELIX
765+
assert self.mapping.cp_config['cp_type'] == CpType.HELIX, f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."
765766

766767
mapping = Mapping(
767768
world_size=tp_size * pp_size * cp_size,
@@ -1093,9 +1094,6 @@ def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
10931094

10941095
def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
10951096
num_tokens = hidden_states.shape[0]
1096-
# note: for testing Helix parallelism, we ensure that the output is
1097-
# large enough for the context phase, but we then cut it again in
1098-
# `forward_context`
10991097
hidden_size = self.o_proj.in_features
11001098
if self.enable_unit_test and num_contexts > 0:
11011099
# note: for testing Helix parallelism, we ensure that the output is

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def _fetch_and_process_requests(
313313
new_requests = self._validate_and_filter_requests(new_requests)
314314

315315
# Attach Python objects to requests
316-
# @B: What's the significance of this condition?
317316
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp
318317
or self.dist.cp_size
319318
> 1) and self.dist.rank > 0:
@@ -693,13 +692,6 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
693692
input_ids_this_rank = input_ids_this_rank[:-padding_len]
694693
position_ids_this_rank = position_ids_this_rank[:-padding_len]
695694

696-
print(
697-
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: input_ids_this_rank: {torch.tensor(input_ids_this_rank)}"
698-
)
699-
print(
700-
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: position_ids_this_rank: {torch.tensor(position_ids_this_rank)}"
701-
)
702-
# TODO: Figure how to pass down position_ids_this_rank to LLMRequest.
703695
req = executor_request_to_llm_request(
704696
req_id=req_item.id,
705697
executor_request=req_item.request,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,6 @@ def warmup(self, resource_manager: ResourceManager) -> None:
562562
cp_type = self.mapping.cp_config.get('cp_type', None)
563563
if cp_type is not None:
564564
if cp_type in [CpType.ULYSSES, CpType.STAR]:
565-
assert False, "cp_type must be HELIX for helix benchmarking."
566-
print("[ModelEngine::warmup] EARLY RETURN since cp_type ", cp_type)
567565
return
568566

569567
self._run_torch_compile_warmup(resource_manager)
@@ -1059,14 +1057,10 @@ def _init_max_seq_len(self):
10591057
# NOTE: py_executor_creator makes sure that the executor uses this
10601058
# smaller value as its max_seq_len too.
10611059
logger.warning(
1062-
f"\n*******************************************************\n"
1063-
f"Specified {self.max_seq_len=} is larger than what the model can support\n"
1064-
f"({inferred_max_seq_len}). NOT Setting max_seq_len to {inferred_max_seq_len}. "
1065-
f"ARE YOU SURE ABOUT THIS?\n"
1066-
f"*******************************************************\n"
1060+
f"Specified {self.max_seq_len=} is larger than what the model can support "
1061+
f"({inferred_max_seq_len}). Setting max_seq_len to {inferred_max_seq_len}. "
10671062
)
1068-
# self.max_seq_len = inferred_max_seq_len
1069-
pass
1063+
self.max_seq_len = inferred_max_seq_len
10701064

10711065
def _infer_max_seq_len_from_config(self) -> int:
10721066

@@ -2134,9 +2128,7 @@ def _prepare_tp_inputs_no_cache(
21342128
attn_metadata.padded_num_tokens = padded_num_tokens if padded_num_tokens != num_tokens else None
21352129

21362130
if self.enable_attention_dp:
2137-
all_rank_num_tokens = self.dist.allgather(
2138-
attn_metadata.num_tokens)
2139-
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
2131+
attn_metadata.all_rank_num_tokens = attn_all_rank_num_tokens
21402132

21412133
virtual_num_tokens = num_tokens
21422134
if attn_metadata.padded_num_tokens is not None:
@@ -2195,9 +2187,7 @@ def _prepare_tp_inputs_no_cache(
21952187
spec_all_rank_num_tokens = [
21962188
item[1] for item in all_rank_num_tokens
21972189
]
2198-
all_rank_num_seqs = [
2199-
item[2] for item in all_rank_num_tokens
2200-
]
2190+
all_rank_num_seqs = [item[2] for item in all_rank_num_tokens]
22012191
attn_metadata.all_rank_num_tokens = attn_all_rank_num_tokens
22022192
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
22032193
spec_metadata.all_rank_num_seqs = all_rank_num_seqs

0 commit comments

Comments
 (0)