Skip to content

Fix: pailgemma type error #627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jp1924
Copy link
Contributor

@jp1924 jp1924 commented Mar 25, 2025

Summary

@Tcc0403

back ground
#524 (comment)

Once, while I was working on #524 PR, the following error occurred in the pailgemma section, and I checked the cause.

The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.

As you can see in the picture, there is a strange situation where False comes up even though I am doing a type check with a language_model that is assigned GemmaForCasualLM.

image

So, I suggested to patch the liger_kernel by importing the model_type directly with model.config.text_config.model_type, outside of isinstance, and then patching the liger_kernel.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@eljandoubi
Copy link
Contributor

@jp1924 If the language model does not pass the type test, there may be patches that were not reverted.

@jp1924
Copy link
Contributor Author

jp1924 commented Mar 25, 2025

@eljandoubi Thank you for your interest!

However, this error occurs even when running only pailgemma individually in the test convergence. Specifically, it happens in the pailgemma1 test during the sequence of pailgemma1 and pailgemma2, and it occurs right away in pailgemma1.

So, this is not an error caused by another multimodal test where a patch applied to siglip wasn't fully reverted.

Below is the model when the error occurs, and there’s no sign of any liger patch applied:

PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-3): 4 x SiglipEncoderLayer(
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=1152, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
      (head): SiglipMultiheadAttentionPoolingHead(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
        )
        (layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
        (mlp): SiglipMLP(
          (activation_fn): PytorchGELUTanh()
          (fc1): Linear(in_features=1152, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=1152, bias=True)
        )
      )
    )
  )
  (multi_modal_projector): PaliGemmaMultiModalProjector(
    (linear): Linear(in_features=1152, out_features=1024, bias=True)
  )
  (language_model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(32000, 1024, padding_idx=0)
      (layers): ModuleList(
        (0-3): 4 x GemmaDecoderLayer(
          (self_attn): GemmaAttention(
            (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
            (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
            (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm((1024,), eps=1e-06)
          (post_attention_layernorm): GemmaRMSNorm((1024,), eps=1e-06)
        )
      )
      (norm): GemmaRMSNorm((1024,), eps=1e-06)
      (rotary_emb): GemmaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
  )
)

Looking at other multimodal tests, pailgemma is the only model using siglip.

That’s why I made this modification.

@eljandoubi
Copy link
Contributor

eljandoubi commented Mar 25, 2025

@jp1924 Tests for PaliGemma succeed on my local machine and in the repository's CI/CD pipeline. Additionally, I've used it to train the model with real-world data using huggingface integration on the cloud. Can you provide more details about your environment?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 25, 2025

@jp1924 Can you provide the code you ran to reproduce the issue with only paligemma1/2 test cases?

@jp1924
Copy link
Contributor Author

jp1924 commented Mar 27, 2025

@eljandoubi @Tcc0403
To begin debugging, a Docker environment was set up, and the source code from linkedin/Liger-Kernel was built and executed.

Here’s the environment used for debugging:

transformers             4.50.1
torch                    2.5.1+cu121
numpy                    2.1.2
flash_attn               2.7.4.post1
Platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.35

The test was limited to test_mini_models_multimodal.py, which includes pailgemma.

pytest /root/workspace/or-liger/test/convergence/bf16/test_mini_models_multimodal.py -k=mini_paligemma

Scenario 1
When kwargs["model"] = model is not used during the convergence test:

When the model was run without passing the model to kwargs and without any modifications, it executed normally.

Scenario 2
When kwargs["model"] = model is used during the convergence test:

AttributeError: 'LigerLayerNorm' object has no attribute 'normalized_shape'

An error occurs, as mentioned in this comment:
#524 (comment)

After resolving this issue and running it again, the following error appears:

TypeError: The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.

This error can be observed as the next issue.

detail log
==================================================================================================================== test session starts =====================================================================================================================
platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.5.0
rootdir: /root/workspace/or-liger
configfile: pyproject.toml
collecting ... 
-------------------------------------------------------------------------------------------------------------------- live log collection ---------------------------------------------------------------------------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1+cu121 available.
collected 5 items / 3 deselected / 2 selected                                                                                                                                                                                                                

test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] FAILED                                                                                                 [ 50%]
test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] FAILED                                                                                                [100%]

========================================================================================================================== FAILURES ==========================================================================================================================
_________________________________________________________________________________ test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] __________________________________________________________________________________

