Description
Issue Type
Documentation Bug
Source
source
Keras Version
2.14
Custom Code
Yes
OS Platform and Distribution
Ubuntu 22.04
Python version
3.10
GPU model and memory
Nvidia RTX4070 (12GB)
Current Behavior?
Hi,
I've spotted a mistake in the Vision Transformer examples in Keras.io [3,4,5,6,7].
In all five of the examples below, to build the ViT architecture, the authors use a single hyper-parameter named projection_dim, which is used both as the model's hidden dimension, and as the dimension for queries, keys, and values, in the multi-head attention layer. These two hyper-parameters they shouldn't be the same. However, according to [1], they are connected:
hidden dimension = number of heads * qkv dimension
One simple way to verify this issue, is to calculate the total number of trainable parameters of the model.
Using the architecture from the examples in Keras.io, and setting the same hyper-parameters with vision transformer base, the model has only 15 million parameters (while the Vision Transformer Base has 86 million [2]).
To fix this issue:
-
a hidden dimension parameter can be defined as:
hidden_dim = projection_dim * num_heads -
The encoded patches should be projected in the hidden dimension, instead of the projection_dim:
encoded_patches = PatchEncoder(num_patches, hidden_dim)(patches) -
The transformer_units should also use the hidden dimension:
transformer_units = [hidden_dim * 2, hidden_dim, ]
Then, if the same hyper-parameters used as in the original paper, the number of trainable parameters will be the same, as in the ViT base.
I understand that the authors may have used alternative versions of the original model, but this particular modification, can change significantly the behaviour of the model.
If you'll need any further information, please let me know.
Best wishes,
Angelos
[1] see table 3 in the original paper: https://arxiv.org/pdf/1706.03762
[2] https://arxiv.org/pdf/2010.11929
[3] https://keras.io/examples/vision/image_classification_with_vision_transformer/
[4] https://keras.io/examples/vision/vit_small_ds/
[5] https://keras.io/examples/vision/object_detection_using_vision_transformer/
[6] https://keras.io/examples/vision/token_learner/
[7] https://keras.io/examples/vision/vit_small_ds/
Standalone code to reproduce the issue or tutorial link
# Below is the ViT class (create_vit_object_detector, indentical to [3]),
# with the same hyper-parameters as ViT-base, including the hidden dimension hyper-parameter.
# If the hidden_dimension is set equal to projection_dim (as implied in the Keras.io examples)
# the model will have 15M parameters.
# If set to 768 (=projection_dim*num_heads), it will have 86M parameters, as the original model.
# The code uses tensorflow 2.14
#%% Import libraries
import keras
from keras import layers
import tensorflow as tf
#%% define required functions and classes
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
input_shape = tf.shape(images)
batch_size = input_shape[0]
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3]
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patches = tf.image.extract_patches(images,
[1,self.patch_size,self.patch_size,1],
[1,self.patch_size,self.patch_size,1],
rates=[1, 1, 1, 1],
padding="SAME"
)
patches = tf.reshape(
patches,
(
batch_size,
num_patches_h * num_patches_w,
self.patch_size * self.patch_size * channels,
),
)
return patches
def get_config(self):
config = super().get_config()
config.update({"patch_size": self.patch_size})
return config
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
# Override function to avoid error while saving model
def get_config(self):
config = super().get_config().copy()
config.update(
{
"input_shape": input_shape,
"patch_size": patch_size,
"num_patches": num_patches,
"projection_dim": projection_dim,
"num_heads": num_heads,
"transformer_units": transformer_units,
"transformer_layers": transformer_layers,
"mlp_head_units": mlp_head_units,
}
)
return config
def call(self, patch):
positions = tf.expand_dims(
tf.experimental.numpy.arange(start=0, stop=self.num_patches, step=1), axis=0
)
projected_patches = self.projection(patch)
encoded = projected_patches + self.position_embedding(positions)
return encoded
def create_vit_object_detector(
input_shape,
patch_size,
num_patches,
projection_dim,
num_heads,
transformer_units,
transformer_layers,
mlp_head_units,
hidden_dimension
):
inputs = keras.Input(shape=input_shape)
# Create patches
patches = Patches(patch_size)(inputs)
# Encode patches
encoded_patches = PatchEncoder(num_patches, hidden_dimension)(patches)
# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.GlobalAveragePooling1D()(representation)
representation = layers.Dropout(0.3)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.3)
bounding_box = layers.Dense(1024)(
features
) # Final four neurons that output bounding box
# return Keras model.
return keras.Model(inputs=inputs, outputs=bounding_box)
#%% model parameters
image_size = 224
patch_size = 16
input_shape = (image_size, image_size, 3) # input image shape
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 12
hidden_dimension = projection_dim * num_heads
# Size of the transformer layers
transformer_units = [
hidden_dimension * 3,
hidden_dimension,
]
transformer_layers = 12
mlp_head_units = [3072, 3072] # Size of the dense layers
vit_object_detector = create_vit_object_detector(
input_shape,
patch_size,
num_patches,
projection_dim,
num_heads,
transformer_units,
transformer_layers,
mlp_head_units,
hidden_dimension,
)
vit_object_detector.summary()
Relevant log output
No response