@@ -1577,7 +1577,6 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
15771577 PretrainedConfig ]):
15781578
15791579 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
1580- ###############################################################################
15811580 self .mapping_with_cp = None
15821581 # Note: Currently the usage of mapping is all over the place making its usage brittle
15831582 # in this file. As a temporary WAR, we hold on to an original copy of mapping when CP
@@ -1606,7 +1605,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16061605 moe_ep_size = model_config .mapping .moe_ep_size ,
16071606 enable_attention_dp = model_config .mapping .enable_attention_dp )
16081607 model_config ._frozen = True
1609- ###############################################################################
16101608
16111609 # Rename some keys of quant_config_dict to support legacy checkpoints
16121610 if model_config .quant_config_dict is not None :
@@ -1656,7 +1654,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16561654 self .epilogue .extend (self .draft_model .mtp_layers )
16571655 self .epilogue .append (self .spec_worker )
16581656
1659- ###############################################################################
16601657 # Undo any manipulations done to mapping.
16611658 if self .mapping_with_cp is not None :
16621659 print (
@@ -1665,7 +1662,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
16651662 model_config ._frozen = False
16661663 model_config .mapping = self .mapping_with_cp
16671664 model_config ._frozen = True
1668- ###############################################################################
16691665
16701666 def forward (
16711667 self ,
@@ -1677,33 +1673,6 @@ def forward(
16771673 return_context_logits : bool = False ,
16781674 ** kwargs ,
16791675 ) -> torch .Tensor :
1680- # with use_torch_printoptions(sci_mode=False,
1681- # threshold=16,
1682- # edgeitems=2,
1683- # linewidth=120):
1684- # print(
1685- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] input_ids: {input_ids}"
1686- # )
1687- # print(
1688- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] position_ids: {position_ids}"
1689- # )
1690- # print(
1691- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] helix_is_inactive_rank: {attn_metadata.helix_is_inactive_rank}"
1692- # )
1693- # print(
1694- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] kv_cache_params.num_cached_tokens_per_seq: {attn_metadata.kv_cache_params.num_cached_tokens_per_seq}"
1695- # )
1696- # print(
1697- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] kv_lens_cuda: {attn_metadata.kv_lens_cuda}"
1698- # )
1699- # assert attn_metadata.kv_cache_manager.tokens_per_block == 32
1700- # block_ids_per_seq = attn_metadata.kv_cache_manager.get_batch_cache_indices(
1701- # attn_metadata.request_ids)
1702- # for request_id, block_ids in zip(attn_metadata.request_ids,
1703- # block_ids_per_seq):
1704- # print(
1705- # f"[DeepseekV3ForCausalLM::forward][rank {self.model_config.mapping.rank}] request_id: {request_id}, block_ids: {torch.tensor(block_ids)}"
1706- # )
17071676 return super ().forward (attn_metadata = attn_metadata ,
17081677 input_ids = input_ids ,
17091678 position_ids = position_ids ,
@@ -1712,147 +1681,6 @@ def forward(
17121681 return_context_logits = return_context_logits ,
17131682 ** kwargs )
17141683
1715- def _save_block_information_to_disk (self , attn_metadata : AttentionMetadata ,
1716- position_ids : torch .Tensor ):
1717- """Save KV cache block information to disk using safetensors format."""
1718- import json
1719- from pathlib import Path
1720-
1721- import safetensors .torch
1722-
1723- # Only save on rank 0 in prefill mode.
1724- if (attn_metadata .helix_is_inactive_rank is not None
1725- or self .model_config .mapping .rank != 0
1726- or len (position_ids [0 ]) != 52 ):
1727- return
1728-
1729- # Create directory for saving block data
1730- save_dir = Path (
1731- "/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers"
1732- )
1733- save_dir .mkdir (exist_ok = True )
1734-
1735- block_ids_per_seq = attn_metadata .kv_cache_manager .get_batch_cache_indices (
1736- attn_metadata .request_ids )
1737- for request_id , block_ids in zip (attn_metadata .request_ids ,
1738- block_ids_per_seq ):
1739- # Save blocks for requests with exactly 2 blocks.
1740- if len (block_ids ) == 2 :
1741- request_save_dir = save_dir / f"request_{ request_id } "
1742- request_save_dir .mkdir (exist_ok = True )
1743-
1744- # Iterate through all layers and save KV cache buffers for each layer.
1745- for layer_idx in range (self .config .num_hidden_layers ):
1746- # Get KV cache buffers for this layer.
1747- kv_buffer = attn_metadata .kv_cache_manager .get_buffers (
1748- layer_idx )
1749-
1750- # Save each block separately for this layer.
1751- for i , block_id in enumerate (block_ids ):
1752- # Get block data from KV cache for this layer.
1753- request_kv_data = kv_buffer [block_id ]
1754-
1755- # Create separate data dictionary for this block.
1756- block_data = {"block_data" : request_kv_data .cpu ()}
1757-
1758- # Create separate metadata for this block, including layer information.
1759- block_metadata = {
1760- "request_id" : int (request_id ),
1761- "layer_idx" : int (layer_idx ),
1762- "block_id" : int (block_id ),
1763- "block_index" : i ,
1764- "block_shape" : list (request_kv_data .shape ),
1765- "tokens_per_block" :
1766- attn_metadata .kv_cache_manager .tokens_per_block ,
1767- "rank" : self .model_config .mapping .rank ,
1768- }
1769-
1770- # Save each block's data separately using safetensors, including layer in filename.
1771- block_safetensors_path = request_save_dir / f"layer_{ layer_idx } _block_id_{ block_id } _rank_{ self .model_config .mapping .rank } .safetensors"
1772- safetensors .torch .save_file (block_data ,
1773- str (block_safetensors_path ))
1774-
1775- # Save each block's metadata separately as JSON, including layer in filename.
1776- block_metadata_path = request_save_dir / f"layer_{ layer_idx } _block_id_{ block_id } _rank_{ self .model_config .mapping .rank } _metadata.json"
1777- with open (block_metadata_path , 'w' ) as f :
1778- json .dump (block_metadata , f , indent = 2 )
1779-
1780- print (
1781- f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank { self .model_config .mapping .rank } ] "
1782- f"Saved layer { layer_idx } block (ID: { block_id } ) for request { request_id } , shape: { request_kv_data .shape } "
1783- f"to { block_safetensors_path .name } " )
1784-
1785- print (
1786- f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank { self .model_config .mapping .rank } ] "
1787- f"Saved block information for request { request_id } to { request_save_dir } "
1788- )
1789-
1790- def _read_block_information_from_disk (self ,
1791- attn_metadata : AttentionMetadata ,
1792- position_ids : torch .Tensor ):
1793- """Read KV cache block information from disk using safetensors format."""
1794- from pathlib import Path
1795-
1796- import safetensors .torch
1797-
1798- # Early return in prefill mode.
1799- if (attn_metadata .helix_is_inactive_rank is None ):
1800- return
1801-
1802- # Early return if this isn't the first decode step.
1803- if (position_ids [0 ][0 ].item () != 52 ):
1804- print (
1805- f"[DeepseekV3ForCausalLM::_save_block_information_to_disk][rank { self .model_config .mapping .rank } ] "
1806- f"Early return in decode mode because this isn't the first decode step { position_ids [0 ][0 ].item ()} "
1807- )
1808- return
1809-
1810- block_ids_per_seq = attn_metadata .kv_cache_manager .get_batch_cache_indices (
1811- attn_metadata .request_ids )
1812- for request_id , block_ids in zip (attn_metadata .request_ids ,
1813- block_ids_per_seq ):
1814-
1815- # Read blocks for requests with exactly 1 block.
1816- assert len (block_ids ) == 1
1817-
1818- # Read KV cache for all layers.
1819- for layer_idx in range (self .config .num_hidden_layers ):
1820- # Determine file path based on rank and layer.
1821- if self .model_config .mapping .rank == 0 :
1822- # Inactive rank.
1823- read_file = Path (
1824- f"/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers/request_2048/layer_{ layer_idx } _block_id_257_rank_0.safetensors"
1825- )
1826- else :
1827- # Active rank.
1828- read_file = Path (
1829- f"/home/bbuddharaju/scratch/TensorRT-LLM_MK/prefill_helix_all_layers/request_2048/layer_{ layer_idx } _block_id_258_rank_0.safetensors"
1830- )
1831-
1832- # Get KV cache buffers for this layer.
1833- kv_buffer = attn_metadata .kv_cache_manager .get_buffers (
1834- layer_idx )
1835-
1836- # Get block data from KV cache.
1837- request_kv_data = kv_buffer [block_ids [0 ]]
1838-
1839- # Load block data from disk.
1840- loaded_data = safetensors .torch .load_file (read_file )
1841- block_read_data = loaded_data ['block_data' ].to (
1842- request_kv_data .device )
1843-
1844- # Copy block data to KV cache.
1845- request_kv_data .copy_ (block_read_data )
1846-
1847- print (
1848- f"[DeepseekV3ForCausalLM::_read_block_information_from_disk][rank { self .model_config .mapping .rank } ] "
1849- f"Layer { layer_idx } : request_kv_data: { request_kv_data } " )
1850-
1851- print (
1852- f"[DeepseekV3ForCausalLM::_read_block_information_from_disk][rank { self .model_config .mapping .rank } ] "
1853- f"Read block data for request { request_id } , layer { layer_idx } , shape: { block_read_data .shape } "
1854- f"from { read_file .name } " )
1855-
18561684 def load_weights (self , weights : Dict ):
18571685 weight_loader = DeepseekV3WeightLoader (self )
18581686 weight_loader .load_weights (weights )
0 commit comments