Skip to content

Commit b121945

Browse files
Add llama converter (no dependency on internal Nvidia code) - part 1/2 (#545)
## What does this PR do? Add llama converter (no dependency on internal Nvidia code) - part 1/2 - change top-level dependencies in convert_llama3_to_decilm.py from puzzle_tools.... to modelopt..... - added modelopt.torch._compress.tools module - remove tokenization_mistral.py - not used scope of 2/2 part (will come once part 1/2 is merged): - change all deeper dependencies from from puzzle_tools.... to modelopt.... - test_convert_llama3_config_to_decilm_config.py should run without any internal nvidia dependencies --------- Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 50a580c commit b121945

File tree

13 files changed

+827
-383
lines changed

13 files changed

+827
-383
lines changed

modelopt/torch/_compress/compress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from omegaconf import DictConfig
2929
from puzzle_tools.runtime import IRuntime
3030

31-
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
31+
from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir
3232

3333

3434
def compress(
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import os
18+
import re
19+
from collections import defaultdict
20+
21+
from safetensors.torch import load_file, save_file
22+
from tqdm import tqdm
23+
24+
25+
def convert_name(name):
26+
return name.replace("feed_forward", "mlp").replace("language_model.", "")
27+
28+
29+
def convert_routed_experts_weight(llama_name, weight):
30+
assert ".experts." in llama_name, "Only use this func to convert weights of routed experts"
31+
llama_name_prefix = llama_name.split(".experts.")[0]
32+
deci_name_prefix = convert_name(llama_name_prefix)
33+
34+
experts_state_dict = {}
35+
for i_expert, expert_weight in enumerate(weight.unbind(dim=0)):
36+
expert_prefix = f"{deci_name_prefix}.experts.{i_expert}"
37+
if "gate_up_proj" in llama_name:
38+
gate_weight, up_weight = expert_weight.transpose(0, 1).chunk(2, dim=0)
39+
experts_state_dict[f"{expert_prefix}.gate_proj.weight"] = gate_weight.contiguous()
40+
experts_state_dict[f"{expert_prefix}.up_proj.weight"] = up_weight.contiguous()
41+
elif "down_proj" in llama_name:
42+
down_weight = expert_weight.transpose(0, 1)
43+
experts_state_dict[f"{expert_prefix}.down_proj.weight"] = down_weight.contiguous()
44+
else:
45+
raise ValueError(f"Unknown expert weight: {llama_name}")
46+
47+
return experts_state_dict
48+
49+
50+
def get_layer_subblock(param):
51+
if param.startswith("model.embed_tokens."):
52+
return "embeddings"
53+
if param.startswith("lm_head.") or param == "model.norm.weight":
54+
return "lm_head"
55+
m = re.match(r"model\.layers\.(\d+)\.(.+)", param)
56+
if m:
57+
layer, suffix = m.groups()
58+
if suffix.startswith(("self_attn.", "input_layernorm.weight")):
59+
return f"block_{layer}_attention"
60+
elif suffix.startswith(("mlp.", "post_attention_layernorm.weight")):
61+
return f"block_{layer}_ffn"
62+
return None
63+
64+
65+
def convert_model_weights_to_decilm(llama_hf_dir, output_dir, is_llama4=False):
66+
index_path = os.path.join(llama_hf_dir, "model.safetensors.index.json")
67+
single_file_path = os.path.join(llama_hf_dir, "model.safetensors")
68+
69+
# Check if we have a sharded model (with index) or single file model
70+
if os.path.exists(index_path):
71+
# Sharded model - use existing logic
72+
with open(index_path) as f:
73+
index = json.load(f)
74+
param_to_file = index["weight_map"]
75+
all_param_names = list(param_to_file.keys())
76+
elif os.path.exists(single_file_path):
77+
# Single file model - create a synthetic index
78+
data = load_file(single_file_path)
79+
all_param_names = list(data.keys())
80+
param_to_file = dict.fromkeys(all_param_names, "model.safetensors")
81+
else:
82+
raise FileNotFoundError(
83+
f"Neither {index_path} nor {single_file_path} found. Cannot determine model format."
84+
)
85+
name_map = {
86+
name: convert_name(name)
87+
for name in all_param_names
88+
if name.startswith("language_model.") or not is_llama4
89+
}
90+
91+
# Reverse map: file -> set of params
92+
file_to_params = defaultdict(set)
93+
for name, file in param_to_file.items():
94+
file_to_params[file].add(name)
95+
96+
# Determine subblocks needed
97+
subblocks = defaultdict(list)
98+
for old_name, new_name in name_map.items():
99+
subblock = get_layer_subblock(new_name)
100+
if subblock:
101+
subblocks[subblock].append((old_name, new_name))
102+
103+
# Output directory
104+
out_dir = os.path.join(output_dir, "subblocks_safetensors")
105+
os.makedirs(out_dir, exist_ok=True)
106+
107+
# New weight index
108+
new_index = {"metadata": {"format": "pt"}, "weight_map": {}}
109+
110+
# For single file models, load all data once
111+
if os.path.exists(single_file_path) and not os.path.exists(index_path):
112+
all_data = load_file(single_file_path)
113+
else:
114+
all_data = None
115+
116+
for subblock, param_pairs in tqdm(subblocks.items(), desc="Processing subblocks"):
117+
tensors = {}
118+
119+
if all_data is not None:
120+
# Single file model - get tensors from pre-loaded data
121+
for old_name, new_name in param_pairs:
122+
if old_name in all_data:
123+
if ".experts." not in old_name:
124+
tensors[new_name] = all_data[old_name]
125+
else:
126+
experts_state_dict = convert_routed_experts_weight(
127+
old_name, all_data[old_name]
128+
)
129+
tensors.update(experts_state_dict)
130+
else:
131+
# Sharded model - load only needed files for this subblock
132+
param_files = {param_to_file[old] for old, _ in param_pairs}
133+
for file in param_files:
134+
data = load_file(os.path.join(llama_hf_dir, file))
135+
for old_name, new_name in param_pairs:
136+
if param_to_file[old_name] == file and old_name in data:
137+
if ".experts." not in old_name:
138+
tensors[new_name] = data[old_name]
139+
else:
140+
experts_state_dict = convert_routed_experts_weight(
141+
old_name, data[old_name]
142+
)
143+
tensors.update(experts_state_dict)
144+
145+
# Save this subblock
146+
subblock_file = f"{subblock}.safetensors"
147+
save_file(tensors, os.path.join(out_dir, subblock_file))
148+
149+
# Update index
150+
for new_name in tensors:
151+
new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}"
152+
153+
# Save new index file
154+
with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f:
155+
json.dump(new_index, f, indent=2)
156+
157+
print(f"✅ Finished saving subblocks and index to {output_dir}")

modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121

2222
import torch
2323
from fire import Fire
24-
from puzzle_tools.checkpoint_utils import copy_tokenizer
25-
from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code
26-
from puzzle_tools.conversion_utils import convert_model_weights_to_decilm
2724
from transformers import LlamaConfig
2825

26+
from modelopt.torch._compress.decilm.conversion_utils import convert_model_weights_to_decilm
2927
from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
28+
from modelopt.torch._compress.tools.checkpoint_utils import copy_tokenizer
29+
from modelopt.torch._compress.tools.checkpoint_utils_hf import copy_deci_lm_hf_code
3030

3131
"""
3232
example:

0 commit comments

Comments
 (0)