model_name = 'mini_paligemma', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 0.001, loss_rtol = 0.01, logits_atol = 0.1, logits_rtol = 0.01, param_atol = 0.01, param_rtol = 0.01

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            pytest.param(
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not QWEN2_VL_AVAILABLE,
                        reason="Qwen2-VL not available in this version of transformers",
                    ),
                    pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
                ],
            ),
            pytest.param(
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not QWEN2_5_VL_AVAILABLE,
                        reason="Qwen2.5-VL not available in this version of transformers",
                    ),
                    pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
                ],
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not MLLAMA_AVAILABLE,
                        reason="Mllama not available in this version of transformers",
                    ),
                ],
            ),
            pytest.param(
                "mini_paligemma",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not PALIGEMMA_AVAILABLE,
                        reason="Paligemma not available in this version of transformers",
                    ),
                ],
            ),
            pytest.param(
                "mini_paligemma2",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not PALIGEMMA_AVAILABLE,
                        reason="Paligemma2 not available in this version of transformers",
                    ),
                ],
            ),
        ],
    )
    def test_mini_model_multimodal(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
        expected_output = run_mini_model_multimodal(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
>       actual_output = run_mini_model_multimodal(
            model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
        )

test/convergence/bf16/test_mini_models_multimodal.py:676: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
test/convergence/bf16/test_mini_models_multimodal.py:532: in run_mini_model_multimodal
    MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

rope = True, cross_entropy = False, fused_linear_cross_entropy = False, layer_norm = True, rms_norm = True, geglu = True
model = PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
 ...rotary_emb): GemmaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
  )
)

    def apply_liger_kernel_to_paligemma(
        rope: bool = True,
        cross_entropy: bool = False,
        fused_linear_cross_entropy: bool = True,
        layer_norm: bool = True,
        rms_norm: bool = True,
        geglu: bool = True,
        model: PreTrainedModel = None,
    ) -> None:
        """
        Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
    
        Args:
            rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
            cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
            fused_linear_cross_entropy (bool):
                Whether to apply Liger's fused linear cross entropy loss. Default is True.
                `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
                If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
            layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
            rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
            geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
            model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
            loaded. Default is None.
        """
        assert not (cross_entropy and fused_linear_cross_entropy), (
            "cross_entropy and fused_linear_cross_entropy cannot both be True."
        )
    
        # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
    
        from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
        from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
        from transformers.models.paligemma import modeling_paligemma
        from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
        from transformers.models.siglip import modeling_siglip
        from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
        from transformers.models.siglip.modeling_siglip import SiglipVisionModel
    
        from liger_kernel.transformers.model.paligemma import lce_forward
        from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
    
        # The vision_tower is a SiglipVisionModel
        if layer_norm:
            modeling_siglip.nn.LayerNorm = LigerLayerNorm
    
        # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
        # The multi_modal_projector is Linear, nothing to do
    
        # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
        apply_liger_kernel_to_gemma(
            rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
        )
        apply_liger_kernel_to_gemma2(
            rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
        )
        # Handle loss function
        if cross_entropy:
            modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
        if fused_linear_cross_entropy:
            if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
                modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
            else:  # if version < 4.46.1
                logger.warning(TRANSFORMER_DEPRECATION_WARNING)
                modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
    
        if model is not None:
            # The model instance already exists, so we need to additionally patch the
            # instance variables that reference already-instantiated modules
    
            if not isinstance(model, PaliGemmaForConditionalGeneration):
                raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
    
            vision_tower: SiglipVisionModel = model.vision_tower
    
            _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
    
            for layer in vision_tower.vision_model.encoder.layers:
                layer: SiglipEncoderLayer
                if layer_norm:
                    _patch_layer_norm_module(layer.layer_norm1)
                    _patch_layer_norm_module(layer.layer_norm2)
    
            language_model = model.language_model
    
            if isinstance(language_model, GemmaForCausalLM):
                apply_liger_kernel_to_gemma(
                    rope=rope,
                    cross_entropy=False,
                    fused_linear_cross_entropy=False,
                    rms_norm=rms_norm,
                    geglu=geglu,
                    model=language_model,
                )
    
            elif isinstance(language_model, Gemma2ForCausalLM):
                apply_liger_kernel_to_gemma2(
                    rope=rope,
                    cross_entropy=False,
                    fused_linear_cross_entropy=False,
                    rms_norm=rms_norm,
                    geglu=geglu,
                    model=language_model,
                )
            else:
