Skip to content

[WIP] llama4 ckpt conversion #1816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 130 additions & 3 deletions MaxText/llama4_ckpt_unscanned.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
"""
# pylint: disable=line-too-long
return {
## language model mappping
"language_model.model.embed_tokens.weight": "tok_embeddings.weight",
"language_model.model.norm.weight": "norm.weight",
"language_model.lm_head.weight": "output.weight",
Expand All @@ -145,18 +146,43 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
f"language_model.model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight",
f"language_model.model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight",
f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight",
# MoE
# MoE in language model
f"language_model.model.layers.{layer_idx}.feed_forward.router.weight": f"layers.{layer_idx}.feed_forward.gate.weight",
f"language_model.model.layers.{layer_idx}.feed_forward.experts.down_proj": f"layers.{layer_idx}.feed_forward.experts.down_proj",
# NOTE: this contains up_proj and gate_proj concated together (we'll split/chunk them later)
f"language_model.model.layers.{layer_idx}.feed_forward.experts.gate_up_proj": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj",
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.gate_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.gate_proj.weight",
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.down_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.down_proj.weight",
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.up_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.up_proj.weight",
# FFN
# FFN in language model
f"language_model.model.layers.{layer_idx}.feed_forward.up_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight",
f"language_model.model.layers.{layer_idx}.feed_forward.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight",
f"language_model.model.layers.{layer_idx}.feed_forward.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight",

# ## vision model mapping
# "vision_model.class_embedding": "vision_encoder.",
# "vision_model.positional_embedding_vlm": "",
# "vision_model.patch_embedding.linear.weight": "",
# "vision_model.layernorm_pre.weight": "",
# "vision_model.layernorm_pre.bias": "",
# "vision_model.layernorm_post.weight": "",
# "vision_model.layernorm_post.bias": "",
# "vision_model.model.layers.{layer_idx}.input_layernorm.weight": "",
# "vision_model.model.layers.{layer_idx}.input_layernorm.bias": "",
# "vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight": "",
# "vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias": "",
# "vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight": "",
# "vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias": "",
# "vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight": "",
# "vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias": "",
# "vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight": "",
# "vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias": "",
# "vision_model.model.layers.{layer_idx}.post_attention_layernorm.weight": "",
# "vision_model.model.layers.{layer_idx}.post_attention_layernorm.bias": "",
# "vision_model.model.layers.{layer_idx}.mlp.fc1.weight": "",
# "vision_model.model.layers.{layer_idx}.mlp.fc1.bias": "",
# "vision_model.model.layers.{layer_idx}.mlp.fc2.weight": "",
# "vision_model.model.layers.{layer_idx}.mlp.fc2.bias": "",
}


Expand Down Expand Up @@ -194,6 +220,10 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
Returns:
jax_weights (dict): Dictionary containing the converted weights.
"""
num_hidden_layers_for_vit = model_params.get("num_layers_vit", 0)
num_attention_heads_for_vit = model_params.get("num_att_head_vit", 0)
hidden_size_for_vit = model_params.get("hidden_size_vit", 0)
head_dim_for_vit = hidden_size_for_vit // num_attention_heads_for_vit
base_num_decoder_layers = model_params["num_layers"]
base_num_query_heads = model_params["num_heads"]
head_dim = model_params["dims_per_head"]
Expand All @@ -217,7 +247,8 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
layer = int(parts[3]) if "layers" in key else 0
# TODO: update when mutli-modality support is added
if "vision" in key or "multi_modal_projector" in key:
print("WARNING: skipping vision or multi-modal key: ", key)
#print("WARNING: skipping vision or multi-modal key: ", key)
chkpt_vars[key] = f.get_tensor(key)
else:
mapped_key = _hf_to_maxtext_mapping(layer)[key]
chkpt_vars[mapped_key] = f.get_tensor(key)
Expand All @@ -230,8 +261,104 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
"logits_dense": {"kernel": None},
},
"token_embedder": {"embedding": None},
"vision_encoder": {
"Llama4VisionModel_0": {
"Llama4VisionEncoder_0": {},
"class_embedding": None,
"positional_embedding_vlm": None,
"Llama4UnfoldConvolution_0": {"vit_unfold_linear": {"kernel": None}},
"layernorm_pre": {},
"layernorm_post": {},
"Llama4VisionPixelShuffleMLP_0": {},
},
"Llama4MultiModalProjector_0": {"vit_multi_modal_projector": {"kernel": None}},
},
}

