-
Notifications
You must be signed in to change notification settings - Fork 293
Add Qwen3 Moe #2260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Qwen3 Moe #2260
Changes from all commits
e528378
84043a3
750412c
9b3d779
6b74171
730a9c4
5f90d10
cda9cfc
1b21c7c
6214a1b
e9e6ca5
6fd2de8
992bfc6
f58c42e
6abee28
914fc05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,369 @@ | ||
import math | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm | ||
from keras_hub.src.utils.keras_utils import clone_initializer | ||
from keras_hub.src.utils.keras_utils import fused_attention_op_available | ||
|
||
|
||
class Qwen3MoeAttention(keras.layers.Layer): | ||
"""A multi-head attention layer for Qwen3Moe models | ||
This attention implementation supports grouped-query attention (GQA) where | ||
kanpuriyanawab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
the number of key-value heads can be less than the number of query heads. | ||
|
||
Args: | ||
kanpuriyanawab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_query_heads: Number of query heads. | ||
num_key_value_heads: Number of key/value heads (for GQA). | ||
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position | ||
Embedding). | ||
rope_scaling_factor: Scaling factor for RoPE, used for extending | ||
context length. | ||
kernel_initializer: Initializer for the kernel weights. | ||
dropout: Dropout rate for attention weights. | ||
sliding_window_size: Size of the sliding window for attention. | ||
**kwargs: Additional keyword arguments to pass to the Layer. | ||
""" | ||
Comment on lines
+17
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for A similar issue exists for the Args:
num_query_heads: int. Number of query heads.
num_key_value_heads: int. Number of key/value heads (for GQA).
head_dim: int. The dimension of each attention head.
rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position
Embedding).
rope_scaling_factor: float. Scaling factor for RoPE, used for extending
context length.
kernel_initializer: Initializer for the kernel weights.
dropout: float. Dropout rate for attention weights.
layer_norm_epsilon: float. The epsilon value for layer normalization.
sliding_window_size: int. Size of the sliding window for attention.
**kwargs: Additional keyword arguments to pass to the Layer. Style Guide ReferencesFootnotes |
||
|
||
def __init__( | ||
self, | ||
num_query_heads, | ||
num_key_value_heads, | ||
head_dim=None, | ||
rope_max_wavelength=10000, | ||
rope_scaling_factor=1, | ||
kernel_initializer="glorot_uniform", | ||
dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
sliding_window_size=None, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
**kwargs, | ||
) | ||
self.num_query_heads = num_query_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.head_dim = head_dim | ||
self.dropout = dropout | ||
|
||
self.layer_norm_epsilon = layer_norm_epsilon | ||
|
||
self.num_key_value_groups = num_query_heads // num_key_value_heads | ||
self.rope_max_wavelength = rope_max_wavelength | ||
|
||
self.kernel_initializer = keras.initializers.get( | ||
clone_initializer(kernel_initializer) | ||
) | ||
|
||
self.rope_scaling_factor = rope_scaling_factor | ||
self.sliding_window_size = sliding_window_size | ||
|
||
def build(self, inputs_shape): | ||
# Einsum variables: | ||
# b = batch size | ||
# q = query length | ||
# k = key/value length | ||
# m = model dim | ||
# u = num query heads | ||
# v = num key/value heads | ||
# h = head dim | ||
hidden_dim = inputs_shape[-1] | ||
if not self.head_dim: | ||
self.head_dim = hidden_dim // self.num_query_heads | ||
kanpuriyanawab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) | ||
self._query_dense = keras.layers.EinsumDense( | ||
equation="bqm,muh->bquh", | ||
output_shape=(None, self.num_query_heads, self.head_dim), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="query", | ||
) | ||
self._query_dense.build(inputs_shape) | ||
|
||
self._query_dense_layer_norm = Qwen3MoeLayerNorm( | ||
epsilon=self.layer_norm_epsilon, | ||
dtype=self.dtype_policy, | ||
head_dim=self.head_dim, | ||
name="query_dense_layernorm", | ||
) | ||
self._query_dense_layer_norm.build(inputs_shape) | ||
|
||
self._key_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self.head_dim, | ||
), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="key", | ||
) | ||
self._key_dense.build(inputs_shape) | ||
|
||
self._key_dense_layer_norm = Qwen3MoeLayerNorm( | ||
epsilon=self.layer_norm_epsilon, | ||
dtype=self.dtype_policy, | ||
head_dim=self.head_dim, | ||
name="key_dense_layernorm", | ||
) | ||
self._key_dense_layer_norm.build(inputs_shape) | ||
|
||
self._value_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self.head_dim, | ||
), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="value", | ||
) | ||
self._value_dense.build(inputs_shape) | ||
|
||
self._softmax = keras.layers.Softmax( | ||
axis=-1, | ||
dtype="float32", | ||
name="attention_softmax", | ||
) | ||
|
||
self._dropout_layer = keras.layers.Dropout( | ||
rate=self.dropout, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._output_dense = keras.layers.EinsumDense( | ||
equation="bquh,uhm->bqm", | ||
output_shape=(None, hidden_dim), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="attention_output", | ||
) | ||
self._output_dense.build( | ||
(None, None, self.num_query_heads, self.head_dim) | ||
) | ||
|
||
self.rotary_embedding_layer = RotaryEmbedding( | ||
max_wavelength=self.rope_max_wavelength, | ||
scaling_factor=self.rope_scaling_factor, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._dot_product_equation = "bquh,bkuh->buqk" | ||
self._combine_equation = "buqk,bkuh->bquh" | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask=None, | ||
cache=None, | ||
cache_update_index=None, | ||
training=None, | ||
): | ||
"""Applies attention mechanism to the input hidden states. | ||
|
||
Args: | ||
hidden_states: Input tensor of shape [batch_size, seq_length, | ||
hidden_size]. | ||
attention_mask: Mask tensor of shape [batch_size, seq_length, | ||
seq_length]. | ||
cache: Optional cached key and value tensors. | ||
cache_update_index: Index at which to update the cache. | ||
training: Boolean indicating whether in training mode. | ||
|
||
Returns: | ||
attention_output: Output tensor after applying attention. | ||
cache: Updated cache tensors (if cache is provided). | ||
""" | ||
start_index = ( | ||
cache_update_index if cache_update_index is not None else 0 | ||
) | ||
|
||
query = self._query_dense(hidden_states) | ||
query = self._query_dense_layer_norm(query) | ||
|
||
# Compute RoPE for queries | ||
query = self.rotary_embedding_layer(query, start_index=start_index) | ||
|
||
def _compute_key_value(x): | ||
key = self._key_dense(x) | ||
key = self._key_dense_layer_norm(key) | ||
key = self.rotary_embedding_layer(key, start_index=start_index) | ||
|
||
value = self._value_dense(x) | ||
|
||
return key, value | ||
|
||
if cache is not None: | ||
key_cache = cache[:, 0, ...] | ||
value_cache = cache[:, 1, ...] | ||
if cache_update_index is None: | ||
key = key_cache | ||
value = value_cache | ||
else: | ||
key_update, value_update = _compute_key_value(hidden_states) | ||
start = [0, cache_update_index, 0, 0] | ||
key = ops.slice_update(key_cache, start, key_update) | ||
value = ops.slice_update(value_cache, start, value_update) | ||
cache = ops.stack((key, value), axis=1) | ||
else: | ||
if cache_update_index is not None: | ||
raise ValueError( | ||
"`cache_update_index` should not be set if `cache` is " | ||
f"`None`. Received: cache={cache}, " | ||
f"cache_update_index={cache_update_index}" | ||
) | ||
key, value = _compute_key_value(hidden_states) | ||
|
||
# [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
# -> [batch_shape, seq_len, num_heads, head_dim] | ||
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
|
||
attention_output = self._compute_attention( | ||
query, | ||
key, | ||
value, | ||
attention_mask, | ||
cache_update_index=cache_update_index, | ||
) | ||
|
||
attention_output = self._dropout_layer( | ||
attention_output, training=training | ||
) | ||
|
||
attention_output = self._output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output | ||
|
||
def _masked_softmax(self, attention_scores, attention_mask=None): | ||
"""Applies softmax with optional masking. | ||
|
||
Args: | ||
attention_scores: Attention score tensor. | ||
attention_mask: Optional mask tensor. | ||
|
||
Returns: | ||
Masked softmax attention weights. | ||
""" | ||
if attention_mask is not None: | ||
return self._softmax( | ||
attention_scores, attention_mask[:, None, :, :] | ||
) | ||
return self._softmax(attention_scores) | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, cache_update_index=None | ||
): | ||
"""Computes attention using query, key, and value tensors. | ||
Uses Flash Attention when available for better performance. | ||
|
||
Args: | ||
query: Query tensor. | ||
key: Key tensor. | ||
value: Value tensor. | ||
attention_mask: Optional mask tensor. | ||
cache_update_index: Index for sliding window computation. | ||
|
||
Returns: | ||
attention_output: Output tensor after applying attention. | ||
""" | ||
if fused_attention_op_available(): | ||
# Use `dot_product_attention` with Flash Attention support if | ||
# available. | ||
if attention_mask is not None: | ||
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_mask = ops.cast(attention_mask, dtype="bool") | ||
attention_output = ops.dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
mask=attention_mask, | ||
scale=self._inv_norm_factor, | ||
) | ||
return attention_output | ||
|
||
attention_scores = ops.einsum(self._dot_product_equation, query, key) | ||
|
||
attention_scores = ops.multiply( | ||
attention_scores, | ||
ops.cast(self._inv_norm_factor, self.compute_dtype), | ||
) | ||
if self.sliding_window_size: | ||
attention_mask = self._mask_sliding_window( | ||
attention_mask, | ||
cache_update_index=cache_update_index | ||
if cache_update_index is not None | ||
else 0, | ||
) | ||
attention_scores = self._masked_softmax( | ||
attention_scores, attention_mask | ||
) | ||
attention_scores = ops.cast(attention_scores, self.compute_dtype) | ||
attention_output = ops.einsum( | ||
self._combine_equation, attention_scores, value | ||
) | ||
|
||
return attention_output | ||
|
||
def _mask_sliding_window( | ||
self, | ||
attention_mask, | ||
cache_update_index=0, | ||
): | ||
"""Creates and combines a sliding window mask with the attention mask. | ||
|
||
Args: | ||
attention_mask: Original attention mask. | ||
cache_update_index: Starting index for the sliding window. | ||
|
||
Returns: | ||
Combined attention mask with sliding window constraints. | ||
""" | ||
_, query_len, key_len = ops.shape(attention_mask) | ||
# Compute the sliding window for square attention. | ||
all_ones = ops.ones((key_len, key_len), "bool") | ||
if keras.config.backend() == "tensorflow": | ||
# TODO: trui/tril has issues with dynamic shape on the tensorflow | ||
# backend. We should fix, but use `band_part` for now. | ||
import tensorflow as tf | ||
|
||
band_size = ops.minimum(key_len, self.sliding_window_size - 1) | ||
band_size = ops.cast(band_size, "int32") | ||
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) | ||
else: | ||
sliding_mask = ops.triu( | ||
all_ones, -1 * self.sliding_window_size + 1 | ||
) * ops.tril(all_ones, self.sliding_window_size - 1) | ||
# Slice the window for short queries during generation. | ||
start = (cache_update_index, 0) | ||
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) | ||
sliding_mask = ops.expand_dims(sliding_mask, 0) | ||
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_query_heads": self.num_query_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"rope_max_wavelength": self.rope_max_wavelength, | ||
"rope_scaling_factor": self.rope_scaling_factor, | ||
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"dropout": self.dropout, | ||
"sliding_window_size": self.sliding_window_size, | ||
"head_dim": self.head_dim, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
} | ||
) | ||
return config |
Uh oh!
There was an error while loading. Please reload this page.