>               raise TypeError(
                    "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
                )
E               TypeError: The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.

src/liger_kernel/transformers/monkey_patch.py:720: TypeError
-------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------
Liger kernel patches have been reverted.



Step 0, Loss: 11.872488975524902
Step 1, Loss: 2.908055067062378
Step 2, Loss: 1.80923330783844
Step 3, Loss: 1.7920551300048828
Step 4, Loss: 1.4004158973693848
Step 5, Loss: 1.259793996810913
Step 6, Loss: 1.3108876943588257
Step 7, Loss: 1.0080386400222778
Step 8, Loss: 1.0473616123199463
Step 9, Loss: 1.1045541763305664
Step 10, Loss: 0.862948477268219
Step 11, Loss: 0.9946214556694031
Step 12, Loss: 0.7224061489105225
Step 13, Loss: 0.7274509072303772
Step 14, Loss: 0.79590904712677
Step 15, Loss: 0.8812856078147888
Step 16, Loss: 0.8721840977668762
Step 17, Loss: 0.8142553567886353
Step 18, Loss: 0.8549021482467651
Step 19, Loss: 0.65662682056427
Step 20, Loss: 0.6154780983924866
Step 21, Loss: 0.45201078057289124
Step 22, Loss: 0.6145424246788025
Step 23, Loss: 0.6072823405265808
Step 24, Loss: 0.41036954522132874
Step 25, Loss: 0.7026124000549316
Step 26, Loss: 0.3765292763710022
Step 27, Loss: 0.4183444380760193
Step 28, Loss: 0.5464693903923035
Step 29, Loss: 0.6562455296516418
Step 30, Loss: 0.5901074409484863
Step 31, Loss: 0.29771801829338074
Liger kernel patches have been reverted.
-------------------------------------------------------------------------------------------------------------------- Captured stderr call --------------------------------------------------------------------------------------------------------------------
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
_________________________________________________________________________________ test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] _________________________________________________________________________________