# vision model ###########################################
max_logging.log("Processing vision model")
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["class_embedding"] = chkpt_vars["vision_model.class_embedding"].to(torch.float32).numpy().astype(CAST_DTYPE)
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["positional_embedding_vlm"] = chkpt_vars["vision_model.positional_embedding_vlm"].to(torch.float32).numpy().astype(CAST_DTYPE)
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4UnfoldConvolution_0"]["vit_unfold_linear"]["kernel"] = chkpt_vars["vision_model.patch_embedding.linear.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_pre"].update({
"scale": chkpt_vars["vision_model.layernorm_pre.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
"bias": chkpt_vars["vision_model.layernorm_pre.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
})
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_post"].update({
"scale": chkpt_vars["vision_model.layernorm_post.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
"bias": chkpt_vars["vision_model.layernorm_post.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
})

max_logging.log("Processing vision encoder")
for layer_idx in tqdm(range(num_hidden_layers_for_vit), desc="layers", leave=False):
layer_name = f"layers_{layer_idx}"
wq = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
wk = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
wv = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
wo = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
bq = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
bk = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
bv = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
bo = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)

wq = np.reshape(wq, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
wk = np.reshape(wk, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
wv = np.reshape(wv, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
wo = np.reshape(wo, [num_attention_heads_for_vit, head_dim_for_vit, hidden_size_for_vit])
bq = np.reshape(bq, [num_attention_heads_for_vit, head_dim_for_vit])
bk = np.reshape(bk, [num_attention_heads_for_vit, head_dim_for_vit])
bv = np.reshape(bv, [num_attention_heads_for_vit, head_dim_for_vit])

self_attention_vision = {
"query": {"kernel": wq , "bias": bq},
"key": {"kernel": wk , "bias": bk},
"value": {"kernel": wv , "bias": bv},
"out": {"kernel": wo , "bias": bo},
}

fc1_w = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
fc2_w = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
fc1_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
fc2_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
vision_mlp = {
"vit_encoder_layer_mlp_fc1": {"kernel": fc1_w, "bias": fc1_b},
"vit_encoder_layer_mlp_fc2": {"kernel": fc2_w, "bias": fc2_b},
}

jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4VisionEncoder_0"].update(
{
layer_name: {
"self_attention_vision": self_attention_vision,
"Llama4VisionMLP_0": vision_mlp,
"input_layer_norm": {
"scale": chkpt_vars[f"vision_model.model.layers.{layer_idx}.input_layernorm.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
"bias": chkpt_vars[f"vision_model.model.layers.{layer_idx}.input_layernorm.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
},
"post_attention_layer_norm": {
"scale": chkpt_vars[f"vision_model.model.layers.{layer_idx}.post_attention_layernorm.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
"bias": chkpt_vars[f"vision_model.model.layers.{layer_idx}.post_attention_layernorm.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
},
}
}
)

max_logging.log("Processing pixel shuffle mlp")
adaptor_fc1 = chkpt_vars["vision_model.vision_adapter.mlp.fc1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
adaptor_fc2 = chkpt_vars["vision_model.vision_adapter.mlp.fc2.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4VisionPixelShuffleMLP_0"].update(
{
"pixel_shuffle_mlp": {
"vit_pixel_shuffle_mlp_fc1": {"kernel": adaptor_fc1},
"vit_pixel_shuffle_mlp_fc2": {"kernel": adaptor_fc2},
},
}
)

max_logging.log("Processing multimodal projector")
jax_weights["vision_encoder"]["Llama4MultiModalProjector_0"]["vit_multi_modal_projector"]["kernel"] = chkpt_vars["multi_modal_projector.linear_1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()

# language model ###########################################
max_logging.log("Processing language model")
# decoder norm scale ###########################################
max_logging.log("Processing decoder norm scale")
decoder_norm_scale = chkpt_vars["norm.weight"].to(torch.float32).numpy().astype(CAST_DTYPE)
Expand Down
5 changes: 4 additions & 1 deletion MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"vocab": 128256,
},
"llama4-17b-16e": {
"num_layers": 48,
"num_layers": 1,
"num_heads": 40,
"num_kv_heads": 8,
"dims_per_head": 128,
Expand All @@ -141,6 +141,9 @@
"rope_type": "llama3.1",
"scale_query": False,
"interleave_moe_layer_step": 1,
"num_layers_vit": 1,
"num_att_head_vit": 16,
"hidden_size_vit": 1408,
},
"llama4-17b-128e": {
"num_layers": 48,
Expand Down
Loading