From 84eb19328d5510585fa0e2f9c4d5ac884050630c Mon Sep 17 00:00:00 2001 From: HayrapetyanZhirayr Date: Wed, 2 Jul 2025 15:45:25 +0000 Subject: [PATCH] Fix qwen3 model implementation in torchtune, closes #2866 --- torchtune/models/qwen3/_attention.py | 279 ++++++++++ torchtune/models/qwen3/_component_builders.py | 501 ++++++++++++++++++ torchtune/models/qwen3/_model_builders.py | 47 +- 3 files changed, 804 insertions(+), 23 deletions(-) create mode 100644 torchtune/models/qwen3/_attention.py create mode 100644 torchtune/models/qwen3/_component_builders.py diff --git a/torchtune/models/qwen3/_attention.py b/torchtune/models/qwen3/_attention.py new file mode 100644 index 0000000000..8ab3e50fc8 --- /dev/null +++ b/torchtune/models/qwen3/_attention.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.kv_cache import KVCache + +logger = logging.getLogger(__name__) + + +class Qwen3Attention(nn.Module): + """ + Basically, it is standard multihead attention, but with QK-norm applied before + the RoPE. It is unusual for most of the models, but Qwen3 became an exception to the rule. + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + Default value is 0.0. + + Raises: + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + if ``embed_dim % num_heads != 0``, **or** + if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.cache_enabled = True + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.decoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) # [b, s, n_h*h_d] + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) # [b, s, n_h, h_d] + + # Normalize q + if self.q_norm is not None: + q = q.transpose(1, 2) # [b, s, n_h, h_d] + q = self.q_norm(q) # [b, n_h, s, h_d] + q = q.transpose(1, 2) # [b, s, n_h, h_d] + + # Apply positional embeddings after q-norm + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) # [b, s, n_h, h_d] + + q = q.transpose(1, 2) # [b, n_h, s, h_d] + + if y is None: + if self.kv_cache is None or not self.cache_enabled: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k,v shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) # [b, s, n_kv*h_d] + v = self.v_proj(y) # [b, s, n_kv*h_d] + + # Apply positional embeddings + # k,v shape: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) # [b, s, n_kv, h_d] + v = v.view(b, s_y, -1, self.head_dim) # [b, s, n_kv, h_d] + + # Normalize k + if self.k_norm is not None: + k = k.transpose(1, 2) # [b, n_kv, s, h_d] + k = self.k_norm(k) # [b, n_kv, s, h_d] + k = k.transpose(1, 2) # [b, s, n_kv, h_d] + + # Apply positional embeddings after k-norm + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) # [b, s, n_h, h_d] + + k = k.transpose(1, 2) # [b, n_kv, s, h_d] + v = v.transpose(1, 2) # [b, n_kv, s, h_d] + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + # k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d] + if self.num_heads != self.num_kv_heads: + expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + + output = self._attention_call( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout if self.training else 0.0, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) \ No newline at end of file diff --git a/torchtune/models/qwen3/_component_builders.py b/torchtune/models/qwen3/_component_builders.py new file mode 100644 index 0000000000..d40c9f710d --- /dev/null +++ b/torchtune/models/qwen3/_component_builders.py @@ -0,0 +1,501 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Optional +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torch import nn +from torchtune.modules.transformer import TransformerDecoder +from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from torchtune.models.qwen3._attention import Qwen3Attention + +from torchtune.modules import ( + FeedForward, + RMSNorm, + TransformerSelfAttentionLayer, + TiedLinear +) + + +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear + +""" +Component builders for the Qwen3 model and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``MultiHeadAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + + +def qwen3( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + head_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 1_000_000.0, + tie_word_embeddings: bool = False, + q_proj_bias: bool = True, + k_proj_bias: bool = True, + v_proj_bias: bool = True, + q_norm: bool = False, + k_norm: bool = False, +) -> TransformerDecoder: + """ + Build the decoder associated with the Qwen2 model. This includes: + - Token embeddings + - num_layers number of TransformerSelfAttentionLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + head_dim (Optional[int]): Dimension of each attention head. If not + specified, it defaults to `embed_dim // num_heads`. In GQA, `head_dim` is not necessarily equal to + `embed_dim // num_heads`, so this parameter allows the caller to explicitly specify a custom value. + norm_eps (float): epsilon in RMS norms. + rope_base (float): the base period of the RoPE embeddings. + tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied. + q_proj_bias (bool): whether to use bias in the query projection. + k_proj_bias (bool): whether to use bias in the key projection. + v_proj_bias (bool): whether to use bias in the value projection. + q_norm (bool): whether to use normalization in the query projection. + k_norm (bool): whether to use normalization in the key projection. + + Returns: + TransformerDecoder: Instantiation of Qwen2 model. + """ + head_dim = head_dim or embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = Qwen2RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = Qwen3Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=q_proj_bias), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=k_proj_bias), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=v_proj_bias), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + pos_embeddings=rope, + q_norm=RMSNorm(dim=head_dim, eps=norm_eps) if q_norm else None, # norm on head_dim + k_norm=RMSNorm(dim=head_dim, eps=norm_eps) if k_norm else None, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = qwen3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + if tie_word_embeddings: + output_proj = TiedLinear(tok_embeddings) + else: + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def qwen3_mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Qwen2 model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) + + +def lora_qwen3( + lora_attn_modules: list[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # qwen2 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + head_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 1_000_000.0, + tie_word_embeddings: bool = False, + q_proj_bias: bool = True, + k_proj_bias: bool = True, + v_proj_bias: bool = True, + q_norm: bool = False, + k_norm: bool = False, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Qwen2 (an instance of :func:`~torchtune.models.qwen2.transformer.Qwen2TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (list[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + rope_base (float): the base period of the RoPE embeddings. + tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied. + q_proj_bias (bool): whether to use bias in the query projection. + k_proj_bias (bool): whether to use bias in the key projection. + v_proj_bias (bool): whether to use bias in the value projection. + q_norm (bool): whether to use normalization in the query projection. + k_norm (bool): whether to use normalization in the key projection. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Qwen2 model with LoRA applied to + a subset of the attention projections in each layer. + + Raises: + ValueError: if ``apply_lora_to_output`` and ``tie_word_embeddings``. + + """ + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = lora_qwen3_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + head_dim=head_dim, + attn_dropout=attn_dropout, + norm_eps=norm_eps, + q_proj_bias=q_proj_bias, + k_proj_bias=k_proj_bias, + v_proj_bias=v_proj_bias, + q_norm=q_norm, + k_norm=k_norm, + rope_base=rope_base, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_qwen3_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + use_dora=use_dora, + lora_dropout=lora_dropout, + ) + else: + mlp = qwen3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + layers.append(layer) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + if tie_word_embeddings: + if apply_lora_to_output: + raise ValueError( + "apply_lora_to_output is incompatible with tie_word_embeddings," + " as there would be no output to apply lora to!" + ) + output_proj = TiedLinear(tok_embeddings) + else: + # TODO: quantize_base is not applied to final output_proj currently. + adapter_cls = DoRALinear if use_dora else LoRALinear + output_proj = ( + adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model + + +def lora_qwen3_self_attention( + lora_modules: list[LORA_ATTN_MODULES], + *, + # MultiHeadAttention args + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_seq_len: int, + head_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: float = 1_000_000.0, + q_proj_bias: bool = True, + k_proj_bias: bool = True, + v_proj_bias: bool = True, + q_norm: bool = False, + k_norm: bool = False, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> Qwen3Attention: + """ + Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (list[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms. Default: 1e-5 + rope_base (float): the base period of the RoPE embeddings. Default: 1_000_000.0 + q_proj_bias (bool): whether to use bias in the query projection. + k_proj_bias (bool): whether to use bias in the key projection. + v_proj_bias (bool): whether to use bias in the value projection. + q_norm (bool): whether to use normalization in the query projection. + k_norm (bool): whether to use normalization in the key projection. + head_dim (Optional[int]): the dimension of each head. If not specified, is computed + as `embed_dim` // `num_heads` + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + MultiHeadAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") + + head_dim = head_dim or embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + adapter_cls = DoRALinear if use_dora else LoRALinear + q_proj = ( + adapter_cls( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=q_proj_bias, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else nn.Linear(embed_dim, num_heads * head_dim, bias=q_proj_bias) + ) + k_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=k_proj_bias, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=k_proj_bias) + ) + v_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + use_bias=v_proj_bias, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=v_proj_bias) + ) + output_proj = ( + adapter_cls( + num_heads * head_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else nn.Linear(embed_dim, embed_dim, bias=False) + ) + rope = Qwen2RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = Qwen3Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + q_norm=RMSNorm(dim=head_dim, eps=norm_eps) if q_norm else None, + k_norm=RMSNorm(dim=head_dim, eps=norm_eps) if k_norm else None, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_qwen3_mlp( + *, + dim: int, + hidden_dim: int, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> FeedForward: + adapter_cls = DoRALinear if use_dora else LoRALinear + gate_proj = adapter_cls( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = adapter_cls( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + up_proj = adapter_cls( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward( + gate_proj=gate_proj, + down_proj=down_proj, + up_proj=up_proj, + ) \ No newline at end of file diff --git a/torchtune/models/qwen3/_model_builders.py b/torchtune/models/qwen3/_model_builders.py index e656346176..731033ce39 100644 --- a/torchtune/models/qwen3/_model_builders.py +++ b/torchtune/models/qwen3/_model_builders.py @@ -7,7 +7,8 @@ from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType -from torchtune.models.qwen2._component_builders import lora_qwen2, qwen2 +from torchtune.models.qwen3._component_builders import lora_qwen3, qwen3 + from torchtune.models.qwen3._tokenizer import QWEN3_SPECIAL_TOKENS, Qwen3Tokenizer from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES @@ -28,7 +29,7 @@ def qwen3_0_6b_base() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 0.6B base model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=28, num_heads=16, @@ -57,7 +58,7 @@ def qwen3_0_6b_instruct() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 0.6B instruct model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=28, num_heads=16, @@ -86,7 +87,7 @@ def qwen3_1_7b_base() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 1.7B base model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=28, num_heads=16, @@ -115,7 +116,7 @@ def qwen3_1_7b_instruct() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 1.7B instruct model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=28, num_heads=16, @@ -144,7 +145,7 @@ def qwen3_4b_base() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 4B base model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=36, num_heads=32, @@ -173,7 +174,7 @@ def qwen3_4b_instruct() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 4B instruct model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=36, num_heads=32, @@ -202,7 +203,7 @@ def qwen3_8b_base() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 8B base model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=36, num_heads=32, @@ -230,7 +231,7 @@ def qwen3_8b_instruct() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 8B instruct model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=36, num_heads=32, @@ -258,7 +259,7 @@ def qwen3_14b_base() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 14B model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=40, num_heads=40, @@ -286,7 +287,7 @@ def qwen3_14b_instruct() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 14B instruct model """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=40, num_heads=40, @@ -314,7 +315,7 @@ def qwen3_32b() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Qwen3 32B instruct model (there's no base variant for the 32B) """ - return qwen2( + return qwen3( vocab_size=151936, num_layers=64, num_heads=64, @@ -419,7 +420,7 @@ def lora_qwen3_0_6b_base( Note: Qwen3 0.6B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -485,7 +486,7 @@ def lora_qwen3_0_6b_instruct( Note: The base and instruct versions have the exact same arch for all Qwen3 model sizes, except for `max_seq_len`. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -552,7 +553,7 @@ def lora_qwen3_1_7b_base( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -619,7 +620,7 @@ def lora_qwen3_1_7b_instruct( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -685,7 +686,7 @@ def lora_qwen3_4b_base( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -751,7 +752,7 @@ def lora_qwen3_4b_instruct( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -817,7 +818,7 @@ def lora_qwen3_8b_base( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -882,7 +883,7 @@ def lora_qwen3_8b_instruct( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -947,7 +948,7 @@ def lora_qwen3_14b_base( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -1012,7 +1013,7 @@ def lora_qwen3_14b_instruct( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output, @@ -1077,7 +1078,7 @@ def lora_qwen3_32b( The base and instruct versions have slightly different architectures for all Qwen3 model sizes except 0.5B and 3B. Make sure to select the correct model builder for the weights. """ - return lora_qwen2( + return lora_qwen3( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output,