Skip to content

Safetensors conversion #2290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
140 changes: 140 additions & 0 deletions keras_hub/src/utils/transformers/export_gemma_to_safetensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import os
import shutil
import warnings

import torch
from safetensors.torch import save_file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out



def convert_to_hf_config(keras_config):
hf_config = {
"vocab_size": keras_config.vocabulary_size,
"num_hidden_layers": keras_config.num_layers,
"num_attention_heads": keras_config.num_query_heads,
"num_key_value_heads": keras_config.num_key_value_heads,
"hidden_size": keras_config.hidden_dim,
"intermediate_size": keras_config.intermediate_dim // 2,
"head_dim": keras_config.head_dim,
"max_position_embeddings": 8192,
}
return hf_config


def export_to_hf(keras_model, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.

So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/ will look like:

  • export.py: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.
  • gemma.py: this will just have a way to create the weight dictionary for Gemma. Inside export.py, we will call the the weight conversion function specific to a specified model.

Pinging @mattdangerw to confirm if we should do this now or at a later point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset() or a function like some_export(model). Any thoughts?

"""This function converts a Keras Gemma model to Hugging Face format by:
- Extracting and mapping weights from the Keras backbone to safetensors.
- Saving the configuration as 'config.json'.
- Saving weights in 'model.safetensors'.
- Saving tokenizer assets.
Args:
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert.
path: str. Path of the directory to which the safetensors file,
config and tokenizer will be saved.
"""
backbone = keras_model.backbone
hf_config = convert_to_hf_config(backbone)

weights_dict = {}

# Map token embedding
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = torch.from_numpy(
token_embedding
)

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
torch.from_numpy(pre_attn_norm)
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.get_weights()[0]
query_kernel = (
torch.from_numpy(query_kernel)
.permute(1, 0, 2)
.reshape(-1, backbone.hidden_dim)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

# Attention key projection
key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0]
key_kernel = torch.from_numpy(key_kernel).T
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel

# Attention value projection
value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0]
value_kernel = torch.from_numpy(value_kernel).T
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel

# Attention output projection
out_kernel = decoder_layer.attention.output_dense.get_weights()[0]
out_kernel = (
torch.from_numpy(out_kernel)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

# Post-attention normalization
post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
torch.from_numpy(post_attn_norm)
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.get_weights()[0]
gate_kernel = torch.from_numpy(gate_kernel).T
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.get_weights()[0]
up_kernel = torch.from_numpy(up_kernel).T
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

# MLP down projection
down_kernel = decoder_layer.ffw_linear.get_weights()[0]
down_kernel = torch.from_numpy(down_kernel).T
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

# Map final normalization
final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = torch.from_numpy(final_norm)

# Tie lm_head.weight to embedding weights
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)

# Make tensors contiguous before saving
weights_dict_contiguous = {
k: v.contiguous() for k, v in weights_dict.items()
}

# Save weights
weights_path = os.path.join(path, "model.safetensors")
save_file(weights_dict_contiguous, weights_path)

# Save tokenizer assets
keras_model.preprocessor.tokenizer.save_assets(path)

# Rename vocabulary file
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(
f"{vocab_spm_path} not found. Tokenizer may not load correctly."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import pytest
import torch
from transformers import GemmaForCausalLM
from transformers import GemmaTokenizer

from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export_gemma_to_safetensor import (
export_to_hf,
)


class TestGemmaExport(TestCase):
@pytest.mark.large
def test_export_to_hf(self):
# Load Keras model
keras_model = GemmaCausalLM.from_preset("gemma_2b_en")
input_text = "All hail RCB"
max_length = 25

# Export to Hugging Face format using self.tmp_path
export_path = os.path.join(self.get_temp_dir(), "export_to_hf")
export_to_hf(keras_model, export_path)

# Load Hugging Face model and tokenizer
hf_model = GemmaForCausalLM.from_pretrained(export_path)
hf_tokenizer = GemmaTokenizer.from_pretrained(export_path)

# Generate text with Keras model
keras_output = keras_model.generate(input_text, max_length=max_length)

# Generate text with Hugging Face model
hf_inputs = hf_tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
hf_outputs = hf_model.generate(
**hf_inputs, max_length=max_length, do_sample=False
)
hf_output_text = hf_tokenizer.decode(
hf_outputs[0], skip_special_tokens=True
)

self.assertEqual(keras_output, hf_output_text)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sentencepiece
tensorflow-datasets
safetensors
pillow
transformers
Loading