From 70e1a2c46932e55ed6fd150862ce7775a339a42d Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Thu, 29 May 2025 19:49:15 +0000 Subject: [PATCH] unit test --- layer_unit_test.ipynb | 315 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 layer_unit_test.ipynb diff --git a/layer_unit_test.ipynb b/layer_unit_test.ipynb new file mode 100644 index 0000000000..a338539d77 --- /dev/null +++ b/layer_unit_test.ipynb @@ -0,0 +1,315 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "51ab58f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "jax 0.015534584410488605 pt 0.015534583479166031\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "Copyright 2025 Google LLC\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\"\"\"\n", + "\n", + "\"\"\" Tests for Llama4 Vision RoPE \"\"\"\n", + "from typing import Callable, NamedTuple, Optional, Tuple\n", + "import os.path\n", + "import sys\n", + "import math\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "import jax\n", + "import unittest\n", + "import jax.numpy as jnp\n", + "from jax.sharding import Mesh\n", + "from MaxText.globals import PKG_DIR\n", + "from MaxText import pyconfig\n", + "from MaxText import maxtext_utils\n", + "from MaxText.layers import attentions, embeddings, llama4\n", + "import numpy as np\n", + "\n", + "Attention = attentions.Attention\n", + "\n", + "# pylint: disable=line-too-long, missing-function-docstring\n", + "\n", + "\"\"\" \n", + "Llama4 Vision RoPE \n", + "Details https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py\n", + "\"\"\"\n", + "\n", + "\n", + "def to_jax(pt_tensor: torch.Tensor) -> jax.Array:\n", + " return jnp.asarray(pt_tensor.detach().numpy())\n", + "\n", + "\n", + "### original Pytorch Reference implementation\n", + "def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):\n", + " \"\"\"Reshape the frequency tensor for broadcasting.\"\"\"\n", + " ndim = query.ndim\n", + " shape = [d if i in (1, ndim - 1) else 1 for i, d in enumerate(query.shape)]\n", + " return freqs_ci.view(*shape)\n", + "\n", + "\n", + "class Llama4UnfoldConvolutionTest(unittest.TestCase):\n", + " \"\"\"Test for the Llama4 Unfold Convolution implementation.\"\"\"\n", + "\n", + " def __copy_weights(self, pt_model, params):\n", + " \"\"\"Copy weights from PyTorch model to JAX model.\n", + "\n", + " Args:\n", + " pt_model: PyTorch Llama4UnfoldConvolution model\n", + " params: JAX model parameters\n", + " \"\"\"\n", + " # Create new params with copied weights\n", + " updated_params = jax.tree_util.tree_map(lambda x: x, params)\n", + " updated_params[\"params\"][\"vit_unfold_linear\"][\"kernel\"] = to_jax(pt_model.linear.weight).T\n", + " return updated_params\n", + "\n", + " def test_unfold_convolution(self):\n", + " \"\"\"Test for the Llama4 Unfold Convolution implementation.\"\"\"\n", + " # Test parameters\n", + " # following the llama4 config\n", + " # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json\n", + " batch_size = 10\n", + " num_channels = 3\n", + " image_size = 336\n", + " patch_size = 14\n", + " hidden_size = 1408\n", + "\n", + " # Create random input tensor\n", + " inputs_pt = torch.randn(batch_size, num_channels, image_size, image_size)\n", + "\n", + " # PyTorch implementation\n", + " # following llama4 implementation in\n", + " # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1279\n", + " class Llama4UnfoldConvolution(nn.Module):\n", + " \"\"\"Llama4 Unfold Convolution implementation.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.hidden_size = hidden_size\n", + " kernel_size = patch_size\n", + " if isinstance(kernel_size, int):\n", + " kernel_size = (kernel_size, kernel_size)\n", + " self.unfold = nn.Unfold(kernel_size=kernel_size, stride=patch_size)\n", + " self.linear = nn.Linear(num_channels * kernel_size[0] * kernel_size[1], hidden_size, bias=False)\n", + "\n", + " def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n", + " # num_patches = (self.image_size // self.patch_size) ** 2\n", + " # hidden_states shape: torch.Size([batch_size, num_channels, img, img])\n", + " hidden_states = self.unfold(hidden_states)\n", + " # hidden_states shape: torch.Size([batch_size, num_channels * patch_size * patch_size, num_patches])\n", + " hidden_states = hidden_states.permute(0, 2, 1)\n", + " # hidden_states shape: torch.Size([batch_size, num_patches, num_channels * patch_size * patch_size])\n", + " hidden_states = self.linear(hidden_states)\n", + " # hidden_states shape: torch.Size([batch_size, num_patches, hidden_size])\n", + " return hidden_states\n", + "\n", + " # Initialize PyTorch model\n", + " pt_model = Llama4UnfoldConvolution()\n", + " pt_model.eval()\n", + " pt_output = pt_model(inputs_pt)\n", + "\n", + " # JAX implementation\n", + " class JaxConfig:\n", + "\n", + " def __init__(self):\n", + " self.patch_size_for_vit = patch_size\n", + " self.hidden_size_for_vit = hidden_size\n", + " self.dtype_mm = jnp.float32\n", + "\n", + " # Initialize JAX model\n", + " jax_model = llama4.Llama4UnfoldConvolution(JaxConfig())\n", + " params = jax_model.init(jax.random.PRNGKey(0), to_jax(inputs_pt))\n", + "\n", + " # Copy weights from PyTorch to JAX\n", + " pt_params = self.__copy_weights(pt_model, params)\n", + "\n", + " # Run JAX forward pass with updated params\n", + " jax_output = jax_model.apply(pt_params, to_jax(inputs_pt))\n", + "\n", + " # Compare shapes\n", + " self.assertEqual(pt_output.shape, jax_output.shape)\n", + "\n", + " # Compare outputs with reasonable tolerances\n", + " np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)\n", + "\n", + "\n", + "class Llama4VisionPixelShuffleMLPTest(unittest.TestCase):\n", + " \"\"\"Test for the Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n", + "\n", + " def __copy_weights(self, pt_model, params):\n", + " \"\"\"Copy weights from PyTorch model to JAX model.\n", + "\n", + " Args:\n", + " pt_model: PyTorch Llama4VisionPixelShuffleMLP model\n", + " params: JAX model parameters\n", + " \"\"\"\n", + " # Create new params with copied weights\n", + " updated_params = jax.tree_util.tree_map(lambda x: x, params)\n", + " # Copy weights for both MLP layers\n", + " updated_params[\"params\"][\"pixel_shuffle_mlp\"][\"vit_pixel_shuffle_mlp_fc1\"][\"kernel\"] = to_jax(pt_model.mlp.fc1.weight).T\n", + " updated_params[\"params\"][\"pixel_shuffle_mlp\"][\"vit_pixel_shuffle_mlp_fc2\"][\"kernel\"] = to_jax(pt_model.mlp.fc2.weight).T\n", + " return updated_params\n", + "\n", + " def test_pixel_shuffle_mlp(self):\n", + " \"\"\"Test for the Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n", + " # Test parameters\n", + " # following config https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json\n", + " batch_size = 10\n", + " num_patches = 24 * 24 # 336/14 = 24 patches per side\n", + " hidden_size = 1408\n", + " intermediate_size = 5632\n", + " projector_input_dim = 4096\n", + " projector_output_dim = 4096\n", + " pixel_shuffle_ratio = 0.5\n", + " projector_dropout = 0.0\n", + "\n", + " def pixel_shuffle(input_tensor, shuffle_ratio):\n", + " # input_tensor: [batch_size, num_patches, channels]\n", + " batch_size, num_patches, channels = input_tensor.shape\n", + " patch_size = int(math.sqrt(num_patches))\n", + "\n", + " input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)\n", + " batch_size, height, width, channels = input_tensor.size()\n", + "\n", + " reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))\n", + " reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n", + "\n", + " reshaped_tensor = reshaped_tensor.view(\n", + " batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))\n", + " )\n", + " reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n", + "\n", + " output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])\n", + " return output_tensor\n", + "\n", + " # PyTorch implementation\n", + " class Llama4VisionMLP2(nn.Module):\n", + " \"\"\"Llama4 Vision MLP2 implementation.\"\"\"\n", + "\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.hidden_size = config.hidden_size\n", + " self.intermediate_size = config.intermediate_size\n", + " self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False)\n", + " self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False)\n", + " self.activation_fn = nn.GELU()\n", + " self.dropout = config.projector_dropout\n", + "\n", + " def forward(self, hidden_states):\n", + " hidden_states = self.fc1(hidden_states)\n", + " hidden_states = self.activation_fn(hidden_states)\n", + " hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)\n", + " return self.activation_fn(self.fc2(hidden_states))\n", + "\n", + " class Llama4VisionPixelShuffleMLP(nn.Module):\n", + " \"\"\"Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n", + "\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.pixel_shuffle_ratio = config.pixel_shuffle_ratio\n", + " self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))\n", + " self.output_dim = config.projector_output_dim\n", + " self.mlp = Llama4VisionMLP2(config)\n", + "\n", + " def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:\n", + " # encoded_patches shape: torch.Size([batch_size, num_patches, hidden_size])\n", + " encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)\n", + " return self.mlp(encoded_patches)\n", + " # result shape: torch.Size([batch_size, num_patches * (pixel_shuffle_rate**2), projector_output_dim])\n", + "\n", + " # Initialize PyTorch model\n", + " class Config:\n", + "\n", + " def __init__(self):\n", + " self.hidden_size = hidden_size\n", + " self.intermediate_size = intermediate_size\n", + " self.projector_input_dim = projector_input_dim\n", + " self.projector_output_dim = projector_output_dim\n", + " self.pixel_shuffle_ratio = pixel_shuffle_ratio\n", + " self.projector_dropout = projector_dropout\n", + "\n", + " # Create random input tensor\n", + " inputs_pt = torch.randn(batch_size, num_patches, hidden_size)\n", + "\n", + " pt_model = Llama4VisionPixelShuffleMLP(Config())\n", + " pt_model.eval()\n", + " pt_output = pt_model(inputs_pt)\n", + "\n", + " # JAX implementation\n", + " class JaxConfig:\n", + "\n", + " def __init__(self):\n", + " self.pixel_shuffle_ratio_for_vit = pixel_shuffle_ratio\n", + " self.projector_input_dim_for_vit = projector_input_dim\n", + " self.projector_output_dim_for_vit = projector_output_dim\n", + " self.dtype_mm = jnp.float32\n", + " self.projector_dropout_for_vit = projector_dropout\n", + "\n", + " # Initialize JAX model\n", + " jax_model = llama4.Llama4VisionPixelShuffleMLP(JaxConfig())\n", + " params = jax_model.init(jax.random.PRNGKey(0), to_jax(inputs_pt))\n", + "\n", + " # Copy weights from PyTorch to JAX\n", + " pt_params = self.__copy_weights(pt_model, params)\n", + "\n", + " # Run JAX forward pass with updated params\n", + " jax_output = jax_model.apply(pt_params, to_jax(inputs_pt), deterministic=True)\n", + "\n", + " # Compare shapes\n", + " self.assertEqual(pt_output.shape, jax_output.shape)\n", + "\n", + " # Compare outputs with reasonable tolerances\n", + " np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)\n", + " print(f\"jax {to_jax(pt_output).mean()} pt {pt_output.mean()}\")\n", + "\n", + "testclass = Llama4VisionPixelShuffleMLPTest()\n", + "testclass.test_pixel_shuffle_mlp()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}