Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion docs/guides/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ CTranslate2 supports selected models from Hugging Face's [Transformers](https://
* CodeGen
* DistilBERT
* Falcon
* Gemma 2
* Gemma 3 (text only)
* Llama
* M2M100
* MarianMT
Expand Down Expand Up @@ -80,7 +82,7 @@ print(tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tok

## BERT

[BERT](https://huggingface.co/docs/transformers/model_doc/bert) is pretrained model on English language using a masked language modeling objective.
[BERT](https://huggingface.co/docs/transformers/model_doc/bert) is a pretrained model on English language using a masked language modeling objective.

CTranslate2 only implements the `BertModel` class from Transformers which includes the Transformer encoder and the pooling layer. Task-specific layers should be run with PyTorch as shown in the example below.

Expand Down Expand Up @@ -183,6 +185,43 @@ output = tokenizer.decode(results[0].sequences_ids[0])
print(output)
```

## Gemma 3 (text only)


[Gemma 3](https://ai.google.dev/gemma/docs/core) is Google's latest family of lightweight, open-weight AI models, built on the same technology as Gemini.

Gemma models come in two flavors: instruction tuned (it) models and base models.

Instruction tuned models expect a specific [prompt template format](https://ai.google.dev/gemma/docs/core/prompt-structure) which you should use.

When converting an instruction-tuned model, CTranslate sets `<end_of_turn>` as the default end-of-sequence token.


To convert a model:

```bash
ct2-transformers-converter --model google/gemma-3-1b-it --output_dir gemma-3-1b-it
```

Gemma 3 usage sample:


```python

from transformers import AutoTokenizer
import ctranslate2

tok = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
gen = ctranslate2.Generator("gemma-3-1b-it")

prompt = "<start_of_turn>user\nGenerate a 200 word text talking about George Orwell.<end_of_turn>\n<start_of_turn>model\n"
tokens = tok.convert_ids_to_tokens(tok.encode(prompt))

res = gen.generate_batch([tokens], max_length=2048, sampling_temperature=0.1, include_prompt_in_result=False)
print(tok.convert_tokens_to_string(res[0].sequences[0]))
```


## Llama 2

[Llama 2](https://ai.meta.com/llama/) is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters.
Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ctranslate2/layers/attention_layer.h"
#include "ctranslate2/padder.h"
#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {
Expand Down Expand Up @@ -65,6 +66,8 @@ namespace ctranslate2 {
dim_t _relative_right_max_position;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
std::unique_ptr<const LayerNorm> _q_norm; // Query normalization
std::unique_ptr<const LayerNorm> _k_norm; // Key normalization
};
}
}
186 changes: 186 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,192 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
gc.collect()


@register_loader("Gemma3TextConfig")
@register_loader("Gemma3Config")
class Gemma3Loader(ModelLoader):
@property
def architecture_name(self):
return "Gemma3ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

head_dim = model.config.head_dim

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

# Get RoPE parameters
rope_theta = getattr(model.config, "rope_theta", 1_000_000) # Global: 1M
rope_local_base_freq = getattr(
model.config, "rope_local_base_freq", 10_000
) # Local: 10k

# Get sliding window configuration
sliding_window = getattr(model.config, "sliding_window", 1024)
layer_types = getattr(model.config, "layer_types", None)

quantization_config = getattr(model.config, "quantization_config", None)
if quantization_config:
if quantization_config.quant_method == "awq":
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
if quant_type is None:
raise NotImplementedError(
"Quantization type '%s' is not yet implemented."
% quantization_config.quant_method
)
else:
quant_type = common_spec.Quantization.CT2

# Create base spec using from_config
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=head_dim,
rotary_interleave=False,
rotary_base=rope_local_base_freq, # Default to local base freq
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window, # Default to local sliding window
pre_post_layer_norm=True,
qk_norm=True,
)

# Store layer_types for use in set_decoder
self._layer_types = layer_types

# Override per-layer settings for global vs local attention
for i, layer_type in enumerate(layer_types):
layer = spec.decoder.layer[i]
if layer_type == "full_attention":
layer.self_attention.rotary_base = np.dtype("float32").type(rope_theta)
layer.self_attention.sliding_window = np.dtype("int32").type(0)
elif layer_type == "sliding_attention":
layer.self_attention.rotary_base = np.dtype("float32").type(
rope_local_base_freq
)
layer.self_attention.sliding_window = np.dtype("int32").type(
sliding_window
)

self.set_decoder(spec.decoder, model.model, quant_type)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.unk_token = tokenizer.unk_token

if (
hasattr(tokenizer, "chat_template")
and isinstance(tokenizer.chat_template, str)
and tokenizer.chat_template.strip()
):
config.eos_token = "<end_of_turn>"
else:
config.eos_token = tokenizer.eos_token

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight + 1.0

def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
spec.scale_embeddings = True
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens) # Input
self.set_layer_norm(spec.layer_norm, module.norm) # Output

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)

self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)

self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

# Set QK-norm weights (Gemma 3 uses this instead of soft-capping)
self.set_layer_norm(
layer_spec.self_attention.q_norm, layer.self_attn.q_norm
)
self.set_layer_norm(
layer_spec.self_attention.k_norm, layer.self_attn.k_norm
)

# Set attention projections
split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(
split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
)
self.set_linear(
split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
)
self.set_linear(
split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
)

if quant_type == common_spec.Quantization.CT2:
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
else:
cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
utils.fuse_linear_prequant(
layer_spec.self_attention.linear[0], split_layers, cc_dim
)

self.set_linear(
layer_spec.self_attention.linear[1],
layer.self_attn.o_proj,
quant_type=quant_type,
)

# Set FFN weights
self.set_linear(
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
)
self.set_linear(
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
)
self.set_linear(
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("MistralConfig")
class MistralLoader(ModelLoader):
@property
Expand Down
6 changes: 6 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
num_heads_kv=None,
head_dim=None,
sliding_window=None,
qk_norm=False,
qk_norm_rms=True,
):
self.queries_scale = model_spec.OPTIONAL

Expand All @@ -40,6 +42,10 @@ def __init__(
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
]

if qk_norm:
self.q_norm = common_spec.LayerNormSpec(rms_norm=qk_norm_rms)
self.k_norm = common_spec.LayerNormSpec(rms_norm=qk_norm_rms)

if relative_position:
self.relative_position_keys = None
self.relative_position_values = None
Expand Down
7 changes: 7 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
quant_type: Optional[common_spec.Quantization] = None,
quant_group_size: Optional[int] = None,
quant_bits: Optional[int] = None,
qk_norm: Optional[bool] = False,
):
"""Initializes a Transformer decoder specification.

Expand Down Expand Up @@ -222,6 +223,7 @@ def __init__(
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
qk_norm=qk_norm,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -286,6 +288,7 @@ def __init__(
num_heads_kv=None,
head_dim=None,
sliding_window=None,
qk_norm=False,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand All @@ -302,13 +305,15 @@ def __init__(
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
qk_norm=qk_norm,
)

if with_encoder_attention:
self.attention = attention_spec.MultiHeadAttentionSpec(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
qk_norm=qk_norm,
)

self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
Expand Down Expand Up @@ -557,6 +562,7 @@ def from_config(
quant_type: Optional[common_spec.Quantization] = None,
quant_group_size: Optional[int] = None,
quant_bits: Optional[int] = None,
qk_norm: Optional[bool] = False,
):
"""Creates a Transformer decoder model specification.

Expand Down Expand Up @@ -631,6 +637,7 @@ def from_config(
quant_type=quant_type,
quant_group_size=quant_group_size,
quant_bits=quant_bits,
qk_norm=qk_norm,
)

return cls(decoder)
Expand Down
Loading
Loading