model_name = 'mini_paligemma2', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 0.001, loss_rtol = 0.01, logits_atol = 0.1, logits_rtol = 0.01, param_atol = 0.01, param_rtol = 0.01

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            pytest.param(
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not QWEN2_VL_AVAILABLE,
                        reason="Qwen2-VL not available in this version of transformers",
                    ),
                    pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
                ],
            ),
            pytest.param(
                "mini_qwen2_5_vl",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not QWEN2_5_VL_AVAILABLE,
                        reason="Qwen2.5-VL not available in this version of transformers",
                    ),
                    pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
                ],
            ),
            pytest.param(
                "mini_mllama",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not MLLAMA_AVAILABLE,
                        reason="Mllama not available in this version of transformers",
                    ),
                ],
            ),
            pytest.param(
                "mini_paligemma",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not PALIGEMMA_AVAILABLE,
                        reason="Paligemma not available in this version of transformers",
                    ),
                ],
            ),
            pytest.param(
                "mini_paligemma2",
                32,
                1e-4,
                torch.bfloat16,
                1e-3,
                1e-2,
                1e-1,
                1e-2,
                1e-2,
                1e-2,
                marks=[
                    pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
                    pytest.mark.skipif(
                        not PALIGEMMA_AVAILABLE,
                        reason="Paligemma2 not available in this version of transformers",
                    ),
                ],
            ),
        ],
    )
    def test_mini_model_multimodal(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden
        expected_output = run_mini_model_multimodal(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
    
>       actual_output = run_mini_model_multimodal(
            model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
        )

test/convergence/bf16/test_mini_models_multimodal.py:676: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
test/convergence/bf16/test_mini_models_multimodal.py:532: in run_mini_model_multimodal
    MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

rope = True, cross_entropy = False, fused_linear_cross_entropy = False, layer_norm = True, rms_norm = True, geglu = True
model = PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
 ...otary_emb): Gemma2RotaryEmbedding()
    )
    (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
  )
)

    def apply_liger_kernel_to_paligemma(
        rope: bool = True,
        cross_entropy: bool = False,
        fused_linear_cross_entropy: bool = True,
        layer_norm: bool = True,
        rms_norm: bool = True,
        geglu: bool = True,
        model: PreTrainedModel = None,
    ) -> None:
        """
        Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
    
        Args:
            rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
            cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
            fused_linear_cross_entropy (bool):
                Whether to apply Liger's fused linear cross entropy loss. Default is True.
                `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
                If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
            layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
            rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
            geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
            model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
            loaded. Default is None.
        """
        assert not (cross_entropy and fused_linear_cross_entropy), (
            "cross_entropy and fused_linear_cross_entropy cannot both be True."
        )
    
        # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
    
        from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
        from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
        from transformers.models.paligemma import modeling_paligemma
        from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
        from transformers.models.siglip import modeling_siglip
        from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
        from transformers.models.siglip.modeling_siglip import SiglipVisionModel
    
        from liger_kernel.transformers.model.paligemma import lce_forward
        from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
    
        # The vision_tower is a SiglipVisionModel
        if layer_norm:
            modeling_siglip.nn.LayerNorm = LigerLayerNorm
    
        # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
        # The multi_modal_projector is Linear, nothing to do
    
        # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
        apply_liger_kernel_to_gemma(
            rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
        )
        apply_liger_kernel_to_gemma2(
            rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
        )
        # Handle loss function
        if cross_entropy:
            modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
        if fused_linear_cross_entropy:
            if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
                modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
            else:  # if version < 4.46.1
                logger.warning(TRANSFORMER_DEPRECATION_WARNING)
                modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
    
        if model is not None:
            # The model instance already exists, so we need to additionally patch the
            # instance variables that reference already-instantiated modules
    
            if not isinstance(model, PaliGemmaForConditionalGeneration):
                raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
    
            vision_tower: SiglipVisionModel = model.vision_tower
    
            _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
    
            for layer in vision_tower.vision_model.encoder.layers:
                layer: SiglipEncoderLayer
                if layer_norm:
                    _patch_layer_norm_module(layer.layer_norm1)
                    _patch_layer_norm_module(layer.layer_norm2)
    
            language_model = model.language_model
    
            if isinstance(language_model, GemmaForCausalLM):
                apply_liger_kernel_to_gemma(
                    rope=rope,
                    cross_entropy=False,
                    fused_linear_cross_entropy=False,
                    rms_norm=rms_norm,
                    geglu=geglu,
                    model=language_model,
                )
    
            elif isinstance(language_model, Gemma2ForCausalLM):
                apply_liger_kernel_to_gemma2(
                    rope=rope,
                    cross_entropy=False,
                    fused_linear_cross_entropy=False,
                    rms_norm=rms_norm,
                    geglu=geglu,
                    model=language_model,
                )
            else:
>               raise TypeError(
                    "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
                )
E               TypeError: The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.

src/liger_kernel/transformers/monkey_patch.py:720: TypeError
-------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------
Liger kernel patches have been reverted.



