|
| 1 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | +import argparse |
| 3 | +import os |
| 4 | +from transformers import PaliGemmaForConditionalGeneration |
| 5 | +import torch |
| 6 | + |
| 7 | + |
| 8 | +def convert(output_path, tensor_parallel_size, use_te): |
| 9 | + device = "cuda" |
| 10 | + |
| 11 | + model_id = "google/paligemma-3b-pt-448" |
| 12 | + model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval() |
| 13 | + |
| 14 | + model = model.to(device) |
| 15 | + |
| 16 | + print(model.config) |
| 17 | + for name, tensor in model.state_dict().items(): |
| 18 | + if "vision_model" not in name: |
| 19 | + continue |
| 20 | + shape_str = "(" + ", ".join([str(x) for x in tensor.shape]) + ")" |
| 21 | + print(f"{name:<75} {shape_str:>20}") |
| 22 | + |
| 23 | + state_dict = model.state_dict() |
| 24 | + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] |
| 25 | + |
| 26 | + def add_chunck_tensor(new_tensor, new_name, chunk_dim=None): |
| 27 | + if chunk_dim is None: |
| 28 | + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] |
| 29 | + else: |
| 30 | + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) |
| 31 | + |
| 32 | + for i in range(tensor_parallel_size): |
| 33 | + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. |
| 34 | + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() |
| 35 | + |
| 36 | + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. |
| 37 | + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") |
| 38 | + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) |
| 39 | + if use_te and is_extra_state_layer: |
| 40 | + layer = new_name.split(".")[-2] |
| 41 | + if layer in extra_state_layers: |
| 42 | + extra_state_name = ( |
| 43 | + new_name[: new_name.rfind(".") + 1] + "_extra_state" |
| 44 | + ) # Replace the weight name. |
| 45 | + new_state_dicts[i]["model"][extra_state_name] = None |
| 46 | + |
| 47 | + for name, tensor in state_dict.items(): |
| 48 | + if tensor.dtype == torch.float16: |
| 49 | + state_dict[name] = tensor.to(torch.float32) |
| 50 | + |
| 51 | + add_chunck_tensor( |
| 52 | + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"], |
| 53 | + "position_embeddings.weight") |
| 54 | + add_chunck_tensor( |
| 55 | + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"], |
| 56 | + "conv1.weight") |
| 57 | + add_chunck_tensor( |
| 58 | + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"], |
| 59 | + "conv1.bias") |
| 60 | + |
| 61 | + head_dim = 72 |
| 62 | + num_head = 16 |
| 63 | + for layer_idx in range(27): |
| 64 | + origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}" |
| 65 | + target_base = f"decoder.layers.{layer_idx}" |
| 66 | + |
| 67 | + for param_type in ["weight", "bias"]: |
| 68 | + # QKV |
| 69 | + q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"] |
| 70 | + k_proj_params = state_dict[f"{origin_base}.self_attn.k_proj.{param_type}"] |
| 71 | + v_proj_params = state_dict[f"{origin_base}.self_attn.v_proj.{param_type}"] |
| 72 | + # Do some tensor manipulation because megatron expect one tensor |
| 73 | + # projection for the QKV in the order |
| 74 | + # [(Q1, K1, V1), (Q2, K2, V2), ...] where Qi is the query of the |
| 75 | + # i-th head with dimension num_head. |
| 76 | + new_tensor = torch.concatenate([ |
| 77 | + q_proj_params.view(num_head, head_dim, -1), |
| 78 | + k_proj_params.view(num_head, head_dim, -1), |
| 79 | + v_proj_params.view(num_head, head_dim, -1)], axis=1).view( |
| 80 | + 3*head_dim*num_head, -1) |
| 81 | + if param_type == "bias": |
| 82 | + new_tensor = new_tensor[:, 0] |
| 83 | + new_name = f"{target_base}.self_attention.linear_qkv.{param_type}" |
| 84 | + add_chunck_tensor(new_tensor, new_name, chunk_dim=0) |
| 85 | + # linear_proj |
| 86 | + add_chunck_tensor( |
| 87 | + state_dict[f"{origin_base}.self_attn.out_proj.{param_type}"], |
| 88 | + f"{target_base}.self_attention.linear_proj.{param_type}", |
| 89 | + chunk_dim=1 if param_type == "weight" else None) |
| 90 | + # layer_norm |
| 91 | + new_name = f"{target_base}.input_layernorm.{param_type}" |
| 92 | + if use_te: |
| 93 | + new_name = f"{target_base}.self_attention.linear_qkv.layer_norm_{param_type}" |
| 94 | + add_chunck_tensor( |
| 95 | + state_dict[f"{origin_base}.layer_norm1.{param_type}"], |
| 96 | + new_name) |
| 97 | + # FC 1 |
| 98 | + add_chunck_tensor( |
| 99 | + state_dict[f"{origin_base}.mlp.fc1.{param_type}"], |
| 100 | + f"{target_base}.mlp.linear_fc1.{param_type}", |
| 101 | + chunk_dim=0) |
| 102 | + # FC 2 |
| 103 | + add_chunck_tensor( |
| 104 | + state_dict[f"{origin_base}.mlp.fc2.{param_type}"], |
| 105 | + f"{target_base}.mlp.linear_fc2.{param_type}", |
| 106 | + chunk_dim=1 if param_type=="weight" else None) |
| 107 | + # layer_norm |
| 108 | + new_name = f"{target_base}.pre_mlp_layernorm.{param_type}" |
| 109 | + if use_te: |
| 110 | + new_name = f"{target_base}.mlp.linear_fc1.layer_norm_{param_type}" |
| 111 | + add_chunck_tensor( |
| 112 | + state_dict[f"{origin_base}.layer_norm2.{param_type}"], |
| 113 | + new_name) |
| 114 | + |
| 115 | + add_chunck_tensor( |
| 116 | + state_dict["vision_tower.vision_model.post_layernorm.weight"], |
| 117 | + "ln_post.weight") |
| 118 | + add_chunck_tensor( |
| 119 | + state_dict["vision_tower.vision_model.post_layernorm.bias"], |
| 120 | + "ln_post.bias") |
| 121 | + |
| 122 | + for i in range(tensor_parallel_size): |
| 123 | + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") |
| 124 | + os.makedirs(output_dir_tp) |
| 125 | + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") |
| 126 | + torch.save(new_state_dicts[i], output_path_tp) |
| 127 | + |
| 128 | + |
| 129 | +if __name__ == "__main__": |
| 130 | + parser = argparse.ArgumentParser( |
| 131 | + description=""" |
| 132 | +Convert SigLIP weights to megatron format. |
| 133 | +
|
| 134 | +
|
| 135 | +Example usage: |
| 136 | +python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te |
| 137 | +
|
| 138 | +examples/multimodal/combine_mistral_clip.sh /lustre/fsw/portfolios/llmservice/users/jbarker/workspace/checkpoints/Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4 |
| 139 | +""", |
| 140 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 141 | + ) |
| 142 | + parser.add_argument( |
| 143 | + "--output", type=str, required=True, help="output directory for megatron state dict file(s)" |
| 144 | + ) |
| 145 | + parser.add_argument( |
| 146 | + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" |
| 147 | + ) |
| 148 | + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") |
| 149 | + |
| 150 | + args = parser.parse_args() |
| 151 | + |
| 152 | + convert(args.output, args.tensor_parallel_size, args.use_te) |
| 153 | + |
| 154 | + print("done.") |
0 commit comments