Skip to content

Commit 6bafe92

Browse files
committed
Merge branch 'add_siglip_converter' into 'main'
Add siglip converter to multimodal example See merge request ADLR/megatron-lm!2214
2 parents 4876ee1 + bc4874c commit 6bafe92

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

examples/multimodal/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Follow the instructions in `megatron-lm/docs/llama_mistral.md` to download weigh
2323
This example uses the OpenAI CLIP `ViT-L/14@336px` Vision model. To download the weights from OpenAI and convert them to a format that can be loaded in megatron, please run the following:
2424

2525
```
26-
python examples/multimodal/clip_converter.py --download-root /some/download/folder --output /some/output/folder --tensor-parallel-size 4 --use-te
26+
python examples/multimodal/model_converter/clip_converter.py --download-root /some/download/folder --output /some/output/folder --tensor-parallel-size 4 --use-te
2727
```
2828

2929
### Combined model checkpoint
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
import argparse
3+
import os
4+
from transformers import PaliGemmaForConditionalGeneration
5+
import torch
6+
7+
8+
def convert(output_path, tensor_parallel_size, use_te):
9+
device = "cuda"
10+
11+
model_id = "google/paligemma-3b-pt-448"
12+
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
13+
14+
model = model.to(device)
15+
16+
print(model.config)
17+
for name, tensor in model.state_dict().items():
18+
if "vision_model" not in name:
19+
continue
20+
shape_str = "(" + ", ".join([str(x) for x in tensor.shape]) + ")"
21+
print(f"{name:<75} {shape_str:>20}")
22+
23+
state_dict = model.state_dict()
24+
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
25+
26+
def add_chunck_tensor(new_tensor, new_name, chunk_dim=None):
27+
if chunk_dim is None:
28+
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
29+
else:
30+
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
31+
32+
for i in range(tensor_parallel_size):
33+
# chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage.
34+
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
35+
36+
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
37+
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
38+
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
39+
if use_te and is_extra_state_layer:
40+
layer = new_name.split(".")[-2]
41+
if layer in extra_state_layers:
42+
extra_state_name = (
43+
new_name[: new_name.rfind(".") + 1] + "_extra_state"
44+
) # Replace the weight name.
45+
new_state_dicts[i]["model"][extra_state_name] = None
46+
47+
for name, tensor in state_dict.items():
48+
if tensor.dtype == torch.float16:
49+
state_dict[name] = tensor.to(torch.float32)
50+
51+
add_chunck_tensor(
52+
state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"],
53+
"position_embeddings.weight")
54+
add_chunck_tensor(
55+
state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"],
56+
"conv1.weight")
57+
add_chunck_tensor(
58+
state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"],
59+
"conv1.bias")
60+
61+
head_dim = 72
62+
num_head = 16
63+
for layer_idx in range(27):
64+
origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
65+
target_base = f"decoder.layers.{layer_idx}"
66+
67+
for param_type in ["weight", "bias"]:
68+
# QKV
69+
q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"]
70+
k_proj_params = state_dict[f"{origin_base}.self_attn.k_proj.{param_type}"]
71+
v_proj_params = state_dict[f"{origin_base}.self_attn.v_proj.{param_type}"]
72+
# Do some tensor manipulation because megatron expect one tensor
73+
# projection for the QKV in the order
74+
# [(Q1, K1, V1), (Q2, K2, V2), ...] where Qi is the query of the
75+
# i-th head with dimension num_head.
76+
new_tensor = torch.concatenate([
77+
q_proj_params.view(num_head, head_dim, -1),
78+
k_proj_params.view(num_head, head_dim, -1),
79+
v_proj_params.view(num_head, head_dim, -1)], axis=1).view(
80+
3*head_dim*num_head, -1)
81+
if param_type == "bias":
82+
new_tensor = new_tensor[:, 0]
83+
new_name = f"{target_base}.self_attention.linear_qkv.{param_type}"
84+
add_chunck_tensor(new_tensor, new_name, chunk_dim=0)
85+
# linear_proj
86+
add_chunck_tensor(
87+
state_dict[f"{origin_base}.self_attn.out_proj.{param_type}"],
88+
f"{target_base}.self_attention.linear_proj.{param_type}",
89+
chunk_dim=1 if param_type == "weight" else None)
90+
# layer_norm
91+
new_name = f"{target_base}.input_layernorm.{param_type}"
92+
if use_te:
93+
new_name = f"{target_base}.self_attention.linear_qkv.layer_norm_{param_type}"
94+
add_chunck_tensor(
95+
state_dict[f"{origin_base}.layer_norm1.{param_type}"],
96+
new_name)
97+
# FC 1
98+
add_chunck_tensor(
99+
state_dict[f"{origin_base}.mlp.fc1.{param_type}"],
100+
f"{target_base}.mlp.linear_fc1.{param_type}",
101+
chunk_dim=0)
102+
# FC 2
103+
add_chunck_tensor(
104+
state_dict[f"{origin_base}.mlp.fc2.{param_type}"],
105+
f"{target_base}.mlp.linear_fc2.{param_type}",
106+
chunk_dim=1 if param_type=="weight" else None)
107+
# layer_norm
108+
new_name = f"{target_base}.pre_mlp_layernorm.{param_type}"
109+
if use_te:
110+
new_name = f"{target_base}.mlp.linear_fc1.layer_norm_{param_type}"
111+
add_chunck_tensor(
112+
state_dict[f"{origin_base}.layer_norm2.{param_type}"],
113+
new_name)
114+
115+
add_chunck_tensor(
116+
state_dict["vision_tower.vision_model.post_layernorm.weight"],
117+
"ln_post.weight")
118+
add_chunck_tensor(
119+
state_dict["vision_tower.vision_model.post_layernorm.bias"],
120+
"ln_post.bias")
121+
122+
for i in range(tensor_parallel_size):
123+
output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}")
124+
os.makedirs(output_dir_tp)
125+
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
126+
torch.save(new_state_dicts[i], output_path_tp)
127+
128+
129+
if __name__ == "__main__":
130+
parser = argparse.ArgumentParser(
131+
description="""
132+
Convert SigLIP weights to megatron format.
133+
134+
135+
Example usage:
136+
python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te
137+
138+
examples/multimodal/combine_mistral_clip.sh /lustre/fsw/portfolios/llmservice/users/jbarker/workspace/checkpoints/Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4
139+
""",
140+
formatter_class=argparse.RawDescriptionHelpFormatter,
141+
)
142+
parser.add_argument(
143+
"--output", type=str, required=True, help="output directory for megatron state dict file(s)"
144+
)
145+
parser.add_argument(
146+
"--tensor-parallel-size", type=int, default=1, help="model tensor parallel size"
147+
)
148+
parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine")
149+
150+
args = parser.parse_args()
151+
152+
convert(args.output, args.tensor_parallel_size, args.use_te)
153+
154+
print("done.")

0 commit comments

Comments
 (0)