Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ loss.backward()
| Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma4 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4_text` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma4 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
Expand Down Expand Up @@ -119,6 +120,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_gemma4",
"apply_liger_kernel_to_gemma4_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
Expand Down Expand Up @@ -207,6 +209,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_gemma4",
"apply_liger_kernel_to_gemma4_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
Expand Down
146 changes: 146 additions & 0 deletions src/liger_kernel/transformers/model/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast

try:
from liger_kernel.transformers.model.output_classes import LigerGemma4CausalLMOutputWithPast
except ImportError:
# Older transformers without gemma4 — multimodal_forward is then unreachable
# because monkey_patch.apply_liger_kernel_to_gemma4 imports gemma4 modules
# behind the same try/except.
LigerGemma4CausalLMOutputWithPast = None

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -149,3 +157,141 @@ def causal_forward(
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)


def multimodal_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_position_ids: Optional[torch.LongTensor] = None,
video_position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
mm_token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**lm_kwargs,
):
r"""Fused-linear-cross-entropy forward for ``Gemma4ForConditionalGeneration``.

Mirrors :func:`liger_kernel.transformers.model.gemma3.multimodal_forward`
with Gemma 4-specific kwargs (``pixel_values_videos``, ``input_features``,
``image_position_ids``, ``video_position_ids``, ``mm_token_type_ids``) and
output fields (``image_hidden_states``, ``audio_hidden_states``).

The win on Gemma 4 multimodal is large: vocab=262,144 means the (B, T, V)
fp32 logits tensor is ~17 GB at T=8192 in bf16 (and another ~34 GB once the
loss path upcasts), OOMing 96 GB cards. Routing loss through
``LigerForCausalLMLoss`` materializes only the loss scalar.

labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either
be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids`
docstring). Tokens with indices set to `-100` are ignored.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`,
calculate logits for all `input_ids` (special case). If a `torch.Tensor`,
must be 1D corresponding to the indices to keep in the sequence-length
dimension.
"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
image_position_ids=image_position_ids,
video_position_ids=video_position_ids,
past_key_values=past_key_values,
mm_token_type_ids=mm_token_type_ids,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**lm_kwargs,
)

hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

text_cfg = self.config.get_text_config()
softcap = getattr(text_cfg, "final_logit_softcapping", None)
shift_labels = lm_kwargs.pop("shift_labels", None)

loss = None
logits = None
token_accuracy = None
predicted_tokens = None

if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=text_cfg.hidden_size,
final_logit_softcapping=softcap,
**lm_kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if softcap is not None:
logits = logits / softcap
logits = torch.tanh(logits)
logits = logits * softcap
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=text_cfg.vocab_size,
**lm_kwargs,
)

if not return_dict:
output_tuple = (logits,) + outputs[1:]
output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple
return output_tuple

return LigerGemma4CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=getattr(outputs, "image_hidden_states", None),
audio_hidden_states=getattr(outputs, "audio_hidden_states", None),
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
13 changes: 13 additions & 0 deletions src/liger_kernel/transformers/model/output_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
except Exception:
_Gemma3CausalLMOutputWithPast = None

try:
from transformers.models.gemma4.modeling_gemma4 import Gemma4CausalLMOutputWithPast as _Gemma4CausalLMOutputWithPast
except Exception:
_Gemma4CausalLMOutputWithPast = None

try:
from transformers.models.glm4v_moe.modeling_glm4v_moe import (
Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast,
Expand Down Expand Up @@ -108,6 +113,14 @@ class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast):
predicted_tokens: Optional[torch.LongTensor] = None


if _Gemma4CausalLMOutputWithPast is not None:

@dataclass
class LigerGemma4CausalLMOutputWithPast(_Gemma4CausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None


if _Glm4vMoeCausalLMOutputWithPast is not None:

@dataclass
Expand Down
131 changes: 130 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,11 @@ def _maybe_patch_scaled_norm(module):
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

causal_lm_types = tuple(cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM) if cls is not None)
# `isinstance(cls, type)` filter (rather than `cls is not None`) so we
# also drop the MagicMock the test harness substitutes for
# `Gemma4TextForCausalLM` under `with patch("...modeling_gemma4")` —
# an `isinstance` check against a non-class entry raises TypeError.
causal_lm_types = tuple(cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM) if isinstance(cls, type))
if isinstance(model, causal_lm_types) or isinstance(model, Gemma4TextModel):
# get the base model from the model instance
base_model = model.model if isinstance(model, causal_lm_types) else model
Expand Down Expand Up @@ -1384,6 +1388,130 @@ def _maybe_patch_scaled_norm(module):
raise TypeError("The model must be Gemma4ForCausalLM, Gemma4TextForCausalLM, or Gemma4TextModel.")


def apply_liger_kernel_to_gemma4(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma4
multimodal models (`Gemma4ForConditionalGeneration`).

Dispatches on class: text-only variants (`Gemma4ForCausalLM`,
`Gemma4TextForCausalLM`, `Gemma4TextModel`) are routed to
:func:`apply_liger_kernel_to_gemma4_text` for backwards compatibility, so the
same entry point works for both shapes when an instance is supplied.

The primary win is the fused-linear-cross-entropy path on the multimodal
class: with vocab=262,144, the (B, T, V) fp32 logits tensor is ~17 GB at
T=8192 in bf16 (and ~34 GB once the loss path upcasts), OOMing 96 GB cards.
Fused CE materializes only the loss scalar.

Out of scope (deferred to future PRs):
- Vision and audio tower kernel swaps. Gemma 4's vision and audio towers
are loaded via `AutoModel.from_config(config.vision_config)` and
`AutoModel.from_config(config.audio_config)` respectively, so their
module classes are polymorphic. A safe class-level swap would need to
enumerate supported tower types — out of scope here; FLCE on the LM head
is what unblocks training OOM.
- PLE (Per-Layer Embeddings) kernels. PLE state passes through the inner
forward unchanged; verified end-to-end on E4B-it without explicit
handling.
- Gemma 4 MoE expert kernels (`Gemma4TextExperts`); guarded out via the
same `enable_moe_block` check used in the text path.

Args:
rope (bool): Currently a no-op (HF's apply_rotary_pos_emb signature is
incompatible with Liger's fused variant). Default False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss.
Default False. Mutually exclusive with `fused_linear_cross_entropy`.
fused_linear_cross_entropy (bool): Fused linear CE for memory
efficiency. Default True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default True.
model (PreTrainedModel): An already-instantiated model to patch
in-place. Default None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.gemma4 import modeling_gemma4
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForCausalLM
from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel

from liger_kernel.transformers.model.gemma4 import multimodal_forward

Gemma4TextForCausalLM = getattr(modeling_gemma4, "Gemma4TextForCausalLM", None)

# Dispatch: if the caller passed a text-only instance, route to the text
# path so this entry point works as a single user-facing API.
if model is not None:
# `isinstance(cls, type)` filter (rather than `cls is not None`) so we
# also drop the MagicMock the test harness substitutes for
# `Gemma4TextForCausalLM` under `with patch("...modeling_gemma4")` —
# an `isinstance` check against a non-class entry raises TypeError.
text_classes = tuple(
cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM, Gemma4TextModel) if isinstance(cls, type)
)
if isinstance(model, text_classes):
apply_liger_kernel_to_gemma4_text(
rope=rope,
cross_entropy=cross_entropy,
fused_linear_cross_entropy=fused_linear_cross_entropy,
rms_norm=rms_norm,
geglu=geglu,
model=model,
)
return
if not isinstance(model, Gemma4ForConditionalGeneration):
raise TypeError(
"The model must be Gemma4ForConditionalGeneration (or a Gemma 4 "
"text-only variant; those are routed to "
"apply_liger_kernel_to_gemma4_text)."
)

# Class-level patches for the text decoder layers (RMSNorm, GeGLU MLP).
# We disable FLCE here because the multimodal class needs its own forward
# (handles pixel_values / input_features / mm_token_type_ids / etc.) — we
# install that below.
apply_liger_kernel_to_gemma4_text(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
)

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(multimodal_forward, model)
else:
modeling_gemma4.Gemma4ForConditionalGeneration.forward = multimodal_forward

if model is not None:
# Recurse into the language model for instance-level RMSNorm / GeGLU
# patching. (The class-level swap above already covers freshly
# instantiated modules; this catches the already-built ones.)
apply_liger_kernel_to_gemma4_text(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=model.model.language_model,
)


def apply_liger_kernel_to_paligemma(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -3362,6 +3490,7 @@ def __init__(self, hidden_size, eps=1e-6, **kwargs):
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"gemma4_text": apply_liger_kernel_to_gemma4_text,
"gemma4": apply_liger_kernel_to_gemma4,
"glm4": apply_liger_kernel_to_glm4,
"glm4v": apply_liger_kernel_to_glm4v,
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
Expand Down
Loading