Step 0, Loss: 11.491681098937988
Step 1, Loss: 5.270604133605957
Step 2, Loss: 2.1230340003967285
Step 3, Loss: 1.9017783403396606
Step 4, Loss: 1.393043875694275
Step 5, Loss: 1.2913684844970703
Step 6, Loss: 1.3962070941925049
Step 7, Loss: 1.1137727499008179
Step 8, Loss: 1.162048101425171
Step 9, Loss: 1.226529836654663
Step 10, Loss: 0.9842991232872009
Step 11, Loss: 1.1158515214920044
Step 12, Loss: 0.8610438704490662
Step 13, Loss: 0.8677318692207336
Step 14, Loss: 0.9316790699958801
Step 15, Loss: 0.9952855110168457
Step 16, Loss: 0.9811475872993469
Step 17, Loss: 0.9185448288917542
Step 18, Loss: 0.9610309600830078
Step 19, Loss: 0.7695668935775757
Step 20, Loss: 0.7231904864311218
Step 21, Loss: 0.553207516670227
Step 22, Loss: 0.7114338278770447
Step 23, Loss: 0.7015824913978577
Step 24, Loss: 0.5099937915802002
Step 25, Loss: 0.7909064292907715
Step 26, Loss: 0.468069851398468
Step 27, Loss: 0.5070297718048096
Step 28, Loss: 0.626304566860199
Step 29, Loss: 0.7336246967315674
Step 30, Loss: 0.6568887829780579
Step 31, Loss: 0.3821035921573639
Liger kernel patches have been reverted.
-------------------------------------------------------------------------------------------------------------------- Captured stderr call --------------------------------------------------------------------------------------------------------------------
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
====================================================================================================================== warnings summary ======================================================================================================================
../../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1441
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01]
test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01]
  /root/workspace/or-liger/test/utils.py:159: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    input_ids = torch.cat([torch.tensor(item["input_ids"]) for item in data])

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== short test summary info ===================================================================================================================
FAILED test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] - TypeError: The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.
FAILED test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_paligemma2-32-0.0001-dtype4-0.001-0.01-0.1-0.01-0.01-0.01] - TypeError: The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM.
======================================================================================================== 2 failed, 3 deselected, 3 warnings in 30.95s ========================================================================================================

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 27, 2025

Context: we try to instantiate the model first and apply monkey patch to the instance.

the change looks like this:

def run_mini_model_multimodal(
    model_name="mini_qwen2_vl",
    num_steps=100,
    dtype=torch.bfloat16,
    lr=1e-5,
    with_liger=False,
):
    set_seed(42)
    model = create_model(model_name).to(dtype).to(device)   # instantiate first

    revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]}
    if "mllama" in model_name:
        revert_kwargs["model_type"] = "conditional_generation"

    if with_liger is True:
        kwargs = {
            "rope": True,
            "rms_norm": True,
            "cross_entropy": False,
        }

        if "qwen2_5_vl" not in model_name:
            kwargs["layer_norm"] = True

        if "gemma" in model_name:
            kwargs["geglu"] = True
        else:
            kwargs["swiglu"] = True
        kwargs["model"] = model  # add arg model
        MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)   # patch later
    else:
        MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)
...

Here's the reproduce code: gist
output:

❯ python3 revert_then_patch.py
True
0x55d0cd1d08a0, id of model.language_model.__class__  # object 0
0x55d0cd1d08a0, id of modeling_gemma.GemmaForCausalLM # object 0
=========================== importlib.reload(modeling_gemma) ====================================
False
0x55d0cd1d08a0, id of model.language_model.__class__  # object 0
0x55d0cec372d0, id of modeling_gemma.GemmaForCausalLM # object 1

The root cause is reloading modules for reverting, which makes GemmaForCausalLM a new object when importing again in apply_liger_kernel_to_paligemma().

This is not a common use case, but this fix can cover the edge case without rewriting current revert functions.
@eljandoubi what do you think?

@eljandoubi
Copy link
Contributor

eljandoubi commented Mar 27, 2025

@jp1924 @Tcc0403 If we are unable to revert the patch, testing Paligemma-2 after Paligemma-1 will create the same Siglip model for Paligemma-2 with Liger-Kernel for both the expectation and the target.
image
The issue lies in reverting a patch for the created model, not the patch for Paligemma itself. A better solution is to create a more robust revert patch and add tests for the monkey patch of Paligemma-1 and Paligemma-2 in test/transformers/test_monkey_patch.py.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Mar 30, 2025

The issue lies in reverting a patch for the created model, not the patch for Paligemma itself. A better solution is to create a more robust revert patch and add tests for the monkey patch of Paligemma-1 and Paligemma-2 in test/transformers/test_monkey_patch.py.

@eljandoubi I agree. We really need a more robust function to revert these monkey patches for our tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants