diff --git a/README.md b/README.md index 69926e4ee..91304aa24 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,7 @@ loss.backward() | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma4 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4_text` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Gemma4 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 835f9d4fa..74ac8d7a3 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -43,6 +43,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4_text # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401 @@ -119,6 +120,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma2", "apply_liger_kernel_to_gemma3", "apply_liger_kernel_to_gemma3_text", + "apply_liger_kernel_to_gemma4", "apply_liger_kernel_to_gemma4_text", "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", @@ -207,6 +209,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma2", "apply_liger_kernel_to_gemma3", "apply_liger_kernel_to_gemma3_text", + "apply_liger_kernel_to_gemma4", "apply_liger_kernel_to_gemma4_text", "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", diff --git a/src/liger_kernel/transformers/model/gemma4.py b/src/liger_kernel/transformers/model/gemma4.py index 50e8e8496..6c31f72ce 100644 --- a/src/liger_kernel/transformers/model/gemma4.py +++ b/src/liger_kernel/transformers/model/gemma4.py @@ -11,6 +11,14 @@ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +try: + from liger_kernel.transformers.model.output_classes import LigerGemma4CausalLMOutputWithPast +except ImportError: + # Older transformers without gemma4 — multimodal_forward is then unreachable + # because monkey_patch.apply_liger_kernel_to_gemma4 imports gemma4 modules + # behind the same try/except. + LigerGemma4CausalLMOutputWithPast = None + logger = logging.get_logger(__name__) @@ -149,3 +157,141 @@ def causal_forward( token_accuracy=token_accuracy, predicted_tokens=predicted_tokens, ) + + +def multimodal_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_position_ids: Optional[torch.LongTensor] = None, + video_position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + mm_token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **lm_kwargs, +): + r"""Fused-linear-cross-entropy forward for ``Gemma4ForConditionalGeneration``. + + Mirrors :func:`liger_kernel.transformers.model.gemma3.multimodal_forward` + with Gemma 4-specific kwargs (``pixel_values_videos``, ``input_features``, + ``image_position_ids``, ``video_position_ids``, ``mm_token_type_ids``) and + output fields (``image_hidden_states``, ``audio_hidden_states``). + + The win on Gemma 4 multimodal is large: vocab=262,144 means the (B, T, V) + fp32 logits tensor is ~17 GB at T=8192 in bf16 (and another ~34 GB once the + loss path upcasts), OOMing 96 GB cards. Routing loss through + ``LigerForCausalLMLoss`` materializes only the loss scalar. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either + be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids` + docstring). Tokens with indices set to `-100` are ignored. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, + calculate logits for all `input_ids` (special case). If a `torch.Tensor`, + must be 1D corresponding to the indices to keep in the sequence-length + dimension. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + image_position_ids=image_position_ids, + video_position_ids=video_position_ids, + past_key_values=past_key_values, + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **lm_kwargs, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + text_cfg = self.config.get_text_config() + softcap = getattr(text_cfg, "final_logit_softcapping", None) + shift_labels = lm_kwargs.pop("shift_labels", None) + + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=text_cfg.hidden_size, + final_logit_softcapping=softcap, + **lm_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if softcap is not None: + logits = logits / softcap + logits = torch.tanh(logits) + logits = logits * softcap + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=text_cfg.vocab_size, + **lm_kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output_tuple = (loss,) + output_tuple if loss is not None else output_tuple + output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple + output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple + return output_tuple + + return LigerGemma4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=getattr(outputs, "image_hidden_states", None), + audio_hidden_states=getattr(outputs, "audio_hidden_states", None), + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/output_classes.py b/src/liger_kernel/transformers/model/output_classes.py index 4fb5aaee3..c2ef4a7c7 100644 --- a/src/liger_kernel/transformers/model/output_classes.py +++ b/src/liger_kernel/transformers/model/output_classes.py @@ -19,6 +19,11 @@ except Exception: _Gemma3CausalLMOutputWithPast = None +try: + from transformers.models.gemma4.modeling_gemma4 import Gemma4CausalLMOutputWithPast as _Gemma4CausalLMOutputWithPast +except Exception: + _Gemma4CausalLMOutputWithPast = None + try: from transformers.models.glm4v_moe.modeling_glm4v_moe import ( Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast, @@ -108,6 +113,14 @@ class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast): predicted_tokens: Optional[torch.LongTensor] = None +if _Gemma4CausalLMOutputWithPast is not None: + + @dataclass + class LigerGemma4CausalLMOutputWithPast(_Gemma4CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + if _Glm4vMoeCausalLMOutputWithPast is not None: @dataclass diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 6d97950d7..6f4e98f1c 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1355,7 +1355,11 @@ def _maybe_patch_scaled_norm(module): # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - causal_lm_types = tuple(cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM) if cls is not None) + # `isinstance(cls, type)` filter (rather than `cls is not None`) so we + # also drop the MagicMock the test harness substitutes for + # `Gemma4TextForCausalLM` under `with patch("...modeling_gemma4")` — + # an `isinstance` check against a non-class entry raises TypeError. + causal_lm_types = tuple(cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM) if isinstance(cls, type)) if isinstance(model, causal_lm_types) or isinstance(model, Gemma4TextModel): # get the base model from the model instance base_model = model.model if isinstance(model, causal_lm_types) else model @@ -1384,6 +1388,130 @@ def _maybe_patch_scaled_norm(module): raise TypeError("The model must be Gemma4ForCausalLM, Gemma4TextForCausalLM, or Gemma4TextModel.") +def apply_liger_kernel_to_gemma4( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Gemma4 + multimodal models (`Gemma4ForConditionalGeneration`). + + Dispatches on class: text-only variants (`Gemma4ForCausalLM`, + `Gemma4TextForCausalLM`, `Gemma4TextModel`) are routed to + :func:`apply_liger_kernel_to_gemma4_text` for backwards compatibility, so the + same entry point works for both shapes when an instance is supplied. + + The primary win is the fused-linear-cross-entropy path on the multimodal + class: with vocab=262,144, the (B, T, V) fp32 logits tensor is ~17 GB at + T=8192 in bf16 (and ~34 GB once the loss path upcasts), OOMing 96 GB cards. + Fused CE materializes only the loss scalar. + + Out of scope (deferred to future PRs): + - Vision and audio tower kernel swaps. Gemma 4's vision and audio towers + are loaded via `AutoModel.from_config(config.vision_config)` and + `AutoModel.from_config(config.audio_config)` respectively, so their + module classes are polymorphic. A safe class-level swap would need to + enumerate supported tower types — out of scope here; FLCE on the LM head + is what unblocks training OOM. + - PLE (Per-Layer Embeddings) kernels. PLE state passes through the inner + forward unchanged; verified end-to-end on E4B-it without explicit + handling. + - Gemma 4 MoE expert kernels (`Gemma4TextExperts`); guarded out via the + same `enable_moe_block` check used in the text path. + + Args: + rope (bool): Currently a no-op (HF's apply_rotary_pos_emb signature is + incompatible with Liger's fused variant). Default False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. + Default False. Mutually exclusive with `fused_linear_cross_entropy`. + fused_linear_cross_entropy (bool): Fused linear CE for memory + efficiency. Default True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default True. + model (PreTrainedModel): An already-instantiated model to patch + in-place. Default None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gemma4 import modeling_gemma4 + from transformers.models.gemma4.modeling_gemma4 import Gemma4ForCausalLM + from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + from liger_kernel.transformers.model.gemma4 import multimodal_forward + + Gemma4TextForCausalLM = getattr(modeling_gemma4, "Gemma4TextForCausalLM", None) + + # Dispatch: if the caller passed a text-only instance, route to the text + # path so this entry point works as a single user-facing API. + if model is not None: + # `isinstance(cls, type)` filter (rather than `cls is not None`) so we + # also drop the MagicMock the test harness substitutes for + # `Gemma4TextForCausalLM` under `with patch("...modeling_gemma4")` — + # an `isinstance` check against a non-class entry raises TypeError. + text_classes = tuple( + cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM, Gemma4TextModel) if isinstance(cls, type) + ) + if isinstance(model, text_classes): + apply_liger_kernel_to_gemma4_text( + rope=rope, + cross_entropy=cross_entropy, + fused_linear_cross_entropy=fused_linear_cross_entropy, + rms_norm=rms_norm, + geglu=geglu, + model=model, + ) + return + if not isinstance(model, Gemma4ForConditionalGeneration): + raise TypeError( + "The model must be Gemma4ForConditionalGeneration (or a Gemma 4 " + "text-only variant; those are routed to " + "apply_liger_kernel_to_gemma4_text)." + ) + + # Class-level patches for the text decoder layers (RMSNorm, GeGLU MLP). + # We disable FLCE here because the multimodal class needs its own forward + # (handles pixel_values / input_features / mm_token_type_ids / etc.) — we + # install that below. + apply_liger_kernel_to_gemma4_text( + rope=rope, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=rms_norm, + geglu=geglu, + ) + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(multimodal_forward, model) + else: + modeling_gemma4.Gemma4ForConditionalGeneration.forward = multimodal_forward + + if model is not None: + # Recurse into the language model for instance-level RMSNorm / GeGLU + # patching. (The class-level swap above already covers freshly + # instantiated modules; this catches the already-built ones.) + apply_liger_kernel_to_gemma4_text( + rope=rope, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=rms_norm, + geglu=geglu, + model=model.model.language_model, + ) + + def apply_liger_kernel_to_paligemma( rope: bool = True, cross_entropy: bool = False, @@ -3362,6 +3490,7 @@ def __init__(self, hidden_size, eps=1e-6, **kwargs): "gemma3_text": apply_liger_kernel_to_gemma3_text, "gemma3": apply_liger_kernel_to_gemma3, "gemma4_text": apply_liger_kernel_to_gemma4_text, + "gemma4": apply_liger_kernel_to_gemma4, "glm4": apply_liger_kernel_to_glm4, "glm4v": apply_liger_kernel_to_glm4v, "glm4v_moe": apply_liger_kernel_to_glm4v_moe, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index cd9b834cb..333e17092 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -1957,6 +1957,89 @@ def test_apply_liger_kernel_to_instance_for_gemma4_text(): pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") +@pytest.mark.skipif(not is_gemma4_available(), reason="gemma4 module not available") +def test_apply_liger_kernel_to_instance_for_gemma4_conditional_generation(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.gemma4.modeling_gemma4"): + from transformers.models.gemma4.modeling_gemma4 import Gemma4ForConditionalGeneration + + from liger_kernel.transformers.model.gemma4 import multimodal_forward as gemma4_multimodal_forward + + # Minimal dense-path text config — same knobs pinned off as the + # text-only test below (no PLE, MoE, KV-share, double-wide MLP). + text_config = transformers.models.gemma4.configuration_gemma4.Gemma4TextConfig( + dtype=torch.bfloat16, + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=16, + num_kv_shared_layers=0, + use_double_wide_mlp=False, + enable_moe_block=False, + hidden_size_per_layer_input=0, + ) + # Vision/audio configs left as None — Gemma4Model wraps both towers in + # `if config._config is not None`, so a None-towers model still + # constructs as Gemma4ForConditionalGeneration and exercises the + # multimodal forward we're patching. The towers themselves are + # polymorphic (AutoModel.from_config) and not in this PR's scope. + config = transformers.models.gemma4.configuration_gemma4.Gemma4Config( + text_config=text_config, + vision_config=None, + audio_config=None, + ) + + dummy_model_instance = Gemma4ForConditionalGeneration._from_config(config) + assert isinstance(dummy_model_instance, Gemma4ForConditionalGeneration) + + # Pre-patch: forward and language-model norms must NOT be Liger. + assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(gemma4_multimodal_forward) + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) != inspect.getsource( + LigerRMSNorm.forward + ) + for layer in dummy_model_instance.model.language_model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) != inspect.getsource( + LigerRMSNorm.forward + ) + assert inspect.getsource(layer.self_attn.q_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.self_attn.k_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Post-patch: top-level forward is multimodal_forward, language_model + # norms / MLPs are Liger. + assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(gemma4_multimodal_forward) + assert inspect.getsource(dummy_model_instance.model.language_model.norm.forward) == inspect.getsource( + LigerRMSNorm.forward + ) + for layer in dummy_model_instance.model.language_model.layers: + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.pre_feedforward_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_feedforward_layernorm.forward) == inspect.getsource( + LigerRMSNorm.forward + ) + assert inspect.getsource(layer.self_attn.q_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.self_attn.k_norm.forward) == inspect.getsource(LigerRMSNorm.forward) + v_norm = getattr(layer.self_attn, "v_norm", None) + if v_norm is not None: + # with_scale=False → intentionally not patched. + assert inspect.getsource(v_norm.forward) != inspect.getsource(LigerRMSNorm.forward) + + try: + print(dummy_model_instance) + except Exception as e: + pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") + + def test_apply_liger_kernel_to_instance_for_qwen2(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen2.modeling_qwen2"): diff --git a/test/utils.py b/test/utils.py index 0af7c47ce..831f00e81 100644 --- a/test/utils.py +++ b/test/utils.py @@ -505,6 +505,21 @@ def revert_liger_kernel_to_gemma4_text(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_gemma4(model_config: MiniModelConfig): + """Revert all Liger kernel patches applied to Gemma4 multimodal model.""" + + from transformers.models.gemma4 import modeling_gemma4 + + # Vision/audio towers are loaded via AutoModel.from_config, so their + # module classes are polymorphic — no class-level swap to revert there. + # Reloading modeling_gemma4 resets Gemma4RMSNorm / Gemma4TextMLP / + # Gemma4ForConditionalGeneration.forward, which is the surface the + # multimodal patch touches. + importlib.reload(modeling_gemma4) + model_config.model_class = modeling_gemma4.Gemma4ForConditionalGeneration + print("Liger kernel patches have been reverted.") + + def revert_liger_kernel_to_gemma3(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Gemma3.