From 2356954d580765a5c24e9ef7221af9ada738da29 Mon Sep 17 00:00:00 2001 From: kxxx Date: Sat, 14 Jun 2025 01:22:07 +0800 Subject: [PATCH 01/28] mom --- fla/__init__.py | 13 +- fla/layers/__init__.py | 4 +- fla/layers/mom.py | 613 +++++++++++++++ fla/layers/mom_varlen.py | 724 ++++++++++++++++++ fla/models/__init__.py | 2 + fla/models/mom_gated_deltanet/__init__.py | 14 + .../configuration_mom_gated_deltanet.py | 91 +++ .../modeling_mom_gated_deltanet.py | 561 ++++++++++++++ 8 files changed, 2018 insertions(+), 4 deletions(-) create mode 100644 fla/layers/mom.py create mode 100644 fla/layers/mom_varlen.py create mode 100644 fla/models/mom_gated_deltanet/__init__.py create mode 100644 fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py create mode 100644 fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py diff --git a/fla/__init__.py b/fla/__init__.py index 0d78656de..fcf6ca501 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -20,7 +20,9 @@ PaTHAttention, ReBasedLinearAttention, RWKV6Attention, - RWKV7Attention + RWKV7Attention, + MomGatedDeltaNet, + ) from fla.models import ( ABCForCausalLM, @@ -61,7 +63,9 @@ RWKV7ForCausalLM, RWKV7Model, TransformerForCausalLM, - TransformerModel + TransformerModel, + MomGatedDeltaNetForCausalLM, + MomGatedDeltaNetModel ) __all__ = [ @@ -123,7 +127,10 @@ 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', - 'MesaNet' + 'MesaNet', + 'MomGatedDeltaNetForCausalLM', + 'MomGatedDeltaNetModel', + 'MomGatedDeltaNet' ] __version__ = '0.2.2' diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 76f8b8cc8..553ef9a1a 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -25,6 +25,7 @@ from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention from .rwkv6 import RWKV6Attention from .rwkv7 import RWKV7Attention +from .mom import MomGatedDeltaNet __all__ = [ 'ABCAttention', @@ -51,5 +52,6 @@ 'RodimusAttention', 'SlidingWindowSharedKeyAttention', 'PaTHAttention', - 'MesaNet' + 'MesaNet', + 'MomGatedDeltaNet' ] diff --git a/fla/layers/mom.py b/fla/layers/mom.py new file mode 100644 index 000000000..c246e73b1 --- /dev/null +++ b/fla/layers/mom.py @@ -0,0 +1,613 @@ +# -*- coding: utf-8 -*- + + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules.l2norm import l2_norm +from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + +# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 + +def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): + ''' + Transform input sequences into memory-organized chunks with capacity constraints. + + Processes input sequences by routing tokens to designated memory states according to routing_mask, + sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, + and returns memory-aligned tensors for parallel processing. + + Key operations: + 1. Expands input tensors when multiple memories are selected per token (top-k routing) + 2. Sorts tokens globally by (batch_idx, memory_idx) to group memory-assigned tokens + 3. Applies capacity-aware truncation (left-truncate oldest tokens when exceeding capacity) + 4. Pads memory chunks to uniform length for tensorization + + Args: + x: Input hidden states + Shape: (batch_size, seq_len, hidden_size) + routing_mask: Binary mask indicating active memory assignments + Shape: (batch_size, seq_len, num_memories) + num_memories: Total number of memories per batch + selected_memories: Memory indices assigned to each token. When using top-k routing, + this contains k memory indices per token (k >= 1) + Shape: (batch_size, seq_len) for k=1 or (batch_size, seq_len, topk) for k>1 + capacity: Scaling factor for memory capacity calculation. Actual capacity per memory is + ceil(seq_len * capacity), maintaining proportional capacity to sequence length + + Returns: + transformed_x: Memory-organized tensor with zero-padded capacity alignment + Shape: (num_memories, batch_size, capacity_len, hidden_size) + truncation_indices: Original indices used for gathering tokens after capacity truncation + Shape: (batch*num_memories, max_len) + sorted_indices: Sorting indices used to group tokens by memory assignments + Shape: (batch_size*seq_len*topk) + max_len: Maximum tokens per memory + mask: Boolean mask indicating valid (non-padded) positions in transformed_x + Shape: (batch*num_memories, max_len) + ''' + if selected_memories.dim() == 3: + # (batch, seq, topk) + topk = selected_memories.shape[2] + # x (batch, seq, hidden) + x = x.repeat_interleave(topk, dim=1) + # x (batch, seq * topk, hidden) + # (batch, seq, topk) + selected_memories = selected_memories.reshape(selected_memories.shape[0], -1) + # (batch, seq * topk) + + b, s, d = x.shape + x_flat = x.reshape(b * s, d) # [b*s, d] + + with torch.no_grad(): + batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) + batch_indices = batch_indices.expand(b, s).reshape(-1) + # (b * s) + memories_flat = selected_memories.reshape(-1) # [b*s] + + combined = batch_indices * (memories_flat.max() + 1) + memories_flat + sorted_indices = combined.argsort() + + x_sorted = x_flat[sorted_indices] # [b*s, d] + # (b*s, hidden) -> (b, s, hidd) + with torch.no_grad(): + # routing_mask (b, s, num_memories) + batch_memory_tokens = routing_mask.sum(dim=1) + # (b, num_memories) + offset = batch_memory_tokens.cumsum(dim=1) + memory_batch_offset = offset.transpose(0,1) + batch_offset = torch.arange(0, b*s, s, device=offset.device) + memory_batch_offset += batch_offset + flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) + lengths = torch.concat([flatten_offset[:1], flatten_offset[1:] - flatten_offset[:-1]], dim=0) + max_len = lengths.max() + capacity_len = math.ceil(s / topk * capacity) + max_len = min(max_len, capacity_len) + + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + # discard tokens exceed capacity and is far from now + # left pad + truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len + mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) + mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) + truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) + + gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) + transformed_x = gathered_x.reshape(b * num_memories, -1, d) + transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) + pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) + # left pad + transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) + # truncation_indices += capacity_len-max_len + + return transformed_x, truncation_indices, sorted_indices, max_len, mask + # (num_memories, batch, seq, hidden) + +# @torch.jit.script +def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): + ''' + Reconstruct and mix transformed outputs back into the original input sequence shape. + + Key operations: + 1. Reshapes and transposes `transformed_x` to prepare for scattering. + 2. Applies the `mask` to zero out invalid positions. + 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`. + 4. Rearranges the scattered outputs using `sorted_indices` to ensure correct ordering. + 5. Applies the `routing_weights` to weight the outputs. + 6. Sums over the `topk` dimension to produce the final reconstructed output. + + Args: + transformed_x (torch.Tensor): + The transformed output tensor from memory units or experts. + Shape: (num_memories, batch_size, capacity_len, hidden_size) + indices (torch.Tensor): + Indices used for scattering the transformed outputs back to their corresponding positions. + Shape: (batch*num_memories, max_len) + sorted_indices (torch.Tensor): + Sorting indices used to rearrange the scattered outputs back into the original sequence order. + Shape: (batch_size*seq_len*topk) + batch_size (int): + The size of the batch. + seq_len (int): + The length of the input sequence. + topk (int): + The number of top elements selected (`topk`) per token during the selection process. + routing_weights (torch.Tensor): + Routing weights assigned to the top-k selected outputs when reconstructing the final output. + Shape: (batch_size, seq_len, topk) + mask (torch.Tensor): + Boolean mask indicating valid positions in the sequence. + Shape: (batch*num_memories, max_len) + + Returns: + restored_x (torch.Tensor): + The reconstructed output tensor in the original input sequence shape. + Shape: (batch_size, seq_len, hidden_size) + ''' + transformed_x = transformed_x.transpose(0, 1).reshape((-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) + b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] + gathered_x = transformed_x.reshape((transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) + mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) + gathered_x = gathered_x * mask_expanded + + assert (indices >= 0).all(), "Indices should be non-negative" + + resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + 0, + indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), + gathered_x, + ) + assert (indices < resortd_x.size(0)).all(), "Indices should be less than resortd_x size" + + inverse_indices = sorted_indices.argsort() + rearranged_x_flat = resortd_x[inverse_indices] + restored_x = rearranged_x_flat.reshape((b, s * k, h, d)) + restored_x = restored_x.reshape(b, s, k, h, d) * routing_weights.reshape(b, s, k).unsqueeze(-1).unsqueeze(-1) + restored_x = restored_x.sum(dim=2) + return restored_x + + +class MomGatedDeltaNet(nn.Module): + """ + The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa + + Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. + Parameter alloation when use_gate=True: + - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each + - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each + - Others are ignorably small. + - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size + NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. + + Parameter allocation when use_gate=False: + - 1 * hidden_size * hidden_size for the q_proj and k_proj each + - 2 * hidden_size * hidden_size for the v_proj and o_proj each + - Others are ignorably small. + - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size + + Args: + hidden_size (int, Optional): + The hidden size of the input. Default: 2048. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + head_dim (int, Optional): + The dimension of each head. Default: 256. + num_heads (int, Optional): + The number of heads. Default: 4. + mode (str, Optional): + Which Gated DeltaNet kernel to use. + Currently available: `chunk` and `fused_recurrent`. + Default: `chunk`. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `True`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the normalization layer. Default: 1e-5. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + mode: str = 'chunk', + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + num_memories: int = 8, + topk: int = 2, + capacity: float = 1.0, + shared_mem: bool = False, + single_kv_proj: bool = False, + **kwargs + ) -> MomGatedDeltaNet: + super().__init__() + self.num_memories = num_memories + self.topk = topk + self.capacity = capacity + self.shared_mem = shared_mem + self.single_kv_proj = single_kv_proj + + self.mode = mode + + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + + self.key_dim = self.num_heads * self.head_dim + self.value_dim = self.key_dim * self.expand_v + self.head_qk_dim = head_dim + self.head_v_dim = head_dim * self.expand_v + self.layer_idx = layer_idx + self.silu = nn.SiLU() + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.gate = nn.Linear(self.hidden_size, self.num_memories, bias=False) + if self.single_kv_proj: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + else: + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)]) + self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + if self.shared_mem: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + # 🔍 topk gating + router_logits = self.gate(hidden_states) # (bsz, q_len, num_memories) + scores = F.softmax(router_logits, dim=2, dtype=torch.float) + routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype + routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_mask = routing_weights_full.bool().int() + + if self.use_gate: + o_g = self.g_proj(hidden_states) + + batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] + + shared_hidden_states = hidden_states + hidden_states, indices, sorted_indices, max_len, mask = transform(hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) + + q = self.q_proj(hidden_states) + if self.single_kv_proj: + k = self.shared_k(hidden_states) + v = self.shared_v(hidden_states) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + else: + k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) + v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) + beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) + g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + seq_idx=kwargs.get('seq_idx', None) + q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) + q, conv_state_q[0] = self.q_conv1d(x=q, + mask=conv_mask, + cache=conv_state_q[0], + output_final_state=use_cache,seq_idx=seq_idx) + k, conv_state_k[0] = self.k_conv1d(x=k, + mask=conv_mask, + cache=conv_state_k[0], + output_final_state=use_cache,seq_idx=seq_idx) + v, conv_state_v[0] = self.v_conv1d(x=v, + mask=conv_mask, + cache=conv_state_v[0], + output_final_state=use_cache,seq_idx=seq_idx) + + q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) + + else: + q, k, v = self.silu(q), self.silu(k), self.silu(v), + + q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) + + q = l2_norm(q) + k = l2_norm(k) + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(self.num_memories + self.shared_mem)] + offsets = kwargs.get('offsets', None) + # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". + if mode == 'chunk': + o_list = [None for _ in range(self.num_memories)] + for e in range(self.num_memories): + o_e, state_e = chunk_gated_delta_rule( + q=q[e].to(dtype=torch.bfloat16), + k=k[e].to(dtype=torch.bfloat16), + v=v[e].to(dtype=torch.bfloat16), + g=g[e].to(dtype=torch.bfloat16), + beta=beta[e].to(dtype=torch.bfloat16), + initial_state=recurrent_state[e], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + # chunk_gated_delta_rule(q=q[e].to(dtype=torch.bfloat16),k=k[e].to(dtype=torch.bfloat16),v=v[e].to(dtype=torch.bfloat16),g=g[e].to(dtype=torch.bfloat16),beta=beta[e].to(dtype=torch.bfloat16),initial_state=recurrent_state[e],output_final_state=use_cache,cu_seqlens=offsets,head_first=False) + o_e = o_e[:,-max_len:,:,:].to(dtype=q[e].dtype) + o_list[e] = o_e + recurrent_state[e] = state_e + o_list = torch.stack(o_list, dim=0) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + + elif mode == 'fused_recurrent': + o_list = [None for _ in range(self.num_memories)] + for e in range(self.num_memories): + # only activated memory updates + if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1: + o_list[e] = torch.zeros_like(v[e,:,-max_len:,:,:]) + continue + o_e, state_e = fused_recurrent_gated_delta_rule( + q=q[e], + k=k[e], + v=v[e], + g=g[e], + beta=beta[e], + initial_state=recurrent_state[e], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + o_e = o_e[:,-max_len:,:,:] + o_list[e] = o_e + # recurrent_state[e] = state_e + for batch in range(state_e.shape[0]): + if recurrent_state[e] is None: + recurrent_state[e] = state_e + elif hidden_states[e, batch].any(): + recurrent_state[e][batch] = state_e[batch] + o_list = torch.stack(o_list, dim=0) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + + if self.shared_mem: + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) + o += shared_o + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[2] + ) + + if self.use_gate: + o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) + o = self.o_norm(o, o_g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values, router_logits + + + def shared_o( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + recurrent_state = None, + use_cache: Optional[bool] = False, + conv_state_q = [None, None], + conv_state_k = [None, None], + conv_state_v = [None, None], + **kwargs + ) -> torch.Tensor: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + if self.use_short_conv: + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + seq_idx=kwargs.get('seq_idx', None) + q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q[1], + output_final_state=use_cache,seq_idx=seq_idx) + k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), + mask=conv_mask, + cache=conv_state_k[1], + output_final_state=use_cache,seq_idx=seq_idx) + v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), + mask=conv_mask, + cache=conv_state_v[1], + output_final_state=use_cache,seq_idx=seq_idx) + else: + q = self.silu(self.q_proj(hidden_states)) + k = self.silu(self.shared_k(hidden_states)) + v = self.silu(self.shared_v(hidden_states)) + + q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v)) + q = l2_norm(q) + k = l2_norm(k) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[:, -g.shape[-2]:, None]) + + offsets = kwargs.get('offsets', None) + if mode == 'chunk': + o, recurrent_state[-1] = chunk_gated_delta_rule( + q=q.to(dtype=torch.bfloat16), + k=k.to(dtype=torch.bfloat16), + v=v.to(dtype=torch.bfloat16), + g=g.to(dtype=torch.bfloat16), + beta=beta.to(dtype=torch.bfloat16), + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + o = o.to(dtype=q.dtype) + elif mode == 'fused_recurrent': + o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + return o \ No newline at end of file diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py new file mode 100644 index 000000000..be6891d7f --- /dev/null +++ b/fla/layers/mom_varlen.py @@ -0,0 +1,724 @@ +# -*- coding: utf-8 -*- + + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules.l2norm import l2_norm +from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule) + + + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + +from transformers.utils import is_flash_attn_2_available +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + from flash_attn.bert_padding import pad_input +else: + print("flash_attn_2 is not available") + +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + gate_layer: torch.Tensor, + beta_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) + beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + gate_layer, + beta_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + +# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 + +def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): + ''' + Transform input sequences into memory-organized chunks with capacity constraints. + + Processes input sequences by routing tokens to designated memory states according to routing_mask, + sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, + and returns memory-aligned tensors for parallel processing. + + Key operations: + 1. Expands input tensors when multiple memories are selected per token (top-k routing) + 2. Sorts tokens globally by (batch_idx, memory_idx) to group memory-assigned tokens + 3. Applies capacity-aware truncation (left-truncate oldest tokens when exceeding capacity) + 4. Pads memory chunks to uniform length for tensorization + + Args: + x: Input hidden states + Shape: (batch_size, seq_len, hidden_size) + routing_mask: Binary mask indicating active memory assignments + Shape: (batch_size, seq_len, num_memories) + num_memories: Total number of memories per batch + selected_memories: Memory indices assigned to each token. When using top-k routing, + this contains k memory indices per token (k >= 1) + Shape: (batch_size, seq_len) for k=1 or (batch_size, seq_len, topk) for k>1 + capacity: Scaling factor for memory capacity calculation. Actual capacity per memory is + ceil(seq_len * capacity), maintaining proportional capacity to sequence length + + Returns: + transformed_x: Memory-organized tensor with zero-padded capacity alignment + Shape: (num_memories, batch_size, capacity_len, hidden_size) + truncation_indices: Original indices used for gathering tokens after capacity truncation + Shape: (batch*num_memories, max_len) + sorted_indices: Sorting indices used to group tokens by memory assignments + Shape: (batch_size*seq_len*topk) + max_len: Maximum tokens per memory + mask: Boolean mask indicating valid (non-padded) positions in transformed_x + Shape: (batch*num_memories, max_len) + ''' + if selected_memories.dim() == 3: + # (batch, seq, topk) + topk = selected_memories.shape[2] + # x (batch, seq, hidden) + x = x.repeat_interleave(topk, dim=1) + # x (batch, seq * topk, hidden) + # (batch, seq, topk) + selected_memories = selected_memories.reshape(selected_memories.shape[0], -1) + # (batch, seq * topk) + + b, s, d = x.shape + x_flat = x.reshape(b * s, d) # [b*s, d] + + with torch.no_grad(): + batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) + batch_indices = batch_indices.expand(b, s).reshape(-1) + # (b * s) + memories_flat = selected_memories.reshape(-1) # [b*s] + + combined = batch_indices * (memories_flat.max() + 1) + memories_flat + sorted_indices = combined.argsort() + + x_sorted = x_flat[sorted_indices] # [b*s, d] + # (b*s, hidden) -> (b, s, hidd) + with torch.no_grad(): + # routing_mask (b, s, num_memories) + batch_memory_tokens = routing_mask.sum(dim=1) + # (b, num_memories) + offset = batch_memory_tokens.cumsum(dim=1) + memory_batch_offset = offset.transpose(0,1) + batch_offset = torch.arange(0, b*s, s, device=offset.device) + memory_batch_offset += batch_offset + flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) + lengths = torch.concat([flatten_offset[:1], flatten_offset[1:] - flatten_offset[:-1]], dim=0) + max_len = lengths.max() + capacity_len = math.ceil(s / topk * capacity) + max_len = min(max_len, capacity_len) + + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + # discard tokens exceed capacity and is far from now + # left pad + truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len + mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) + mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) + truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) + + gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) + transformed_x = gathered_x.reshape(b * num_memories, -1, d) + transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) + pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) + pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) + # left pad + transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) + mask_2 = torch.cat((pad_mask.bool(), mask), dim=1).reshape((b, num_memories, capacity_len)).transpose(0, 1) + # truncation_indices += capacity_len-max_len + + return transformed_x, truncation_indices, sorted_indices, max_len, mask,mask_2 + # (num_memories, batch, seq, hidden) + +# @torch.jit.script +def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): + ''' + Reconstruct and mix transformed outputs back into the original input sequence shape. + + Key operations: + 1. Reshapes and transposes `transformed_x` to prepare for scattering. + 2. Applies the `mask` to zero out invalid positions. + 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`. + 4. Rearranges the scattered outputs using `sorted_indices` to ensure correct ordering. + 5. Applies the `routing_weights` to weight the outputs. + 6. Sums over the `topk` dimension to produce the final reconstructed output. + + Args: + transformed_x (torch.Tensor): + The transformed output tensor from memory units or experts. + Shape: (num_memories, batch_size, capacity_len, hidden_size) + indices (torch.Tensor): + Indices used for scattering the transformed outputs back to their corresponding positions. + Shape: (batch*num_memories, max_len) + sorted_indices (torch.Tensor): + Sorting indices used to rearrange the scattered outputs back into the original sequence order. + Shape: (batch_size*seq_len*topk) + batch_size (int): + The size of the batch. + seq_len (int): + The length of the input sequence. + topk (int): + The number of top elements selected (`topk`) per token during the selection process. + routing_weights (torch.Tensor): + Routing weights assigned to the top-k selected outputs when reconstructing the final output. + Shape: (batch_size, seq_len, topk) + mask (torch.Tensor): + Boolean mask indicating valid positions in the sequence. + Shape: (batch*num_memories, max_len) + + Returns: + restored_x (torch.Tensor): + The reconstructed output tensor in the original input sequence shape. + Shape: (batch_size, seq_len, hidden_size) + ''' + transformed_x = transformed_x.transpose(0, 1).reshape((-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) + b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] + gathered_x = transformed_x.reshape((transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) + mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) + gathered_x = gathered_x * mask_expanded + + assert (indices >= 0).all(), "Indices should be non-negative" + + resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + 0, + indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), + gathered_x, + ) + assert (indices < resortd_x.size(0)).all(), "Indices should be less than resortd_x size" + + inverse_indices = sorted_indices.argsort() + rearranged_x_flat = resortd_x[inverse_indices] + restored_x = rearranged_x_flat.reshape((b, s * k, h, d)) + restored_x = restored_x.reshape(b, s, k, h, d) * routing_weights.reshape(b, s, k).unsqueeze(-1).unsqueeze(-1) + restored_x = restored_x.sum(dim=2) + return restored_x + + +class MomGatedDeltaNet(nn.Module): + """ + The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa + + Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. + Parameter alloation when use_gate=True: + - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each + - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each + - Others are ignorably small. + - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size + NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. + + Parameter allocation when use_gate=False: + - 1 * hidden_size * hidden_size for the q_proj and k_proj each + - 2 * hidden_size * hidden_size for the v_proj and o_proj each + - Others are ignorably small. + - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size + + Args: + hidden_size (int, Optional): + The hidden size of the input. Default: 2048. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + head_dim (int, Optional): + The dimension of each head. Default: 256. + num_heads (int, Optional): + The number of heads. Default: 4. + mode (str, Optional): + Which Gated DeltaNet kernel to use. + Currently available: `chunk` and `fused_recurrent`. + Default: `chunk`. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `True`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the normalization layer. Default: 1e-5. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + mode: str = 'chunk', + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + num_memories: int = 8, + topk: int = 2, + capacity: float = 1.0, + shared_mem: bool = False, + single_kv_proj: bool = False, + **kwargs + ) -> MomGatedDeltaNet: + super().__init__() + self.num_memories = num_memories + self.topk = topk + self.capacity = capacity + self.shared_mem = shared_mem + self.single_kv_proj = single_kv_proj + + self.mode = mode + + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + + self.key_dim = self.num_heads * self.head_dim + self.value_dim = self.key_dim * self.expand_v + self.head_qk_dim = head_dim + self.head_v_dim = head_dim * self.expand_v + self.layer_idx = layer_idx + self.silu = nn.SiLU() + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.gate = nn.Linear(self.hidden_size, self.num_memories, bias=False) + if self.single_kv_proj: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + else: + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)]) + self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + if self.shared_mem: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = None + batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1] + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + # 🔍 topk gating + router_logits = self.gate(hidden_states) # (bsz, q_len, num_memories) + scores = F.softmax(router_logits, dim=2, dtype=torch.float) + routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype + routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_mask = routing_weights_full.bool().int() + + if self.use_gate: + o_g = self.g_proj(hidden_states) + + batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] + + shared_hidden_states = hidden_states + hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform(hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) + + q = self.q_proj(hidden_states) + if self.single_kv_proj: + k = self.shared_k(hidden_states) + v = self.shared_v(hidden_states) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + else: + k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) + v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) + beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) + g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + seq_idx=kwargs.get('seq_idx', None) + q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) + q, conv_state_q[0] = self.q_conv1d(x=q, + mask=conv_mask, + cache=conv_state_q[0], + output_final_state=use_cache,seq_idx=seq_idx) + k, conv_state_k[0] = self.k_conv1d(x=k, + mask=conv_mask, + cache=conv_state_k[0], + output_final_state=use_cache,seq_idx=seq_idx) + v, conv_state_v[0] = self.v_conv1d(x=v, + mask=conv_mask, + cache=conv_state_v[0], + output_final_state=use_cache,seq_idx=seq_idx) + + q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) + + else: + q, k, v = self.silu(q), self.silu(k), self.silu(v), + + q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) + + q = l2_norm(q) + k = l2_norm(k) + + q,k,v,g,beta,mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q,k,v,g,beta,mask2)) + cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) + cu_seqlen = cu_seqlen_all[0].to(torch.long) + cu_q,cu_k,cu_v,cu_g,cu_beta= (x.unsqueeze(0).contiguous() for x in (cu_q,cu_k,cu_v,cu_g,cu_beta)) + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(1 + self.shared_mem)] + offsets = kwargs.get('offsets', None) + # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". + if mode == 'chunk': + o_, recurrent_state_ = chunk_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=recurrent_state[0], + output_final_state=use_cache, + cu_seqlens=cu_seqlen, + head_first=False + ) + recurrent_state[0] = recurrent_state_ + o_ = o_.squeeze(0).contiguous() + o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d',b=batch_size) + o_list = o_list[:,:,-max_len:] + + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + + elif mode == 'fused_recurrent': + o_, recurrent_state_ = fused_recurrent_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=recurrent_state[0], + output_final_state=use_cache, + cu_seqlens=cu_seqlen, + head_first=False + ) + recurrent_state[0] = recurrent_state_ + o_ = o_.squeeze(0).contiguous() + o_list = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d',b=batch_size) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + + if self.shared_mem: + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) + o += shared_o + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[2] + ) + + if self.use_gate: + o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) + o = self.o_norm(o, o_g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values, router_logits + + + def shared_o( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + recurrent_state = None, + use_cache: Optional[bool] = False, + conv_state_q = [None, None], + conv_state_k = [None, None], + conv_state_v = [None, None], + **kwargs + ) -> torch.Tensor: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + if self.use_short_conv: + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + seq_idx=kwargs.get('seq_idx', None) + q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q[1], + output_final_state=use_cache,seq_idx=seq_idx) + k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), + mask=conv_mask, + cache=conv_state_k[1], + output_final_state=use_cache,seq_idx=seq_idx) + v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), + mask=conv_mask, + cache=conv_state_v[1], + output_final_state=use_cache,seq_idx=seq_idx) + else: + q = self.silu(self.q_proj(hidden_states)) + k = self.silu(self.shared_k(hidden_states)) + v = self.silu(self.shared_v(hidden_states)) + + q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v)) + q = l2_norm(q) + k = l2_norm(k) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[:, -g.shape[-2]:, None]) + + offsets = kwargs.get('offsets', None) + if mode == 'chunk': + o, recurrent_state[-1] = chunk_gated_delta_rule( + q=q.to(dtype=torch.bfloat16), + k=k.to(dtype=torch.bfloat16), + v=v.to(dtype=torch.bfloat16), + g=g.to(dtype=torch.bfloat16), + beta=beta.to(dtype=torch.bfloat16), + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + o = o.to(dtype=q.dtype) + elif mode == 'fused_recurrent': + o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + return o \ No newline at end of file diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 0b58315f0..8e2976cae 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -27,6 +27,7 @@ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel +from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig,MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel __all__ = [ 'ABCConfig', 'ABCForCausalLM', 'ABCModel', @@ -52,4 +53,5 @@ 'RodimusConfig', 'RodimusForCausalLM', 'RodimusModel', 'RodimusTokenizer', 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', + 'MomGatedDeltaNetConfig','MomGatedDeltaNetForCausalLM','MomGatedDeltaNetModel' ] diff --git a/fla/models/mom_gated_deltanet/__init__.py b/fla/models/mom_gated_deltanet/__init__.py new file mode 100644 index 000000000..8f1a34949 --- /dev/null +++ b/fla/models/mom_gated_deltanet/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \ + MomGatedDeltaNetConfig +from fla.models.mom_gated_deltanet.modeling_mom_gated_deltanet import ( + MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel) + +AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig) +AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel) +AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM) + +__all__ = ['MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel'] diff --git a/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py b/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py new file mode 100644 index 000000000..72163c6ad --- /dev/null +++ b/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class MomGatedDeltaNetConfig(PretrainedConfig): + model_type = 'mom_gated_deltanet' + keys_to_ignore_at_inference = ['past_key_values'] + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_v: int = 2, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + head_dim: int = 256, + num_heads: int = 6, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 21, + norm_first: bool = False, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + num_memories: int = 8, + topk: int = 2, + capacity: float = 1.0, + use_layer_wise_balance: bool=True, + aux_loss_scale: float=0.01, + shared_mem: bool = False, + single_kv_proj: bool = False, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.head_dim = head_dim + self.num_heads = num_heads + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_first = norm_first + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.num_memories = num_memories + self.topk = topk + self.capacity = capacity + self.use_layer_wise_balance = use_layer_wise_balance + self.aux_loss_scale = aux_loss_scale + self.shared_mem = shared_mem + self.single_kv_proj = single_kv_proj + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py new file mode 100644 index 000000000..e2115596a --- /dev/null +++ b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py @@ -0,0 +1,561 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from dataclasses import dataclass +import math +import warnings +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.attn import Attention +from fla.layers import MomGatedDeltaNet +from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \ + MomGatedDeltaNetConfig +from fla.models.utils import Cache +from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, + RMSNorm) +from fla.modules.activations import swiglu_linear +from fla.modules.layernorm import rms_norm_linear + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class MomGatedDeltaNetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish', + norm_first: bool = True, + norm_eps: float = 1e-5 + ) -> MomGatedDeltaNetMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + + if norm_first: + self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward( + self, + x: torch.Tensor, + **kwargs: Unpack[Dict], + ) -> torch.Tensor: + if self.norm_first: + x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias) + else: + x = self.gate_proj(x) + gate, y = x.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class MomGatedDeltaNetBlock(nn.Module): + def __init__(self, config: MomGatedDeltaNetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if not config.norm_first: + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = MomGatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx, + num_memories=config.num_memories, + topk=config.topk, + capacity=config.capacity, + shared_mem=config.shared_mem, + single_kv_proj=config.single_kv_proj + ) + if not config.norm_first: + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = MomGatedDeltaNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + norm_first=config.norm_first, + norm_eps=config.norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + if hasattr(self, 'attn_norm'): + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values, router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if hasattr(self, 'mlp_norm'): + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values, router_logits) + + return outputs + + +class MomGatedDeltaNetPreTrainedModel(PreTrainedModel): + + config_class = MomGatedDeltaNetConfig + supports_gradient_checkpointing = True + _no_split_modules = ['GatedDeltaNetBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + +@dataclass +class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel): + + def __init__(self, config: MomGatedDeltaNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([MomGatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return MomGatedDeltaNetOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + +@dataclass +class MomGatedDeltaNetCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +class MomGatedDeltaNetForCausalLM(MomGatedDeltaNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MomGatedDeltaNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.num_memories = config.num_memories + self.topk = config.topk + self.aux_loss_scale = config.aux_loss_scale + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + num_logits_to_keep: Optional[int] = 0, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if num_logits_to_keep is not None: + model_inputs['num_logits_to_keep'] = num_logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + 'num_logits_to_keep': num_logits_to_keep, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + num_logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:]) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + if fuse_linear_and_cross_entropy: + loss_fct = FusedLinearCrossEntropyLoss() + else: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = loss_fct(hidden_states.view(-1, self.config.hidden_size), + labels.view(-1), + self.lm_head.weight, + self.lm_head.bias) + else: + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + if loss==0: + breakpoint() + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.num_memories, + self.topk, + use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + ) + aux_loss *= self.aux_loss_scale + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MomGatedDeltaNetCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses From 75a2b9bd21565acf6d98e4b41be4be81576e0994 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 30 Jun 2025 14:20:51 +0800 Subject: [PATCH 02/28] Add `MomGatedDeltaNet` back to the `__init__.py` module exports --- fla/layers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index cfa42cbae..010e99a6e 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -19,6 +19,7 @@ from .mamba import Mamba from .mamba2 import Mamba2 from .mesa_net import MesaNet +from .mom import MomGatedDeltaNet from .multiscale_retention import MultiScaleRetention from .nsa import NativeSparseAttention from .path_attn import PaTHAttention @@ -26,7 +27,6 @@ from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention from .rwkv6 import RWKV6Attention from .rwkv7 import RWKV7Attention -from .mom import MomGatedDeltaNet __all__ = [ 'ABCAttention', @@ -46,7 +46,7 @@ 'LinearAttention', 'Mamba', 'Mamba2', - 'MomGatedDeltaNet' + 'MomGatedDeltaNet', 'MultiScaleRetention', 'NativeSparseAttention', 'ReBasedLinearAttention', From e043a6765324a9992f36a00f22a686b8c39023e9 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Mon, 30 Jun 2025 14:26:49 +0800 Subject: [PATCH 03/28] Cleaned up import statements in __init__.py --- fla/__init__.py | 9 ++++----- fla/models/__init__.py | 4 ++-- fla/models/mom_gated_deltanet/__init__.py | 6 ++---- .../configuration_mom_gated_deltanet.py | 7 ++++--- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fla/__init__.py b/fla/__init__.py index 7e79a16cd..a98afb7d6 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -23,8 +23,7 @@ ReBasedLinearAttention, RodimusAttention, RWKV6Attention, - RWKV7Attention, - + RWKV7Attention ) from fla.models import ( ABCForCausalLM, @@ -53,8 +52,8 @@ LinearAttentionModel, MesaNetForCausalLM, MesaNetModel, - MomGatedDeltaNetForCausalLM, - MomGatedDeltaNetModel + MomGatedDeltaNetForCausalLM, + MomGatedDeltaNetModel, NSAForCausalLM, NSAModel, PaTHAttentionForCausalLM, @@ -68,7 +67,7 @@ RWKV7ForCausalLM, RWKV7Model, TransformerForCausalLM, - TransformerModel, + TransformerModel ) __all__ = [ diff --git a/fla/models/__init__.py b/fla/models/__init__.py index f3a8c4901..4c4a0a4c2 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -20,6 +20,7 @@ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel +from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel @@ -28,7 +29,6 @@ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel -from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig,MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel __all__ = [ 'ABCConfig', 'ABCForCausalLM', 'ABCModel', @@ -55,5 +55,5 @@ 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', 'RodimusConfig', 'RodimusForCausalLM', 'RodimusModel', 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', - 'MomGatedDeltaNetConfig','MomGatedDeltaNetForCausalLM','MomGatedDeltaNetModel' + 'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel' ] diff --git a/fla/models/mom_gated_deltanet/__init__.py b/fla/models/mom_gated_deltanet/__init__.py index 8f1a34949..42798e8e3 100644 --- a/fla/models/mom_gated_deltanet/__init__.py +++ b/fla/models/mom_gated_deltanet/__init__.py @@ -2,10 +2,8 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM -from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \ - MomGatedDeltaNetConfig -from fla.models.mom_gated_deltanet.modeling_mom_gated_deltanet import ( - MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel) +from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import MomGatedDeltaNetConfig +from fla.models.mom_gated_deltanet.modeling_mom_gated_deltanet import MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig) AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel) diff --git a/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py b/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py index 72163c6ad..4d12bec37 100644 --- a/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py +++ b/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py @@ -8,6 +8,7 @@ class MomGatedDeltaNetConfig(PretrainedConfig): model_type = 'mom_gated_deltanet' keys_to_ignore_at_inference = ['past_key_values'] + def __init__( self, attn_mode: str = "chunk", @@ -37,8 +38,8 @@ def __init__( num_memories: int = 8, topk: int = 2, capacity: float = 1.0, - use_layer_wise_balance: bool=True, - aux_loss_scale: float=0.01, + use_layer_wise_balance: bool = True, + aux_loss_scale: float = 0.01, shared_mem: bool = False, single_kv_proj: bool = False, **kwargs @@ -88,4 +89,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) From 63e188540bfb8b3faea386490dd9a4dabd25c7cc Mon Sep 17 00:00:00 2001 From: Jusen Date: Wed, 9 Jul 2025 11:34:12 +0800 Subject: [PATCH 04/28] Fix the issue of gradients being NaN --- fla/layers/__init__.py | 2 +- fla/layers/mom_varlen.py | 2 +- .../modeling_mom_gated_deltanet.py | 29 ++++++++++--------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 010e99a6e..afd878573 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -19,7 +19,7 @@ from .mamba import Mamba from .mamba2 import Mamba2 from .mesa_net import MesaNet -from .mom import MomGatedDeltaNet +from .mom_varlen import MomGatedDeltaNet from .multiscale_retention import MultiScaleRetention from .nsa import NativeSparseAttention from .path_attn import PaTHAttention diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index be6891d7f..1b651ad6b 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -567,7 +567,7 @@ def forward( q,k,v,g,beta,mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q,k,v,g,beta,mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) - cu_seqlen = cu_seqlen_all[0].to(torch.long) + cu_seqlen = cu_seqlen_all[0].to(torch.long).unique() cu_q,cu_k,cu_v,cu_g,cu_beta= (x.unsqueeze(0).contiguous() for x in (cu_q,cu_k,cu_v,cu_g,cu_beta)) # dealing with padding diff --git a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py index e2115596a..e98026733 100644 --- a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py +++ b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py @@ -437,6 +437,7 @@ def forward( logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:]) loss = None + aux_loss = None if labels is not None: if self.config.fuse_cross_entropy: if fuse_linear_and_cross_entropy: @@ -455,21 +456,21 @@ def forward( self.lm_head.bias) else: loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - if loss==0: - breakpoint() - valid_router_logits = tuple( - logits - for logits in (outputs.router_logits if return_dict else outputs[-1]) - if logits is not None - ) - aux_loss = load_balancing_loss_func( - valid_router_logits, - self.num_memories, - self.topk, - use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ - ) - aux_loss *= self.aux_loss_scale + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.num_memories, + self.topk, + use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + ) + aux_loss *= self.aux_loss_scale + + loss += aux_loss if not return_dict: output = (logits,) + outputs[1:] From 7f0fc0c9dc7cdadec5f0edd39489221ab77f182b Mon Sep 17 00:00:00 2001 From: Jusen Date: Wed, 9 Jul 2025 12:32:37 +0800 Subject: [PATCH 05/28] Cleanup code --- fla/layers/mom_varlen.py | 21 +++++++++++-------- .../modeling_mom_gated_deltanet.py | 3 ++- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index 1b651ad6b..2ca88eb7f 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -17,7 +17,6 @@ fused_recurrent_gated_delta_rule) - if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -30,9 +29,11 @@ else: print("flash_attn_2 is not available") + def elu_p1(x): return (F.elu(x, 1., False) + 1.).to(x) + def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. @@ -60,7 +61,6 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.T ) - def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, @@ -239,7 +239,7 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se mask_2 = torch.cat((pad_mask.bool(), mask), dim=1).reshape((b, num_memories, capacity_len)).transpose(0, 1) # truncation_indices += capacity_len-max_len - return transformed_x, truncation_indices, sorted_indices, max_len, mask,mask_2 + return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 # (num_memories, batch, seq, hidden) # @torch.jit.script @@ -540,7 +540,7 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + seq_idx = kwargs.get('seq_idx', None) q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) q, conv_state_q[0] = self.q_conv1d(x=q, mask=conv_mask, @@ -565,10 +565,10 @@ def forward( q = l2_norm(q) k = l2_norm(k) - q,k,v,g,beta,mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q,k,v,g,beta,mask2)) + q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) cu_seqlen = cu_seqlen_all[0].to(torch.long).unique() - cu_q,cu_k,cu_v,cu_g,cu_beta= (x.unsqueeze(0).contiguous() for x in (cu_q,cu_k,cu_v,cu_g,cu_beta)) + cu_q, cu_k, cu_v, cu_g, cu_beta= (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) # dealing with padding if attention_mask is not None: @@ -667,15 +667,18 @@ def shared_o( q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q[1], - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=seq_idx) k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), mask=conv_mask, cache=conv_state_k[1], - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=seq_idx) v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), mask=conv_mask, cache=conv_state_v[1], - output_final_state=use_cache,seq_idx=seq_idx) + output_final_state=use_cache, + seq_idx=seq_idx) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.shared_k(hidden_states)) diff --git a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py index e98026733..ed2906ea8 100644 --- a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py +++ b/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py @@ -202,6 +202,7 @@ def _init_weights( with torch.no_grad(): p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + @dataclass class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None @@ -312,6 +313,7 @@ def forward( router_logits=all_router_logits ) + @dataclass class MomGatedDeltaNetCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None @@ -532,7 +534,6 @@ def load_balancing_loss_func( routing_weights = routing_weights.softmax(dim=-1) routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) - # cast the expert indices to int64, otherwise one-hot encoding will fail if selected_experts.dtype != torch.int64: selected_experts = selected_experts.to(torch.int64) From dc7cf66701aa37338fb6c936d5f5668db319dcfd Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sun, 13 Jul 2025 14:01:07 +0800 Subject: [PATCH 06/28] Update mom.py --- fla/layers/mom.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index c246e73b1..26f177a85 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -22,15 +22,6 @@ from fla.models.utils import Cache -def elu_p1(x): - return (F.elu(x, 1., False) + 1.).to(x) - - -def sum_norm(x): - return (x / x.sum(-1, keepdim=True)).to(x) - -# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 - def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): ''' Transform input sequences into memory-organized chunks with capacity constraints. @@ -453,7 +444,7 @@ def forward( g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(self.num_memories + self.shared_mem)] - offsets = kwargs.get('offsets', None) + cu_seqlens = kwargs.get('cu_seqlens', None) # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". if mode == 'chunk': o_list = [None for _ in range(self.num_memories)] @@ -466,10 +457,9 @@ def forward( beta=beta[e].to(dtype=torch.bfloat16), initial_state=recurrent_state[e], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) - # chunk_gated_delta_rule(q=q[e].to(dtype=torch.bfloat16),k=k[e].to(dtype=torch.bfloat16),v=v[e].to(dtype=torch.bfloat16),g=g[e].to(dtype=torch.bfloat16),beta=beta[e].to(dtype=torch.bfloat16),initial_state=recurrent_state[e],output_final_state=use_cache,cu_seqlens=offsets,head_first=False) + # chunk_gated_delta_rule(q=q[e].to(dtype=torch.bfloat16),k=k[e].to(dtype=torch.bfloat16),v=v[e].to(dtype=torch.bfloat16),g=g[e].to(dtype=torch.bfloat16),beta=beta[e].to(dtype=torch.bfloat16),initial_state=recurrent_state[e],output_final_state=use_cache,cu_seqlens=cu_seqlens) o_e = o_e[:,-max_len:,:,:].to(dtype=q[e].dtype) o_list[e] = o_e recurrent_state[e] = state_e @@ -491,8 +481,7 @@ def forward( beta=beta[e], initial_state=recurrent_state[e], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) o_e = o_e[:,-max_len:,:,:] o_list[e] = o_e @@ -581,7 +570,7 @@ def shared_o( beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) g = g.mul(attention_mask[:, -g.shape[-2]:, None]) - offsets = kwargs.get('offsets', None) + cu_seqlens = kwargs.get('cu_seqlens', None) if mode == 'chunk': o, recurrent_state[-1] = chunk_gated_delta_rule( q=q.to(dtype=torch.bfloat16), @@ -591,8 +580,7 @@ def shared_o( beta=beta.to(dtype=torch.bfloat16), initial_state=recurrent_state[-1], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) o = o.to(dtype=q.dtype) elif mode == 'fused_recurrent': @@ -604,10 +592,9 @@ def shared_o( beta=beta, initial_state=recurrent_state[-1], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") - return o \ No newline at end of file + return o From e274e441b7e31046fd1f31e3b1c7e74cdf5fa869 Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 13 Jul 2025 15:37:02 +0800 Subject: [PATCH 07/28] Change model name & Update conv api --- fla/__init__.py | 6 +- fla/layers/mom_varlen.py | 72 ++++++++++-------- fla/models/__init__.py | 4 +- fla/models/mom/__init__.py | 12 +++ .../configuration_mom.py} | 9 ++- .../modeling_mom.py} | 75 ++++++++++--------- fla/models/mom_gated_deltanet/__init__.py | 12 --- 7 files changed, 103 insertions(+), 87 deletions(-) create mode 100644 fla/models/mom/__init__.py rename fla/models/{mom_gated_deltanet/configuration_mom_gated_deltanet.py => mom/configuration_mom.py} (92%) rename fla/models/{mom_gated_deltanet/modeling_mom_gated_deltanet.py => mom/modeling_mom.py} (92%) delete mode 100644 fla/models/mom_gated_deltanet/__init__.py diff --git a/fla/__init__.py b/fla/__init__.py index 011dcfa28..c8625c07f 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -53,8 +53,8 @@ LinearAttentionModel, MesaNetForCausalLM, MesaNetModel, - MomGatedDeltaNetForCausalLM, - MomGatedDeltaNetModel, + MomForCausalLM, + MomModel, MLAForCausalLM, MLAModel, NSAForCausalLM, @@ -89,7 +89,7 @@ 'LightNetAttention', 'LightNetForCausalLM', 'LightNetModel', 'LinearAttention', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 'MesaNet', 'MesaNetForCausalLM', 'MesaNetModel', - 'MomGatedDeltaNet', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel', + 'MomGatedDeltaNet', 'MomForCausalLM', 'MomModel', 'MultiheadLatentAttention', 'MLAForCausalLM', 'MLAModel', 'MultiScaleRetention', 'RetNetForCausalLM', 'RetNetModel', 'NativeSparseAttention', 'NSAForCausalLM', 'NSAModel', diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index 2ca88eb7f..c12604834 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -447,16 +447,19 @@ def __init__( self.q_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) else: @@ -534,26 +537,30 @@ def forward( v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) - + if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None - seq_idx = kwargs.get('seq_idx', None) q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) - q, conv_state_q[0] = self.q_conv1d(x=q, - mask=conv_mask, - cache=conv_state_q[0], - output_final_state=use_cache,seq_idx=seq_idx) - k, conv_state_k[0] = self.k_conv1d(x=k, - mask=conv_mask, - cache=conv_state_k[0], - output_final_state=use_cache,seq_idx=seq_idx) - v, conv_state_v[0] = self.v_conv1d(x=v, - mask=conv_mask, - cache=conv_state_v[0], - output_final_state=use_cache,seq_idx=seq_idx) + q, conv_state_q[0] = self.q_conv1d( + x=q, + cache=conv_state_q[0], + output_final_state=use_cache, + cu_seqlens=None + ) + k, conv_state_k[0] = self.k_conv1d( + x=k, + cache=conv_state_k[0], + output_final_state=use_cache, + cu_seqlens=None + ) + v, conv_state_v[0] = self.v_conv1d( + x=v, + cache=conv_state_v[0], + output_final_state=use_cache, + cu_seqlens=None + ) q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) @@ -662,23 +669,24 @@ def shared_o( assert mode == 'chunk', "Only chunk mode is supported in training." if self.use_short_conv: - conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) - q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), - mask=conv_mask, - cache=conv_state_q[1], - output_final_state=use_cache, - seq_idx=seq_idx) - k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), - mask=conv_mask, - cache=conv_state_k[1], - output_final_state=use_cache, - seq_idx=seq_idx) - v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), - mask=conv_mask, - cache=conv_state_v[1], - output_final_state=use_cache, - seq_idx=seq_idx) + q, conv_state_q[1] = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q[1], + output_final_state=use_cache, + cu_seqlens=None + ) + k, conv_state_k[1] = self.k_conv1d( + x=self.shared_k(hidden_states), + cache=conv_state_k[1], + output_final_state=use_cache, + cu_seqlens=None + ) + v, conv_state_v[1] = self.v_conv1d( + x=self.shared_v(hidden_states), + cache=conv_state_v[1], + output_final_state=use_cache, + cu_seqlens=None + ) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.shared_k(hidden_states)) diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 1590d16e6..53d861b73 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -20,7 +20,7 @@ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel -from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel +from fla.models.mom import MomConfig, MomForCausalLM, MomModel from fla.models.mla import MLAConfig, MLAForCausalLM, MLAModel from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel @@ -48,7 +48,7 @@ 'MambaConfig', 'MambaForCausalLM', 'MambaModel', 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model', 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', - 'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel', + 'MomConfig', 'MomForCausalLM', 'MomModel', 'MLAConfig', 'MLAForCausalLM', 'MLAModel', 'NSAConfig', 'NSAForCausalLM', 'NSAModel', 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', diff --git a/fla/models/mom/__init__.py b/fla/models/mom/__init__.py new file mode 100644 index 000000000..70e12a0aa --- /dev/null +++ b/fla/models/mom/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mom.configuration_mom import MomConfig +from fla.models.mom.modeling_mom import MomForCausalLM, MomModel + +AutoConfig.register(MomConfig.model_type, MomConfig) +AutoModel.register(MomConfig, MomModel) +AutoModelForCausalLM.register(MomConfig, MomForCausalLM) + +__all__ = ['MomConfig', 'MomForCausalLM', 'MomModel'] diff --git a/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py b/fla/models/mom/configuration_mom.py similarity index 92% rename from fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py rename to fla/models/mom/configuration_mom.py index 4d12bec37..0691d5d5f 100644 --- a/fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py +++ b/fla/models/mom/configuration_mom.py @@ -5,8 +5,8 @@ from transformers.configuration_utils import PretrainedConfig -class MomGatedDeltaNetConfig(PretrainedConfig): - model_type = 'mom_gated_deltanet' +class MomConfig(PretrainedConfig): + model_type = 'mom' keys_to_ignore_at_inference = ['past_key_values'] def __init__( @@ -42,6 +42,7 @@ def __init__( aux_loss_scale: float = 0.01, shared_mem: bool = False, single_kv_proj: bool = False, + mom_backend: str = 'GDN', **kwargs ): self.attn_mode = attn_mode @@ -72,6 +73,10 @@ def __init__( self.aux_loss_scale = aux_loss_scale self.shared_mem = shared_mem self.single_kv_proj = single_kv_proj + self.mom_backend = mom_backend + + if not self.mom_backend in ['GDN']: + raise NotImplementedError("The MoM backend is not currently implemented.") if attn is not None: if not isinstance(attn, Dict): diff --git a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py b/fla/models/mom/modeling_mom.py similarity index 92% rename from fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py rename to fla/models/mom/modeling_mom.py index ed2906ea8..8273b86e3 100644 --- a/fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py +++ b/fla/models/mom/modeling_mom.py @@ -19,8 +19,8 @@ from fla.layers.attn import Attention from fla.layers import MomGatedDeltaNet -from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \ - MomGatedDeltaNetConfig +from fla.models.mom.configuration_mom import \ + MomConfig from fla.models.utils import Cache from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm) @@ -34,7 +34,7 @@ logger = logging.get_logger(__name__) -class MomGatedDeltaNetMLP(nn.Module): +class MomMLP(nn.Module): def __init__( self, @@ -44,7 +44,7 @@ def __init__( hidden_act: str = 'swish', norm_first: bool = True, norm_eps: float = 1e-5 - ) -> MomGatedDeltaNetMLP: + ) -> MomMLP: super().__init__() self.hidden_size = hidden_size @@ -79,8 +79,8 @@ def forward( return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) -class MomGatedDeltaNetBlock(nn.Module): - def __init__(self, config: MomGatedDeltaNetConfig, layer_idx: int): +class MomBlock(nn.Module): + def __init__(self, config: MomConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -96,27 +96,30 @@ def __init__(self, config: MomGatedDeltaNetConfig, layer_idx: int): layer_idx=layer_idx ) else: - self.attn = MomGatedDeltaNet( - mode=config.attn_mode, - hidden_size=config.hidden_size, - expand_v=config.expand_v, - head_dim=config.head_dim, - num_heads=config.num_heads, - use_gate=config.use_gate, - use_short_conv=config.use_short_conv, - conv_size=config.conv_size, - norm_first=config.norm_first, - norm_eps=config.norm_eps, - layer_idx=layer_idx, - num_memories=config.num_memories, - topk=config.topk, - capacity=config.capacity, - shared_mem=config.shared_mem, - single_kv_proj=config.single_kv_proj - ) + if config.mom_backend == 'GDN': + self.attn = MomGatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + layer_idx=layer_idx, + num_memories=config.num_memories, + topk=config.topk, + capacity=config.capacity, + shared_mem=config.shared_mem, + single_kv_proj=config.single_kv_proj + ) + else: + raise NotImplementedError("The MoM backend is not currently implemented.") if not config.norm_first: self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) - self.mlp = MomGatedDeltaNetMLP( + self.mlp = MomMLP( hidden_size=config.hidden_size, hidden_ratio=config.hidden_ratio, intermediate_size=config.intermediate_size, @@ -158,9 +161,9 @@ def forward( return outputs -class MomGatedDeltaNetPreTrainedModel(PreTrainedModel): +class MomPreTrainedModel(PreTrainedModel): - config_class = MomGatedDeltaNetConfig + config_class = MomConfig supports_gradient_checkpointing = True _no_split_modules = ['GatedDeltaNetBlock'] @@ -204,18 +207,18 @@ def _init_weights( @dataclass -class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): +class MomOutputWithPast(BaseModelOutputWithPast): router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None -class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel): +class MomModel(MomPreTrainedModel): - def __init__(self, config: MomGatedDeltaNetConfig): + def __init__(self, config: MomConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([MomGatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([MomBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) self.gradient_checkpointing = False @@ -305,7 +308,7 @@ def forward( if not return_dict: return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) - return MomGatedDeltaNetOutputWithPast( + return MomOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, @@ -315,17 +318,17 @@ def forward( @dataclass -class MomGatedDeltaNetCausalLMOutputWithPast(CausalLMOutputWithPast): +class MomCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None -class MomGatedDeltaNetForCausalLM(MomGatedDeltaNetPreTrainedModel, GenerationMixin): +class MomForCausalLM(MomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = MomGatedDeltaNetModel(config) + self.model = MomModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.num_memories = config.num_memories @@ -478,7 +481,7 @@ def forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - return MomGatedDeltaNetCausalLMOutputWithPast( + return MomCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, diff --git a/fla/models/mom_gated_deltanet/__init__.py b/fla/models/mom_gated_deltanet/__init__.py deleted file mode 100644 index 42798e8e3..000000000 --- a/fla/models/mom_gated_deltanet/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- - -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM - -from fla.models.mom_gated_deltanet.configuration_mom_gated_deltanet import MomGatedDeltaNetConfig -from fla.models.mom_gated_deltanet.modeling_mom_gated_deltanet import MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel - -AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig) -AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel) -AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM) - -__all__ = ['MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel'] From 7ab9b43168cc5e1300c6f275aa6cde996260da2d Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sun, 13 Jul 2025 10:04:12 +0000 Subject: [PATCH 08/28] Fix format issues --- fla/__init__.py | 8 +- fla/layers/__init__.py | 4 +- fla/layers/mom.py | 122 ++++++++++++++++------------ fla/layers/mom_varlen.py | 96 ++++++++++++---------- fla/models/mom/configuration_mom.py | 6 +- fla/models/mom/modeling_mom.py | 18 ++-- 6 files changed, 141 insertions(+), 113 deletions(-) diff --git a/fla/__init__.py b/fla/__init__.py index c8625c07f..557a6860b 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -16,7 +16,7 @@ LightNetAttention, LinearAttention, MesaNet, - MomGatedDeltaNet, + MomAttention, MultiheadLatentAttention, MultiScaleRetention, NativeSparseAttention, @@ -53,10 +53,10 @@ LinearAttentionModel, MesaNetForCausalLM, MesaNetModel, - MomForCausalLM, - MomModel, MLAForCausalLM, MLAModel, + MomForCausalLM, + MomModel, NSAForCausalLM, NSAModel, PaTHAttentionForCausalLM, @@ -89,7 +89,7 @@ 'LightNetAttention', 'LightNetForCausalLM', 'LightNetModel', 'LinearAttention', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 'MesaNet', 'MesaNetForCausalLM', 'MesaNetModel', - 'MomGatedDeltaNet', 'MomForCausalLM', 'MomModel', + 'MomAttention', 'MomForCausalLM', 'MomModel', 'MultiheadLatentAttention', 'MLAForCausalLM', 'MLAModel', 'MultiScaleRetention', 'RetNetForCausalLM', 'RetNetModel', 'NativeSparseAttention', 'NSAForCausalLM', 'NSAModel', diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 92451b7de..a602bd3fb 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -19,8 +19,8 @@ from .mamba import Mamba from .mamba2 import Mamba2 from .mesa_net import MesaNet -from .mom_varlen import MomGatedDeltaNet from .mla import MultiheadLatentAttention +from .mom_varlen import MomAttention from .multiscale_retention import MultiScaleRetention from .nsa import NativeSparseAttention from .path_attn import PaTHAttention @@ -47,8 +47,8 @@ 'LinearAttention', 'Mamba', 'Mamba2', - 'MomGatedDeltaNet', 'MesaNet', + 'MomAttention', 'MultiheadLatentAttention', 'MultiScaleRetention', 'NativeSparseAttention', diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 26f177a85..c89ffb54b 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -13,8 +13,7 @@ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution from fla.modules.l2norm import l2_norm -from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, - fused_recurrent_gated_delta_rule) +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -25,7 +24,7 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): ''' Transform input sequences into memory-organized chunks with capacity constraints. - + Processes input sequences by routing tokens to designated memory states according to routing_mask, sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, and returns memory-aligned tensors for parallel processing. @@ -88,7 +87,7 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se batch_memory_tokens = routing_mask.sum(dim=1) # (b, num_memories) offset = batch_memory_tokens.cumsum(dim=1) - memory_batch_offset = offset.transpose(0,1) + memory_batch_offset = offset.transpose(0, 1) batch_offset = torch.arange(0, b*s, s, device=offset.device) memory_batch_offset += batch_offset flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) @@ -97,12 +96,14 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se capacity_len = math.ceil(s / topk * capacity) max_len = min(max_len, capacity_len) - indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( + b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) # discard tokens exceed capacity and is far from now # left pad truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) - mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) + mask = torch.bitwise_and(mask, truncation_indices >= torch.cat( + (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) @@ -117,6 +118,8 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se # (num_memories, batch, seq, hidden) # @torch.jit.script + + def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): ''' Reconstruct and mix transformed outputs back into the original input sequence shape. @@ -157,15 +160,17 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens The reconstructed output tensor in the original input sequence shape. Shape: (batch_size, seq_len, hidden_size) ''' - transformed_x = transformed_x.transpose(0, 1).reshape((-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) + transformed_x = transformed_x.transpose(0, 1).reshape( + (-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] - gathered_x = transformed_x.reshape((transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) + gathered_x = transformed_x.reshape( + (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) gathered_x = gathered_x * mask_expanded assert (indices >= 0).all(), "Indices should be non-negative" - resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( 0, indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), gathered_x, @@ -180,7 +185,7 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens return restored_x -class MomGatedDeltaNet(nn.Module): +class MomAttention(nn.Module): """ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa @@ -246,7 +251,7 @@ def __init__( shared_mem: bool = False, single_kv_proj: bool = False, **kwargs - ) -> MomGatedDeltaNet: + ) -> MomAttention: super().__init__() self.num_memories = num_memories self.topk = topk @@ -284,10 +289,14 @@ def __init__( self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) else: - self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) - self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) - self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) + for _ in range(self.num_memories)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) + for _ in range(self.num_memories)]) + self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories)]) + self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories)]) if self.shared_mem: self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) @@ -385,16 +394,18 @@ def forward( routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype - routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), + dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() if self.use_gate: o_g = self.g_proj(hidden_states) - + batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] shared_hidden_states = hidden_states - hidden_states, indices, sorted_indices, max_len, mask = transform(hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) + hidden_states, indices, sorted_indices, max_len, mask = transform( + hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) q = self.q_proj(hidden_states) if self.single_kv_proj: @@ -406,27 +417,29 @@ def forward( k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) - g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) + g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) + for i, a_expert in enumerate(self.a_proj)], dim=0) if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + conv_mask = attention_mask[:, -hidden_states.shape[2] + :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + seq_idx = kwargs.get('seq_idx', None) q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) q, conv_state_q[0] = self.q_conv1d(x=q, - mask=conv_mask, - cache=conv_state_q[0], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_q[0], + output_final_state=use_cache, seq_idx=seq_idx) k, conv_state_k[0] = self.k_conv1d(x=k, - mask=conv_mask, - cache=conv_state_k[0], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_k[0], + output_final_state=use_cache, seq_idx=seq_idx) v, conv_state_v[0] = self.v_conv1d(x=v, - mask=conv_mask, - cache=conv_state_v[0], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_v[0], + output_final_state=use_cache, seq_idx=seq_idx) q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) @@ -434,7 +447,7 @@ def forward( q, k, v = self.silu(q), self.silu(k), self.silu(v), q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) - + q = l2_norm(q) k = l2_norm(k) @@ -443,7 +456,8 @@ def forward( beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) - recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(self.num_memories + self.shared_mem)] + recurrent_state = last_state['recurrent_state'] if last_state is not None else [ + None for _ in range(self.num_memories + self.shared_mem)] cu_seqlens = kwargs.get('cu_seqlens', None) # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". if mode == 'chunk': @@ -460,18 +474,19 @@ def forward( cu_seqlens=cu_seqlens, ) # chunk_gated_delta_rule(q=q[e].to(dtype=torch.bfloat16),k=k[e].to(dtype=torch.bfloat16),v=v[e].to(dtype=torch.bfloat16),g=g[e].to(dtype=torch.bfloat16),beta=beta[e].to(dtype=torch.bfloat16),initial_state=recurrent_state[e],output_final_state=use_cache,cu_seqlens=cu_seqlens) - o_e = o_e[:,-max_len:,:,:].to(dtype=q[e].dtype) + o_e = o_e[:, -max_len:, :, :].to(dtype=q[e].dtype) o_list[e] = o_e recurrent_state[e] = state_e o_list = torch.stack(o_list, dim=0) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, + batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': o_list = [None for _ in range(self.num_memories)] for e in range(self.num_memories): # only activated memory updates if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1: - o_list[e] = torch.zeros_like(v[e,:,-max_len:,:,:]) + o_list[e] = torch.zeros_like(v[e, :, -max_len:, :, :]) continue o_e, state_e = fused_recurrent_gated_delta_rule( q=q[e], @@ -483,7 +498,7 @@ def forward( output_final_state=use_cache, cu_seqlens=cu_seqlens, ) - o_e = o_e[:,-max_len:,:,:] + o_e = o_e[:, -max_len:, :, :] o_list[e] = o_e # recurrent_state[e] = state_e for batch in range(state_e.shape[0]): @@ -492,10 +507,12 @@ def forward( elif hidden_states[e, batch].any(): recurrent_state[e][batch] = state_e[batch] o_list = torch.stack(o_list, dim=0) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, + batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) if self.shared_mem: - shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, + use_cache, conv_state_q, conv_state_k, conv_state_v) o += shared_o if past_key_values is not None: @@ -516,16 +533,15 @@ def forward( return o, None, past_key_values, router_logits - def shared_o( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - recurrent_state = None, + recurrent_state=None, use_cache: Optional[bool] = False, - conv_state_q = [None, None], - conv_state_k = [None, None], - conv_state_v = [None, None], + conv_state_q=[None, None], + conv_state_k=[None, None], + conv_state_v=[None, None], **kwargs ) -> torch.Tensor: if attention_mask is not None: @@ -541,19 +557,19 @@ def shared_o( if self.use_short_conv: conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx=kwargs.get('seq_idx', None) + seq_idx = kwargs.get('seq_idx', None) q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), - mask=conv_mask, - cache=conv_state_q[1], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_q[1], + output_final_state=use_cache, seq_idx=seq_idx) k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), - mask=conv_mask, - cache=conv_state_k[1], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_k[1], + output_final_state=use_cache, seq_idx=seq_idx) v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), - mask=conv_mask, - cache=conv_state_v[1], - output_final_state=use_cache,seq_idx=seq_idx) + mask=conv_mask, + cache=conv_state_v[1], + output_final_state=use_cache, seq_idx=seq_idx) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.shared_k(hidden_states)) diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index c12604834..c2c381cf2 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -13,9 +13,7 @@ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution from fla.modules.l2norm import l2_norm -from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, - fused_recurrent_gated_delta_rule) - +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -23,9 +21,9 @@ from fla.models.utils import Cache from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input - from flash_attn.bert_padding import pad_input + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input else: print("flash_attn_2 is not available") @@ -140,16 +138,16 @@ def _upad_input( ) - def sum_norm(x): return (x / x.sum(-1, keepdim=True)).to(x) # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 + def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): ''' Transform input sequences into memory-organized chunks with capacity constraints. - + Processes input sequences by routing tokens to designated memory states according to routing_mask, sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, and returns memory-aligned tensors for parallel processing. @@ -212,7 +210,7 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se batch_memory_tokens = routing_mask.sum(dim=1) # (b, num_memories) offset = batch_memory_tokens.cumsum(dim=1) - memory_batch_offset = offset.transpose(0,1) + memory_batch_offset = offset.transpose(0, 1) batch_offset = torch.arange(0, b*s, s, device=offset.device) memory_batch_offset += batch_offset flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) @@ -221,12 +219,14 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se capacity_len = math.ceil(s / topk * capacity) max_len = min(max_len, capacity_len) - indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( + b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) # discard tokens exceed capacity and is far from now # left pad truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) - mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) + mask = torch.bitwise_and(mask, truncation_indices >= torch.cat( + (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) @@ -243,6 +243,8 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se # (num_memories, batch, seq, hidden) # @torch.jit.script + + def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): ''' Reconstruct and mix transformed outputs back into the original input sequence shape. @@ -283,15 +285,17 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens The reconstructed output tensor in the original input sequence shape. Shape: (batch_size, seq_len, hidden_size) ''' - transformed_x = transformed_x.transpose(0, 1).reshape((-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) + transformed_x = transformed_x.transpose(0, 1).reshape( + (-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] - gathered_x = transformed_x.reshape((transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) + gathered_x = transformed_x.reshape( + (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) gathered_x = gathered_x * mask_expanded assert (indices >= 0).all(), "Indices should be non-negative" - resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( 0, indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), gathered_x, @@ -306,7 +310,7 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens return restored_x -class MomGatedDeltaNet(nn.Module): +class MomAttention(nn.Module): """ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa @@ -372,7 +376,7 @@ def __init__( shared_mem: bool = False, single_kv_proj: bool = False, **kwargs - ) -> MomGatedDeltaNet: + ) -> MomAttention: super().__init__() self.num_memories = num_memories self.topk = topk @@ -410,10 +414,14 @@ def __init__( self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) else: - self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) - self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) - self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) + for _ in range(self.num_memories)]) + self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) + for _ in range(self.num_memories)]) + self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories)]) + self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories)]) if self.shared_mem: self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) @@ -505,7 +513,7 @@ def forward( assert mode == 'chunk', "Only chunk mode is supported in training." last_state = None - batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1] + batchsize, q_len = hidden_states.shape[0], hidden_states.shape[1] if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] @@ -515,16 +523,18 @@ def forward( routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype - routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), + dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() if self.use_gate: o_g = self.g_proj(hidden_states) - + batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] shared_hidden_states = hidden_states - hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform(hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) + hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform( + hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) q = self.q_proj(hidden_states) if self.single_kv_proj: @@ -536,8 +546,9 @@ def forward( k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) - g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) - + g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) + for i, a_expert in enumerate(self.a_proj)], dim=0) + if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: @@ -568,21 +579,22 @@ def forward( q, k, v = self.silu(q), self.silu(k), self.silu(v), q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) - + q = l2_norm(q) k = l2_norm(k) q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) cu_seqlen = cu_seqlen_all[0].to(torch.long).unique() - cu_q, cu_k, cu_v, cu_g, cu_beta= (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) + cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) # dealing with padding if attention_mask is not None: beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) - recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(1 + self.shared_mem)] + recurrent_state = last_state['recurrent_state'] if last_state is not None else [ + None for _ in range(1 + self.shared_mem)] offsets = kwargs.get('offsets', None) # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". if mode == 'chunk': @@ -600,10 +612,11 @@ def forward( recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d',b=batch_size) - o_list = o_list[:,:,-max_len:] - - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) + o_list = o_list[:, :, -max_len:] + + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': o_, recurrent_state_ = fused_recurrent_gated_delta_rule( @@ -620,11 +633,13 @@ def forward( recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() o_list = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d',b=batch_size) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) if self.shared_mem: - shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, + use_cache, conv_state_q, conv_state_k, conv_state_v) o += shared_o if past_key_values is not None: @@ -645,16 +660,15 @@ def forward( return o, None, past_key_values, router_logits - def shared_o( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - recurrent_state = None, + recurrent_state=None, use_cache: Optional[bool] = False, - conv_state_q = [None, None], - conv_state_k = [None, None], - conv_state_v = [None, None], + conv_state_q=[None, None], + conv_state_k=[None, None], + conv_state_v=[None, None], **kwargs ) -> torch.Tensor: if attention_mask is not None: @@ -732,4 +746,4 @@ def shared_o( else: raise NotImplementedError(f"Not supported mode `{mode}`.") - return o \ No newline at end of file + return o diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py index 0691d5d5f..9e32e6196 100644 --- a/fla/models/mom/configuration_mom.py +++ b/fla/models/mom/configuration_mom.py @@ -42,7 +42,7 @@ def __init__( aux_loss_scale: float = 0.01, shared_mem: bool = False, single_kv_proj: bool = False, - mom_backend: str = 'GDN', + mom_backend: str = 'gated_deltanet', **kwargs ): self.attn_mode = attn_mode @@ -75,8 +75,8 @@ def __init__( self.single_kv_proj = single_kv_proj self.mom_backend = mom_backend - if not self.mom_backend in ['GDN']: - raise NotImplementedError("The MoM backend is not currently implemented.") + if self.mom_backend not in ['gated_deltanet']: + raise NotImplementedError(f"The MoM backend {mom_backend} is not currently supported.") if attn is not None: if not isinstance(attn, Dict): diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index 8273b86e3..bd8d36aed 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -2,28 +2,24 @@ from __future__ import annotations -from dataclasses import dataclass import math import warnings +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn -import torch.utils.checkpoint from transformers.activations import ACT2FN from transformers.generation import GenerationMixin -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging +from fla.layers import MomAttention from fla.layers.attn import Attention -from fla.layers import MomGatedDeltaNet -from fla.models.mom.configuration_mom import \ - MomConfig +from fla.models.mom.configuration_mom import MomConfig from fla.models.utils import Cache -from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, - RMSNorm) +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm from fla.modules.activations import swiglu_linear from fla.modules.layernorm import rms_norm_linear @@ -97,7 +93,7 @@ def __init__(self, config: MomConfig, layer_idx: int): ) else: if config.mom_backend == 'GDN': - self.attn = MomGatedDeltaNet( + self.attn = MomAttention( mode=config.attn_mode, hidden_size=config.hidden_size, expand_v=config.expand_v, @@ -210,6 +206,7 @@ def _init_weights( class MomOutputWithPast(BaseModelOutputWithPast): router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + class MomModel(MomPreTrainedModel): def __init__(self, config: MomConfig): @@ -322,6 +319,7 @@ class MomCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + class MomForCausalLM(MomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] From 11fa1b6db5f90c4ef0262c5b4c3a143cdeefc799 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sun, 13 Jul 2025 10:12:29 +0000 Subject: [PATCH 09/28] Replace old modules --- fla/layers/mom.py | 21 +++++++++--------- fla/layers/mom_varlen.py | 48 +++++++++++++++------------------------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index c89ffb54b..f27c560db 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - from __future__ import annotations import math @@ -11,7 +10,7 @@ from einops import rearrange from torch.nn import functional as F -from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution from fla.modules.l2norm import l2_norm from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule @@ -190,14 +189,14 @@ class MomAttention(nn.Module): The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. - Parameter alloation when use_gate=True: + Parameter alloation when use_output_gate=True: - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - Others are ignorably small. - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. - Parameter allocation when use_gate=False: + Parameter allocation when use_output_gate=False: - 1 * hidden_size * hidden_size for the q_proj and k_proj each - 2 * hidden_size * hidden_size for the v_proj and o_proj each - Others are ignorably small. @@ -218,7 +217,7 @@ class MomAttention(nn.Module): Default: `chunk`. use_beta (bool, Optional): Whether to use beta. Default: `True`. - use_gate (bool, Optional): + use_output_gate (bool, Optional): Whether to use output gate. Default: `True`. use_short_conv (bool, Optional): Whether to use short convolutions. Default: `True`. @@ -239,7 +238,7 @@ def __init__( head_dim: int = 256, num_heads: int = 6, mode: str = 'chunk', - use_gate: bool = True, + use_output_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, conv_bias: bool = False, @@ -264,7 +263,7 @@ def __init__( self.hidden_size = hidden_size self.expand_v = expand_v - self.use_gate = use_gate + self.use_output_gate = use_output_gate self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias @@ -347,9 +346,9 @@ def __init__( "ShortConvolution is crucial to the performance. " "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." ) - if use_gate: + if use_output_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) - self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) else: self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) @@ -398,7 +397,7 @@ def forward( dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() - if self.use_gate: + if self.use_output_gate: o_g = self.g_proj(hidden_states) batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] @@ -523,7 +522,7 @@ def forward( offset=q.shape[2] ) - if self.use_gate: + if self.use_output_gate: o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) o = self.o_norm(o, o_g) else: diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index c2c381cf2..0378f0443 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - from __future__ import annotations import math @@ -11,7 +10,7 @@ from einops import rearrange from torch.nn import functional as F -from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution from fla.modules.l2norm import l2_norm from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule @@ -138,12 +137,6 @@ def _upad_input( ) -def sum_norm(x): - return (x / x.sum(-1, keepdim=True)).to(x) - -# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 - - def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): ''' Transform input sequences into memory-organized chunks with capacity constraints. @@ -315,14 +308,14 @@ class MomAttention(nn.Module): The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. - Parameter alloation when use_gate=True: + Parameter alloation when use_output_gate=True: - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - Others are ignorably small. - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. - Parameter allocation when use_gate=False: + Parameter allocation when use_output_gate=False: - 1 * hidden_size * hidden_size for the q_proj and k_proj each - 2 * hidden_size * hidden_size for the v_proj and o_proj each - Others are ignorably small. @@ -343,7 +336,7 @@ class MomAttention(nn.Module): Default: `chunk`. use_beta (bool, Optional): Whether to use beta. Default: `True`. - use_gate (bool, Optional): + use_output_gate (bool, Optional): Whether to use output gate. Default: `True`. use_short_conv (bool, Optional): Whether to use short convolutions. Default: `True`. @@ -364,7 +357,7 @@ def __init__( head_dim: int = 256, num_heads: int = 6, mode: str = 'chunk', - use_gate: bool = True, + use_output_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, conv_bias: bool = False, @@ -389,7 +382,7 @@ def __init__( self.hidden_size = hidden_size self.expand_v = expand_v - self.use_gate = use_gate + self.use_output_gate = use_output_gate self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias @@ -475,9 +468,9 @@ def __init__( "ShortConvolution is crucial to the performance. " "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." ) - if use_gate: + if use_output_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) - self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) else: self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) @@ -513,7 +506,7 @@ def forward( assert mode == 'chunk', "Only chunk mode is supported in training." last_state = None - batchsize, q_len = hidden_states.shape[0], hidden_states.shape[1] + _, q_len = hidden_states.shape[0], hidden_states.shape[1] if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] @@ -527,7 +520,7 @@ def forward( dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() - if self.use_gate: + if self.use_output_gate: o_g = self.g_proj(hidden_states) batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] @@ -585,7 +578,7 @@ def forward( q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) - cu_seqlen = cu_seqlen_all[0].to(torch.long).unique() + cu_seqlens = cu_seqlen_all[0].to(torch.long).unique() cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) # dealing with padding @@ -595,8 +588,7 @@ def forward( recurrent_state = last_state['recurrent_state'] if last_state is not None else [ None for _ in range(1 + self.shared_mem)] - offsets = kwargs.get('offsets', None) - # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". + cu_seqlens = kwargs.get('cu_seqlens', None) if mode == 'chunk': o_, recurrent_state_ = chunk_gated_delta_rule( q=cu_q, @@ -606,8 +598,7 @@ def forward( beta=cu_beta, initial_state=recurrent_state[0], output_final_state=use_cache, - cu_seqlens=cu_seqlen, - head_first=False + cu_seqlens=cu_seqlens, ) recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() @@ -627,8 +618,7 @@ def forward( beta=cu_beta, initial_state=recurrent_state[0], output_final_state=use_cache, - cu_seqlens=cu_seqlen, - head_first=False + cu_seqlens=cu_seqlens, ) recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() @@ -650,7 +640,7 @@ def forward( offset=q.shape[2] ) - if self.use_gate: + if self.use_output_gate: o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) o = self.o_norm(o, o_g) else: @@ -717,7 +707,7 @@ def shared_o( beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) g = g.mul(attention_mask[:, -g.shape[-2]:, None]) - offsets = kwargs.get('offsets', None) + cu_seqlens = kwargs.get('cu_seqlens', None) if mode == 'chunk': o, recurrent_state[-1] = chunk_gated_delta_rule( q=q.to(dtype=torch.bfloat16), @@ -727,8 +717,7 @@ def shared_o( beta=beta.to(dtype=torch.bfloat16), initial_state=recurrent_state[-1], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) o = o.to(dtype=q.dtype) elif mode == 'fused_recurrent': @@ -740,8 +729,7 @@ def shared_o( beta=beta, initial_state=recurrent_state[-1], output_final_state=use_cache, - cu_seqlens=offsets, - head_first=False + cu_seqlens=cu_seqlens, ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") From 9f62f28bf83e02b20474a5796af0da44404072ff Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 13 Jul 2025 18:59:18 +0800 Subject: [PATCH 10/28] Update docstring and default config --- fla/layers/mom_varlen.py | 44 +---------------------------- fla/models/mom/configuration_mom.py | 10 +++---- fla/models/mom/modeling_mom.py | 8 +++--- 3 files changed, 10 insertions(+), 52 deletions(-) diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py index 0378f0443..de9a457ad 100644 --- a/fla/layers/mom_varlen.py +++ b/fla/layers/mom_varlen.py @@ -305,49 +305,7 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens class MomAttention(nn.Module): """ - The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa - - Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. - Parameter alloation when use_output_gate=True: - - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - - Others are ignorably small. - - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size - NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. - - Parameter allocation when use_output_gate=False: - - 1 * hidden_size * hidden_size for the q_proj and k_proj each - - 2 * hidden_size * hidden_size for the v_proj and o_proj each - - Others are ignorably small. - - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size - - Args: - hidden_size (int, Optional): - The hidden size of the input. Default: 2048. - expand_v (float, Optional): - The expansion ratio for the value dim. Default: 2.0. - head_dim (int, Optional): - The dimension of each head. Default: 256. - num_heads (int, Optional): - The number of heads. Default: 4. - mode (str, Optional): - Which Gated DeltaNet kernel to use. - Currently available: `chunk` and `fused_recurrent`. - Default: `chunk`. - use_beta (bool, Optional): - Whether to use beta. Default: `True`. - use_output_gate (bool, Optional): - Whether to use output gate. Default: `True`. - use_short_conv (bool, Optional): - Whether to use short convolutions. Default: `True`. - conv_size (int, Optional): - The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. - conv_bias (bool, Optional): - Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. - layer_idx (int, Optional): - The index of the layer. Default: None. - norm_eps (float, Optional): - The epsilon value for the normalization layer. Default: 1e-5. + The layer implementaion for [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685). """ def __init__( diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py index 9e32e6196..9e075987e 100644 --- a/fla/models/mom/configuration_mom.py +++ b/fla/models/mom/configuration_mom.py @@ -12,18 +12,18 @@ class MomConfig(PretrainedConfig): def __init__( self, attn_mode: str = "chunk", - hidden_size: int = 2048, - expand_v: int = 2, + hidden_size: int = 1024, + expand_v: int = 1, use_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, head_dim: int = 256, - num_heads: int = 6, + num_heads: int = 4, max_position_embeddings: int = 2048, hidden_ratio: Optional[int] = 4, intermediate_size: Optional[int] = None, hidden_act: str = "swish", - num_hidden_layers: int = 21, + num_hidden_layers: int = 24, norm_first: bool = False, norm_eps: float = 1e-6, attn: Optional[Dict] = None, @@ -35,7 +35,7 @@ def __init__( initializer_range: float = 0.02, fuse_cross_entropy: bool = True, vocab_size: int = 32000, - num_memories: int = 8, + num_memories: int = 4, topk: int = 2, capacity: float = 1.0, use_layer_wise_balance: bool = True, diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index bd8d36aed..53ee6ef52 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -92,7 +92,7 @@ def __init__(self, config: MomConfig, layer_idx: int): layer_idx=layer_idx ) else: - if config.mom_backend == 'GDN': + if config.mom_backend == 'gated_deltanet': self.attn = MomAttention( mode=config.attn_mode, hidden_size=config.hidden_size, @@ -112,7 +112,7 @@ def __init__(self, config: MomConfig, layer_idx: int): single_kv_proj=config.single_kv_proj ) else: - raise NotImplementedError("The MoM backend is not currently implemented.") + raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") if not config.norm_first: self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) self.mlp = MomMLP( @@ -469,7 +469,7 @@ def forward( valid_router_logits, self.num_memories, self.topk, - use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + use_layer_wise_balance=self.config.use_layer_wise_balance, ) aux_loss *= self.aux_loss_scale @@ -532,7 +532,7 @@ def load_balancing_loss_func( for logits in gate_logits: routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) - routing_weights = routing_weights.softmax(dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) # cast the expert indices to int64, otherwise one-hot encoding will fail From 4e3a997eca7c99b2e8ba3190eddb2e08d68b43fe Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 13 Jul 2025 20:51:14 +0800 Subject: [PATCH 11/28] Remove old ops --- fla/layers/__init__.py | 2 +- fla/layers/mom.py | 324 +++++++++++------- fla/layers/mom_varlen.py | 695 --------------------------------------- 3 files changed, 203 insertions(+), 818 deletions(-) delete mode 100644 fla/layers/mom_varlen.py diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index a602bd3fb..ec3f4b26b 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -20,7 +20,7 @@ from .mamba2 import Mamba2 from .mesa_net import MesaNet from .mla import MultiheadLatentAttention -from .mom_varlen import MomAttention +from .mom import MomAttention from .multiscale_retention import MultiScaleRetention from .nsa import NativeSparseAttention from .path_attn import PaTHAttention diff --git a/fla/layers/mom.py b/fla/layers/mom.py index f27c560db..de9a457ad 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -19,6 +19,123 @@ from fla.models.utils import Cache +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +else: + print("flash_attn_2 is not available") + + +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + gate_layer: torch.Tensor, + beta_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) + beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + gate_layer, + beta_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): ''' @@ -109,11 +226,13 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se transformed_x = gathered_x.reshape(b * num_memories, -1, d) transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) + pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) # left pad transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) + mask_2 = torch.cat((pad_mask.bool(), mask), dim=1).reshape((b, num_memories, capacity_len)).transpose(0, 1) # truncation_indices += capacity_len-max_len - return transformed_x, truncation_indices, sorted_indices, max_len, mask + return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 # (num_memories, batch, seq, hidden) # @torch.jit.script @@ -186,49 +305,7 @@ def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tens class MomAttention(nn.Module): """ - The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa - - Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. - Parameter alloation when use_output_gate=True: - - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - - Others are ignorably small. - - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size - NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. - - Parameter allocation when use_output_gate=False: - - 1 * hidden_size * hidden_size for the q_proj and k_proj each - - 2 * hidden_size * hidden_size for the v_proj and o_proj each - - Others are ignorably small. - - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size - - Args: - hidden_size (int, Optional): - The hidden size of the input. Default: 2048. - expand_v (float, Optional): - The expansion ratio for the value dim. Default: 2.0. - head_dim (int, Optional): - The dimension of each head. Default: 256. - num_heads (int, Optional): - The number of heads. Default: 4. - mode (str, Optional): - Which Gated DeltaNet kernel to use. - Currently available: `chunk` and `fused_recurrent`. - Default: `chunk`. - use_beta (bool, Optional): - Whether to use beta. Default: `True`. - use_output_gate (bool, Optional): - Whether to use output gate. Default: `True`. - use_short_conv (bool, Optional): - Whether to use short convolutions. Default: `True`. - conv_size (int, Optional): - The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. - conv_bias (bool, Optional): - Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. - layer_idx (int, Optional): - The index of the layer. Default: None. - norm_eps (float, Optional): - The epsilon value for the normalization layer. Default: 1e-5. + The layer implementaion for [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685). """ def __init__( @@ -329,16 +406,19 @@ def __init__( self.q_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, kernel_size=conv_size, + bias=conv_bias, activation='silu' ) else: @@ -384,6 +464,7 @@ def forward( assert mode == 'chunk', "Only chunk mode is supported in training." last_state = None + _, q_len = hidden_states.shape[0], hidden_states.shape[1] if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] @@ -403,7 +484,7 @@ def forward( batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] shared_hidden_states = hidden_states - hidden_states, indices, sorted_indices, max_len, mask = transform( + hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform( hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) q = self.q_proj(hidden_states) @@ -423,22 +504,25 @@ def forward( conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - conv_mask = attention_mask[:, -hidden_states.shape[2] - :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None - seq_idx = kwargs.get('seq_idx', None) q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) - q, conv_state_q[0] = self.q_conv1d(x=q, - mask=conv_mask, - cache=conv_state_q[0], - output_final_state=use_cache, seq_idx=seq_idx) - k, conv_state_k[0] = self.k_conv1d(x=k, - mask=conv_mask, - cache=conv_state_k[0], - output_final_state=use_cache, seq_idx=seq_idx) - v, conv_state_v[0] = self.v_conv1d(x=v, - mask=conv_mask, - cache=conv_state_v[0], - output_final_state=use_cache, seq_idx=seq_idx) + q, conv_state_q[0] = self.q_conv1d( + x=q, + cache=conv_state_q[0], + output_final_state=use_cache, + cu_seqlens=None + ) + k, conv_state_k[0] = self.k_conv1d( + x=k, + cache=conv_state_k[0], + output_final_state=use_cache, + cu_seqlens=None + ) + v, conv_state_v[0] = self.v_conv1d( + x=v, + cache=conv_state_v[0], + output_final_state=use_cache, + cu_seqlens=None + ) q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) @@ -450,64 +534,56 @@ def forward( q = l2_norm(q) k = l2_norm(k) + q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) + cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) + cu_seqlens = cu_seqlen_all[0].to(torch.long).unique() + cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) + # dealing with padding if attention_mask is not None: beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) recurrent_state = last_state['recurrent_state'] if last_state is not None else [ - None for _ in range(self.num_memories + self.shared_mem)] + None for _ in range(1 + self.shared_mem)] cu_seqlens = kwargs.get('cu_seqlens', None) - # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens". if mode == 'chunk': - o_list = [None for _ in range(self.num_memories)] - for e in range(self.num_memories): - o_e, state_e = chunk_gated_delta_rule( - q=q[e].to(dtype=torch.bfloat16), - k=k[e].to(dtype=torch.bfloat16), - v=v[e].to(dtype=torch.bfloat16), - g=g[e].to(dtype=torch.bfloat16), - beta=beta[e].to(dtype=torch.bfloat16), - initial_state=recurrent_state[e], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - # chunk_gated_delta_rule(q=q[e].to(dtype=torch.bfloat16),k=k[e].to(dtype=torch.bfloat16),v=v[e].to(dtype=torch.bfloat16),g=g[e].to(dtype=torch.bfloat16),beta=beta[e].to(dtype=torch.bfloat16),initial_state=recurrent_state[e],output_final_state=use_cache,cu_seqlens=cu_seqlens) - o_e = o_e[:, -max_len:, :, :].to(dtype=q[e].dtype) - o_list[e] = o_e - recurrent_state[e] = state_e - o_list = torch.stack(o_list, dim=0) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, - batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o_, recurrent_state_ = chunk_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=recurrent_state[0], + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + recurrent_state[0] = recurrent_state_ + o_ = o_.squeeze(0).contiguous() + o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) + o_list = o_list[:, :, -max_len:] + + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': - o_list = [None for _ in range(self.num_memories)] - for e in range(self.num_memories): - # only activated memory updates - if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1: - o_list[e] = torch.zeros_like(v[e, :, -max_len:, :, :]) - continue - o_e, state_e = fused_recurrent_gated_delta_rule( - q=q[e], - k=k[e], - v=v[e], - g=g[e], - beta=beta[e], - initial_state=recurrent_state[e], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - o_e = o_e[:, -max_len:, :, :] - o_list[e] = o_e - # recurrent_state[e] = state_e - for batch in range(state_e.shape[0]): - if recurrent_state[e] is None: - recurrent_state[e] = state_e - elif hidden_states[e, batch].any(): - recurrent_state[e][batch] = state_e[batch] - o_list = torch.stack(o_list, dim=0) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, - batch_size=q.shape[1], seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o_, recurrent_state_ = fused_recurrent_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=recurrent_state[0], + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + recurrent_state[0] = recurrent_state_ + o_ = o_.squeeze(0).contiguous() + o_list = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) + o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) if self.shared_mem: shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, @@ -555,20 +631,24 @@ def shared_o( assert mode == 'chunk', "Only chunk mode is supported in training." if self.use_short_conv: - conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None - seq_idx = kwargs.get('seq_idx', None) - q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), - mask=conv_mask, - cache=conv_state_q[1], - output_final_state=use_cache, seq_idx=seq_idx) - k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), - mask=conv_mask, - cache=conv_state_k[1], - output_final_state=use_cache, seq_idx=seq_idx) - v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), - mask=conv_mask, - cache=conv_state_v[1], - output_final_state=use_cache, seq_idx=seq_idx) + q, conv_state_q[1] = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q[1], + output_final_state=use_cache, + cu_seqlens=None + ) + k, conv_state_k[1] = self.k_conv1d( + x=self.shared_k(hidden_states), + cache=conv_state_k[1], + output_final_state=use_cache, + cu_seqlens=None + ) + v, conv_state_v[1] = self.v_conv1d( + x=self.shared_v(hidden_states), + cache=conv_state_v[1], + output_final_state=use_cache, + cu_seqlens=None + ) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.shared_k(hidden_states)) diff --git a/fla/layers/mom_varlen.py b/fla/layers/mom_varlen.py deleted file mode 100644 index de9a457ad..000000000 --- a/fla/layers/mom_varlen.py +++ /dev/null @@ -1,695 +0,0 @@ -# -*- coding: utf-8 -*- - -from __future__ import annotations - -import math -from typing import TYPE_CHECKING, Dict, Optional, Tuple - -import torch -import torch.nn as nn -from einops import rearrange -from torch.nn import functional as F - -from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution -from fla.modules.l2norm import l2_norm -from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule - -if TYPE_CHECKING: - from transformers.processing_utils import Unpack - - from fla.models.utils import Cache - -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input -else: - print("flash_attn_2 is not available") - - -def elu_p1(x): - return (F.elu(x, 1., False) + 1.).to(x) - - -def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - gate_layer: torch.Tensor, - beta_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) - beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - gate_layer, - beta_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): - ''' - Transform input sequences into memory-organized chunks with capacity constraints. - - Processes input sequences by routing tokens to designated memory states according to routing_mask, - sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, - and returns memory-aligned tensors for parallel processing. - - Key operations: - 1. Expands input tensors when multiple memories are selected per token (top-k routing) - 2. Sorts tokens globally by (batch_idx, memory_idx) to group memory-assigned tokens - 3. Applies capacity-aware truncation (left-truncate oldest tokens when exceeding capacity) - 4. Pads memory chunks to uniform length for tensorization - - Args: - x: Input hidden states - Shape: (batch_size, seq_len, hidden_size) - routing_mask: Binary mask indicating active memory assignments - Shape: (batch_size, seq_len, num_memories) - num_memories: Total number of memories per batch - selected_memories: Memory indices assigned to each token. When using top-k routing, - this contains k memory indices per token (k >= 1) - Shape: (batch_size, seq_len) for k=1 or (batch_size, seq_len, topk) for k>1 - capacity: Scaling factor for memory capacity calculation. Actual capacity per memory is - ceil(seq_len * capacity), maintaining proportional capacity to sequence length - - Returns: - transformed_x: Memory-organized tensor with zero-padded capacity alignment - Shape: (num_memories, batch_size, capacity_len, hidden_size) - truncation_indices: Original indices used for gathering tokens after capacity truncation - Shape: (batch*num_memories, max_len) - sorted_indices: Sorting indices used to group tokens by memory assignments - Shape: (batch_size*seq_len*topk) - max_len: Maximum tokens per memory - mask: Boolean mask indicating valid (non-padded) positions in transformed_x - Shape: (batch*num_memories, max_len) - ''' - if selected_memories.dim() == 3: - # (batch, seq, topk) - topk = selected_memories.shape[2] - # x (batch, seq, hidden) - x = x.repeat_interleave(topk, dim=1) - # x (batch, seq * topk, hidden) - # (batch, seq, topk) - selected_memories = selected_memories.reshape(selected_memories.shape[0], -1) - # (batch, seq * topk) - - b, s, d = x.shape - x_flat = x.reshape(b * s, d) # [b*s, d] - - with torch.no_grad(): - batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) - batch_indices = batch_indices.expand(b, s).reshape(-1) - # (b * s) - memories_flat = selected_memories.reshape(-1) # [b*s] - - combined = batch_indices * (memories_flat.max() + 1) + memories_flat - sorted_indices = combined.argsort() - - x_sorted = x_flat[sorted_indices] # [b*s, d] - # (b*s, hidden) -> (b, s, hidd) - with torch.no_grad(): - # routing_mask (b, s, num_memories) - batch_memory_tokens = routing_mask.sum(dim=1) - # (b, num_memories) - offset = batch_memory_tokens.cumsum(dim=1) - memory_batch_offset = offset.transpose(0, 1) - batch_offset = torch.arange(0, b*s, s, device=offset.device) - memory_batch_offset += batch_offset - flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) - lengths = torch.concat([flatten_offset[:1], flatten_offset[1:] - flatten_offset[:-1]], dim=0) - max_len = lengths.max() - capacity_len = math.ceil(s / topk * capacity) - max_len = min(max_len, capacity_len) - - indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( - b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) - # discard tokens exceed capacity and is far from now - # left pad - truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len - mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) - mask = torch.bitwise_and(mask, truncation_indices >= torch.cat( - (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) - truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) - - gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) - transformed_x = gathered_x.reshape(b * num_memories, -1, d) - transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) - pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) - pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) - # left pad - transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) - mask_2 = torch.cat((pad_mask.bool(), mask), dim=1).reshape((b, num_memories, capacity_len)).transpose(0, 1) - # truncation_indices += capacity_len-max_len - - return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 - # (num_memories, batch, seq, hidden) - -# @torch.jit.script - - -def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): - ''' - Reconstruct and mix transformed outputs back into the original input sequence shape. - - Key operations: - 1. Reshapes and transposes `transformed_x` to prepare for scattering. - 2. Applies the `mask` to zero out invalid positions. - 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`. - 4. Rearranges the scattered outputs using `sorted_indices` to ensure correct ordering. - 5. Applies the `routing_weights` to weight the outputs. - 6. Sums over the `topk` dimension to produce the final reconstructed output. - - Args: - transformed_x (torch.Tensor): - The transformed output tensor from memory units or experts. - Shape: (num_memories, batch_size, capacity_len, hidden_size) - indices (torch.Tensor): - Indices used for scattering the transformed outputs back to their corresponding positions. - Shape: (batch*num_memories, max_len) - sorted_indices (torch.Tensor): - Sorting indices used to rearrange the scattered outputs back into the original sequence order. - Shape: (batch_size*seq_len*topk) - batch_size (int): - The size of the batch. - seq_len (int): - The length of the input sequence. - topk (int): - The number of top elements selected (`topk`) per token during the selection process. - routing_weights (torch.Tensor): - Routing weights assigned to the top-k selected outputs when reconstructing the final output. - Shape: (batch_size, seq_len, topk) - mask (torch.Tensor): - Boolean mask indicating valid positions in the sequence. - Shape: (batch*num_memories, max_len) - - Returns: - restored_x (torch.Tensor): - The reconstructed output tensor in the original input sequence shape. - Shape: (batch_size, seq_len, hidden_size) - ''' - transformed_x = transformed_x.transpose(0, 1).reshape( - (-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) - b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] - gathered_x = transformed_x.reshape( - (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) - mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) - gathered_x = gathered_x * mask_expanded - - assert (indices >= 0).all(), "Indices should be non-negative" - - resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( - 0, - indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), - gathered_x, - ) - assert (indices < resortd_x.size(0)).all(), "Indices should be less than resortd_x size" - - inverse_indices = sorted_indices.argsort() - rearranged_x_flat = resortd_x[inverse_indices] - restored_x = rearranged_x_flat.reshape((b, s * k, h, d)) - restored_x = restored_x.reshape(b, s, k, h, d) * routing_weights.reshape(b, s, k).unsqueeze(-1).unsqueeze(-1) - restored_x = restored_x.sum(dim=2) - return restored_x - - -class MomAttention(nn.Module): - """ - The layer implementaion for [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685). - """ - - def __init__( - self, - hidden_size: int = 2048, - expand_v: float = 2, - head_dim: int = 256, - num_heads: int = 6, - mode: str = 'chunk', - use_output_gate: bool = True, - use_short_conv: bool = True, - conv_size: int = 4, - conv_bias: bool = False, - layer_idx: int = None, - norm_eps: float = 1e-5, - num_memories: int = 8, - topk: int = 2, - capacity: float = 1.0, - shared_mem: bool = False, - single_kv_proj: bool = False, - **kwargs - ) -> MomAttention: - super().__init__() - self.num_memories = num_memories - self.topk = topk - self.capacity = capacity - self.shared_mem = shared_mem - self.single_kv_proj = single_kv_proj - - self.mode = mode - - self.hidden_size = hidden_size - self.expand_v = expand_v - - self.use_output_gate = use_output_gate - self.use_short_conv = use_short_conv - self.conv_size = conv_size - self.conv_bias = conv_bias - - self.head_dim = head_dim - self.num_heads = num_heads - - self.key_dim = self.num_heads * self.head_dim - self.value_dim = self.key_dim * self.expand_v - self.head_qk_dim = head_dim - self.head_v_dim = head_dim * self.expand_v - self.layer_idx = layer_idx - self.silu = nn.SiLU() - - assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." - - self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) - self.gate = nn.Linear(self.hidden_size, self.num_memories, bias=False) - if self.single_kv_proj: - self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) - self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) - self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) - self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) - else: - self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) - for _ in range(self.num_memories)]) - self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) - for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) - for _ in range(self.num_memories)]) - self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) - for _ in range(self.num_memories)]) - if self.shared_mem: - self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) - self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) - self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) - self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) - - A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) - A_log = torch.log(A) - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - self.D = nn.Parameter(torch.ones(self.num_heads)) - self.D._no_weight_decay = True - # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp( - torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Just to be explicit. Without this we already don't put wd on dt_bias because of the check - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - - if use_short_conv: - self.conv_size = conv_size - self.q_conv1d = ShortConvolution( - hidden_size=self.key_dim, - kernel_size=conv_size, - bias=conv_bias, - activation='silu' - ) - self.k_conv1d = ShortConvolution( - hidden_size=self.key_dim, - kernel_size=conv_size, - bias=conv_bias, - activation='silu' - ) - self.v_conv1d = ShortConvolution( - hidden_size=self.value_dim, - kernel_size=conv_size, - bias=conv_bias, - activation='silu' - ) - else: - raise UserWarning( - "ShortConvolution is crucial to the performance. " - "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." - ) - if use_output_gate: - self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) - self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) - else: - self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) - self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) - self.apply(self._initialize_weights) - - def _initialize_weights(self, module: nn.Module): - if getattr(module, "_is_hf_initialized", False): - return - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) - if module.bias is not None: - nn.init.zeros_(module.bias) - module._is_hf_initialized = True - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs: Unpack[Dict] - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) - - mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode - if self.training: - assert mode == 'chunk', "Only chunk mode is supported in training." - - last_state = None - _, q_len = hidden_states.shape[0], hidden_states.shape[1] - if past_key_values is not None and len(past_key_values) > self.layer_idx: - last_state = past_key_values[self.layer_idx] - - # 🔍 topk gating - router_logits = self.gate(hidden_states) # (bsz, q_len, num_memories) - scores = F.softmax(router_logits, dim=2, dtype=torch.float) - routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype - routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), - dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) - routing_mask = routing_weights_full.bool().int() - - if self.use_output_gate: - o_g = self.g_proj(hidden_states) - - batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] - - shared_hidden_states = hidden_states - hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform( - hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) - - q = self.q_proj(hidden_states) - if self.single_kv_proj: - k = self.shared_k(hidden_states) - v = self.shared_v(hidden_states) - beta = self.shared_b(hidden_states).sigmoid() - g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) - else: - k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) - v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) - beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) - g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) - for i, a_expert in enumerate(self.a_proj)], dim=0) - - if self.use_short_conv: - conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] - if last_state is not None: - conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) - q, conv_state_q[0] = self.q_conv1d( - x=q, - cache=conv_state_q[0], - output_final_state=use_cache, - cu_seqlens=None - ) - k, conv_state_k[0] = self.k_conv1d( - x=k, - cache=conv_state_k[0], - output_final_state=use_cache, - cu_seqlens=None - ) - v, conv_state_v[0] = self.v_conv1d( - x=v, - cache=conv_state_v[0], - output_final_state=use_cache, - cu_seqlens=None - ) - - q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) - - else: - q, k, v = self.silu(q), self.silu(k), self.silu(v), - - q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) - - q = l2_norm(q) - k = l2_norm(k) - - q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) - cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) - cu_seqlens = cu_seqlen_all[0].to(torch.long).unique() - cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) - - # dealing with padding - if attention_mask is not None: - beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) - g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) - - recurrent_state = last_state['recurrent_state'] if last_state is not None else [ - None for _ in range(1 + self.shared_mem)] - cu_seqlens = kwargs.get('cu_seqlens', None) - if mode == 'chunk': - o_, recurrent_state_ = chunk_gated_delta_rule( - q=cu_q, - k=cu_k, - v=cu_v, - g=cu_g, - beta=cu_beta, - initial_state=recurrent_state[0], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - recurrent_state[0] = recurrent_state_ - o_ = o_.squeeze(0).contiguous() - o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) - o_list = o_list[:, :, -max_len:] - - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) - - elif mode == 'fused_recurrent': - o_, recurrent_state_ = fused_recurrent_gated_delta_rule( - q=cu_q, - k=cu_k, - v=cu_v, - g=cu_g, - beta=cu_beta, - initial_state=recurrent_state[0], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - recurrent_state[0] = recurrent_state_ - o_ = o_.squeeze(0).contiguous() - o_list = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) - - if self.shared_mem: - shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, - use_cache, conv_state_q, conv_state_k, conv_state_v) - o += shared_o - - if past_key_values is not None: - past_key_values.update( - recurrent_state=recurrent_state, - conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, - layer_idx=self.layer_idx, - offset=q.shape[2] - ) - - if self.use_output_gate: - o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) - o = self.o_norm(o, o_g) - else: - o = self.o_norm(o) - o = rearrange(o, 'b t h d -> b t (h d)') - o = self.o_proj(o) - - return o, None, past_key_values, router_logits - - def shared_o( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - recurrent_state=None, - use_cache: Optional[bool] = False, - conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], - **kwargs - ) -> torch.Tensor: - if attention_mask is not None: - assert len(attention_mask.shape) == 2, ( - "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " - "for padding purposes (0 indicating padding). " - "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." - ) - - mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode - if self.training: - assert mode == 'chunk', "Only chunk mode is supported in training." - - if self.use_short_conv: - q, conv_state_q[1] = self.q_conv1d( - x=self.q_proj(hidden_states), - cache=conv_state_q[1], - output_final_state=use_cache, - cu_seqlens=None - ) - k, conv_state_k[1] = self.k_conv1d( - x=self.shared_k(hidden_states), - cache=conv_state_k[1], - output_final_state=use_cache, - cu_seqlens=None - ) - v, conv_state_v[1] = self.v_conv1d( - x=self.shared_v(hidden_states), - cache=conv_state_v[1], - output_final_state=use_cache, - cu_seqlens=None - ) - else: - q = self.silu(self.q_proj(hidden_states)) - k = self.silu(self.shared_k(hidden_states)) - v = self.silu(self.shared_v(hidden_states)) - - q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v)) - q = l2_norm(q) - k = l2_norm(k) - beta = self.shared_b(hidden_states).sigmoid() - g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) - - # dealing with padding - if attention_mask is not None: - beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) - g = g.mul(attention_mask[:, -g.shape[-2]:, None]) - - cu_seqlens = kwargs.get('cu_seqlens', None) - if mode == 'chunk': - o, recurrent_state[-1] = chunk_gated_delta_rule( - q=q.to(dtype=torch.bfloat16), - k=k.to(dtype=torch.bfloat16), - v=v.to(dtype=torch.bfloat16), - g=g.to(dtype=torch.bfloat16), - beta=beta.to(dtype=torch.bfloat16), - initial_state=recurrent_state[-1], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - o = o.to(dtype=q.dtype) - elif mode == 'fused_recurrent': - o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( - q=q, - k=k, - v=v, - g=g, - beta=beta, - initial_state=recurrent_state[-1], - output_final_state=use_cache, - cu_seqlens=cu_seqlens, - ) - else: - raise NotImplementedError(f"Not supported mode `{mode}`.") - - return o From 9e2b4b6e25917577e0d08e10e509c00354c5a24d Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sun, 13 Jul 2025 13:38:33 +0000 Subject: [PATCH 12/28] Update some old settings --- fla/layers/mom.py | 75 +++++++++++++++++++---------- fla/models/mom/configuration_mom.py | 36 ++++++++------ fla/models/mom/modeling_mom.py | 68 ++++---------------------- 3 files changed, 78 insertions(+), 101 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index de9a457ad..3a82d3248 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -11,7 +11,6 @@ from torch.nn import functional as F from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution -from fla.modules.l2norm import l2_norm from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule if TYPE_CHECKING: @@ -137,7 +136,13 @@ def _upad_input( ) -def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, capacity: float): +def transform( + x: torch.Tensor, + routing_mask: torch.Tensor, + num_memories: int, + selected_memories: torch.Tensor, + capacity: float +): ''' Transform input sequences into memory-organized chunks with capacity constraints. @@ -233,12 +238,18 @@ def transform(x: torch.Tensor, routing_mask: torch.Tensor, num_memories: int, se # truncation_indices += capacity_len-max_len return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 - # (num_memories, batch, seq, hidden) - -# @torch.jit.script -def reconstruct(transformed_x, indices: torch.Tensor, sorted_indices: torch.Tensor, batch_size: int, seq_len: int, topk: int, routing_weights: torch.Tensor, mask: torch.Tensor): +def reconstruct( + transformed_x, + indices: torch.Tensor, + sorted_indices: torch.Tensor, + batch_size: int, + seq_len: int, + topk: int, + routing_weights: torch.Tensor, + mask: torch.Tensor +): ''' Reconstruct and mix transformed outputs back into the original input sequence shape. @@ -311,9 +322,9 @@ class MomAttention(nn.Module): def __init__( self, hidden_size: int = 2048, - expand_v: float = 2, head_dim: int = 256, - num_heads: int = 6, + num_heads: int = 4, + expand_v: float = 2, mode: str = 'chunk', use_output_gate: bool = True, use_short_conv: bool = True, @@ -348,10 +359,10 @@ def __init__( self.head_dim = head_dim self.num_heads = num_heads - self.key_dim = self.num_heads * self.head_dim - self.value_dim = self.key_dim * self.expand_v + self.key_dim = int(self.num_heads * self.head_dim) + self.value_dim = int(self.key_dim * self.expand_v) self.head_qk_dim = head_dim - self.head_v_dim = head_dim * self.expand_v + self.head_v_dim = int(head_dim * self.expand_v) self.layer_idx = layer_idx self.silu = nn.SiLU() @@ -365,14 +376,22 @@ def __init__( self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) else: - self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) - for _ in range(self.num_memories)]) - self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) - for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) - for _ in range(self.num_memories)]) - self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) - for _ in range(self.num_memories)]) + self.k_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.key_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.v_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.value_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.b_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ]) + self.a_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ]) if self.shared_mem: self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) @@ -474,8 +493,13 @@ def forward( routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype - routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), - dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_weights_full = torch.zeros( + routing_weights.shape[0], + routing_weights.shape[1], + self.num_memories, + dtype=routing_weights.dtype, + device=routing_weights.device + ).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() if self.use_output_gate: @@ -531,9 +555,6 @@ def forward( q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) - q = l2_norm(q) - k = l2_norm(k) - q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) cu_seqlens = cu_seqlen_all[0].to(torch.long).unique() @@ -556,6 +577,7 @@ def forward( beta=cu_beta, initial_state=recurrent_state[0], output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) recurrent_state[0] = recurrent_state_ @@ -576,6 +598,7 @@ def forward( beta=cu_beta, initial_state=recurrent_state[0], output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) recurrent_state[0] = recurrent_state_ @@ -655,8 +678,6 @@ def shared_o( v = self.silu(self.shared_v(hidden_states)) q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v)) - q = l2_norm(q) - k = l2_norm(k) beta = self.shared_b(hidden_states).sigmoid() g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) @@ -675,6 +696,7 @@ def shared_o( beta=beta.to(dtype=torch.bfloat16), initial_state=recurrent_state[-1], output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) o = o.to(dtype=q.dtype) @@ -687,6 +709,7 @@ def shared_o( beta=beta, initial_state=recurrent_state[-1], output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) else: diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py index 9e075987e..98020e0e6 100644 --- a/fla/models/mom/configuration_mom.py +++ b/fla/models/mom/configuration_mom.py @@ -12,29 +12,26 @@ class MomConfig(PretrainedConfig): def __init__( self, attn_mode: str = "chunk", - hidden_size: int = 1024, - expand_v: int = 1, - use_gate: bool = True, - use_short_conv: bool = True, + hidden_size: int = 2048, conv_size: int = 4, - head_dim: int = 256, num_heads: int = 4, + head_dim: int = 256, + expand_v: float = 1., + use_output_gate: bool = True, + use_short_conv: bool = True, max_position_embeddings: int = 2048, hidden_ratio: Optional[int] = 4, intermediate_size: Optional[int] = None, hidden_act: str = "swish", num_hidden_layers: int = 24, - norm_first: bool = False, norm_eps: float = 1e-6, attn: Optional[Dict] = None, use_cache: bool = True, - pad_token_id: int = None, + pad_token_id: Optional[int] = None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings: bool = False, initializer_range: float = 0.02, - fuse_cross_entropy: bool = True, - vocab_size: int = 32000, num_memories: int = 4, topk: int = 2, capacity: float = 1.0, @@ -43,29 +40,31 @@ def __init__( shared_mem: bool = False, single_kv_proj: bool = False, mom_backend: str = 'gated_deltanet', + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, **kwargs ): self.attn_mode = attn_mode self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim self.expand_v = expand_v - self.use_gate = use_gate - self.use_short_conv = use_short_conv self.conv_size = conv_size - self.head_dim = head_dim - self.num_heads = num_heads + self.use_output_gate = use_output_gate + self.use_short_conv = use_short_conv self.max_position_embeddings = max_position_embeddings self.hidden_ratio = hidden_ratio self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers - self.norm_first = norm_first self.norm_eps = norm_eps self.attn = attn self.use_cache = use_cache self.initializer_range = initializer_range - self.fuse_cross_entropy = fuse_cross_entropy - self.vocab_size = vocab_size + self.num_memories = num_memories self.topk = topk self.capacity = capacity @@ -75,6 +74,11 @@ def __init__( self.single_kv_proj = single_kv_proj self.mom_backend = mom_backend + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + if self.mom_backend not in ['gated_deltanet']: raise NotImplementedError(f"The MoM backend {mom_backend} is not currently supported.") diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index 53ee6ef52..395f473ea 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from transformers.activations import ACT2FN from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel @@ -19,9 +18,9 @@ from fla.layers.attn import Attention from fla.models.mom.configuration_mom import MomConfig from fla.models.utils import Cache -from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm -from fla.modules.activations import swiglu_linear -from fla.modules.layernorm import rms_norm_linear +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as MomMLP +from fla.modules import RMSNorm if TYPE_CHECKING: from transformers.processing_utils import Unpack @@ -30,58 +29,12 @@ logger = logging.get_logger(__name__) -class MomMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - hidden_ratio: Optional[int] = None, - intermediate_size: Optional[int] = None, - hidden_act: str = 'swish', - norm_first: bool = True, - norm_eps: float = 1e-5 - ) -> MomMLP: - super().__init__() - - self.hidden_size = hidden_size - # the final number of params is `hidden_ratio * hidden_size^2` - # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` - if hidden_ratio is None: - hidden_ratio = 4 - if intermediate_size is None: - intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) - intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) - self.hidden_ratio = hidden_ratio - self.intermediate_size = intermediate_size - self.norm_first = norm_first - - if norm_first: - self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps) - - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward( - self, - x: torch.Tensor, - **kwargs: Unpack[Dict], - ) -> torch.Tensor: - if self.norm_first: - x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias) - else: - x = self.gate_proj(x) - gate, y = x.chunk(2, -1) - return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) - - class MomBlock(nn.Module): def __init__(self, config: MomConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if not config.norm_first: - self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) if config.attn is not None and layer_idx in config.attn['layers']: self.attn = Attention( hidden_size=config.hidden_size, @@ -99,10 +52,9 @@ def __init__(self, config: MomConfig, layer_idx: int): expand_v=config.expand_v, head_dim=config.head_dim, num_heads=config.num_heads, - use_gate=config.use_gate, + use_output_gate=config.use_output_gate, use_short_conv=config.use_short_conv, conv_size=config.conv_size, - norm_first=config.norm_first, norm_eps=config.norm_eps, layer_idx=layer_idx, num_memories=config.num_memories, @@ -113,15 +65,13 @@ def __init__(self, config: MomConfig, layer_idx: int): ) else: raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") - if not config.norm_first: - self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) self.mlp = MomMLP( hidden_size=config.hidden_size, hidden_ratio=config.hidden_ratio, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - norm_first=config.norm_first, - norm_eps=config.norm_eps + fuse_swiglu=config.fuse_swiglu ) def forward( @@ -161,7 +111,7 @@ class MomPreTrainedModel(PreTrainedModel): config_class = MomConfig supports_gradient_checkpointing = True - _no_split_modules = ['GatedDeltaNetBlock'] + _no_split_modules = ['MomBlock'] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -241,7 +191,7 @@ def forward( **kwargs: Unpack[Dict] ) -> Union[Tuple, BaseModelOutputWithPast]: if output_attentions: - warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.") output_attentions = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states From 2b51b28282799fd7f7f0dce4d2f36c4e766abfb1 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Sun, 13 Jul 2025 13:44:00 +0000 Subject: [PATCH 13/28] Fix isort --- fla/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 53d861b73..6995f13d0 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -20,8 +20,8 @@ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel -from fla.models.mom import MomConfig, MomForCausalLM, MomModel from fla.models.mla import MLAConfig, MLAForCausalLM, MLAModel +from fla.models.mom import MomConfig, MomForCausalLM, MomModel from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel From 706d6fda5687c041185074a2a19cc5b0e903108f Mon Sep 17 00:00:00 2001 From: Jusen Date: Tue, 22 Jul 2025 20:19:24 +0800 Subject: [PATCH 14/28] support inference --- fla/layers/mom.py | 103 +++++++++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 3a82d3248..34ce9b572 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -18,45 +18,13 @@ from fla.models.utils import Cache -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input -else: - print("flash_attn_2 is not available") +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input def elu_p1(x): return (F.elu(x, 1., False) + 1.).to(x) -def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, @@ -98,7 +66,7 @@ def _upad_input( (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) @@ -528,24 +496,25 @@ def forward( conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) q, conv_state_q[0] = self.q_conv1d( x=q, cache=conv_state_q[0], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) k, conv_state_k[0] = self.k_conv1d( x=k, cache=conv_state_k[0], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) v, conv_state_v[0] = self.v_conv1d( x=v, cache=conv_state_v[0], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) @@ -555,19 +524,19 @@ def forward( q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[None, None, :, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[None, None, :, -g.shape[-2]:, None]) + q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) - cu_seqlens = cu_seqlen_all[0].to(torch.long).unique() + cu_seqlens, reverse_indices = cu_seqlen_all[0].to(torch.long).unique(return_inverse=True) cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) - # dealing with padding - if attention_mask is not None: - beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) - g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) - recurrent_state = last_state['recurrent_state'] if last_state is not None else [ None for _ in range(1 + self.shared_mem)] - cu_seqlens = kwargs.get('cu_seqlens', None) + if mode == 'chunk': o_, recurrent_state_ = chunk_gated_delta_rule( q=cu_q, @@ -580,7 +549,15 @@ def forward( use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) - recurrent_state[0] = recurrent_state_ + total_len = len(cu_seqlen_all[0]) + if use_cache and len(cu_seqlens) != total_len: + # handle the case where some memories are not used + recurrent_state[0] = recurrent_state_[reverse_indices[1:]-1] + for i in range(total_len-1): + if cu_seqlen_all[0][i] == cu_seqlen_all[0][i+1]: + recurrent_state[0][i] = torch.zeros_like(recurrent_state[0][i]) + else: + recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) @@ -590,21 +567,44 @@ def forward( seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': + total_len = len(cu_seqlen_all[0]) + if use_cache and len(cu_seqlens) != total_len: + # select memories that are activated + memories = torch.zeros_like(recurrent_state[0][:self.topk*batch_size]) + mem_id = 0 + for i in range(total_len-1): + if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: + memories[mem_id] = recurrent_state[0][reverse_indices[i+1]-1] + mem_id += 1 + assert q_len != 1 or mem_id == self.topk * batch_size, "The number of memories is not correct." + else: + memories = recurrent_state[0] + o_, recurrent_state_ = fused_recurrent_gated_delta_rule( q=cu_q, k=cu_k, v=cu_v, g=cu_g, beta=cu_beta, - initial_state=recurrent_state[0], + initial_state=memories, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) - recurrent_state[0] = recurrent_state_ + if use_cache and len(cu_seqlens) != total_len: + if recurrent_state[0] is None: + recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) + # handle the case where some memories are not used + for i in range(total_len-1): + if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: + recurrent_state[0][i] = recurrent_state_[reverse_indices[i+1]-1] + else: + recurrent_state[0] = recurrent_state_ o_ = o_.squeeze(0).contiguous() - o_list = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) + o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) + o_list = o_list[:, :, -max_len:] + o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) @@ -654,23 +654,24 @@ def shared_o( assert mode == 'chunk', "Only chunk mode is supported in training." if self.use_short_conv: + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None q, conv_state_q[1] = self.q_conv1d( x=self.q_proj(hidden_states), cache=conv_state_q[1], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) k, conv_state_k[1] = self.k_conv1d( x=self.shared_k(hidden_states), cache=conv_state_k[1], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) v, conv_state_v[1] = self.v_conv1d( x=self.shared_v(hidden_states), cache=conv_state_v[1], output_final_state=use_cache, - cu_seqlens=None + mask=conv_mask ) else: q = self.silu(self.q_proj(hidden_states)) From cea11c43bdc7fd146d8683a55208f71b5c0a6280 Mon Sep 17 00:00:00 2001 From: Jusen Date: Wed, 23 Jul 2025 16:27:02 +0800 Subject: [PATCH 15/28] Fix inference & support expand_v --- fla/layers/mom.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 34ce9b572..524493cac 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -68,10 +68,11 @@ def _upad_input( """ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + head_v_dim = value_layer.shape[-1] key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_v_dim), indices_k ) gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) @@ -568,7 +569,9 @@ def forward( elif mode == 'fused_recurrent': total_len = len(cu_seqlen_all[0]) - if use_cache and len(cu_seqlens) != total_len: + if use_cache and recurrent_state[0] is not None and len(cu_seqlens) != total_len: + if recurrent_state[0] is None: + recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) # select memories that are activated memories = torch.zeros_like(recurrent_state[0][:self.topk*batch_size]) mem_id = 0 From 8a8b6c2864b56c0cbb5840b28c781b3ea51c5b75 Mon Sep 17 00:00:00 2001 From: Jusen Date: Wed, 23 Jul 2025 16:32:21 +0800 Subject: [PATCH 16/28] Fix rescale_prenorm_residual --- fla/models/mom/modeling_mom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index 395f473ea..3e28c77f4 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -119,7 +119,7 @@ def __init__(self, *inputs, **kwargs): def _init_weights( self, module: nn.Module, - rescale_prenorm_residual: bool = True, + rescale_prenorm_residual: bool = False, num_residuals_per_layer: int = 2, ): if isinstance(module, (nn.Linear, nn.Conv1d)): From b9f958ae1029866e0c7be5174aa6510cd7831ca6 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 24 Jul 2025 07:38:00 +0000 Subject: [PATCH 17/28] Add generation testing --- tests/models/test_modeling_mom.py | 57 +++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/models/test_modeling_mom.py diff --git a/tests/models/test_modeling_mom.py b/tests/models/test_modeling_mom.py new file mode 100644 index 000000000..8c89ac32d --- /dev/null +++ b/tests/models/test_modeling_mom.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +import pytest +import torch + +from fla.models import MomConfig + +from .test_modeling_base import run_test_generation, run_test_model_forward_backward + + +# =================================================================================== +# Test for Modeling (Forward/Backward Pass) +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) + for test in [ + (4, 4, 1024, 4, 64, True, torch.bfloat16), + (4, 4, 1024, 4, 64, False, torch.bfloat16), + (4, 4, 1024, 4, 128, False, torch.bfloat16), + ] + ] +) +def test_modeling( + L: int, + B: int, + T: int, + H: int, + D: int, + use_l2warp: bool, + dtype: torch.dtype, +): + run_test_model_forward_backward(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype) + + +# =================================================================================== +# Test for Generation +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 4, 2000, 8, 64, torch.float16), + ] + ] +) +def test_generation( + L: int, + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + run_test_generation(L, B, T, H, D, MomConfig, dtype) From b105c83f7b597212aa156b09c287e06b3c765dbc Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 24 Jul 2025 07:46:43 +0000 Subject: [PATCH 18/28] Add proper pythonpath for pytest --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 88a568246..4c6ff1585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,3 +40,6 @@ multi_line_output = 3 [tool.pytest.ini_options] log_cli = true log_cli_level = "INFO" +pythonpath = [ + "." +] From 51d014b54976a23e9d58cc1a3487eaee7e53c63e Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 24 Jul 2025 07:48:25 +0000 Subject: [PATCH 19/28] Update registration of MomConfig and related models to allow for existing entries --- fla/models/mom/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fla/models/mom/__init__.py b/fla/models/mom/__init__.py index 70e12a0aa..bcc101c64 100644 --- a/fla/models/mom/__init__.py +++ b/fla/models/mom/__init__.py @@ -5,8 +5,8 @@ from fla.models.mom.configuration_mom import MomConfig from fla.models.mom.modeling_mom import MomForCausalLM, MomModel -AutoConfig.register(MomConfig.model_type, MomConfig) -AutoModel.register(MomConfig, MomModel) -AutoModelForCausalLM.register(MomConfig, MomForCausalLM) +AutoConfig.register(MomConfig.model_type, MomConfig, exist_ok=True) +AutoModel.register(MomConfig, MomModel, exist_ok=True) +AutoModelForCausalLM.register(MomConfig, MomForCausalLM, exist_ok=True) __all__ = ['MomConfig', 'MomForCausalLM', 'MomModel'] From 95dc4187dd618580c99a95ed393fd25a9437f252 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 24 Jul 2025 07:52:05 +0000 Subject: [PATCH 20/28] Delete unused act --- fla/layers/mom.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 524493cac..6becc97da 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -21,10 +21,6 @@ from fla.layers.utils import get_unpad_data, index_first_axis, pad_input -def elu_p1(x): - return (F.elu(x, 1., False) + 1.).to(x) - - def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, @@ -370,9 +366,7 @@ def __init__( A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) A_log = torch.log(A) self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True self.D = nn.Parameter(torch.ones(self.num_heads)) - self.D._no_weight_decay = True # hard coded for now dt_min = 0.001 dt_max = 0.1 @@ -387,7 +381,6 @@ def __init__( self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True if use_short_conv: self.conv_size = conv_size @@ -497,7 +490,8 @@ def forward( conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + conv_mask = attention_mask[:, -hidden_states.shape[2] + :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) q, conv_state_q[0] = self.q_conv1d( x=q, From 2d0f20322ae64c1f0c0c40c750cdb051df7992f5 Mon Sep 17 00:00:00 2001 From: Jusen Date: Sat, 26 Jul 2025 14:32:15 +0800 Subject: [PATCH 21/28] Refactor the code using cu_seqlen. test_generation passed wo o_norm --- fla/layers/mom.py | 336 ++++++++++++++++++++++++++++------------------ 1 file changed, 207 insertions(+), 129 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 6becc97da..cb37a09be 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -18,7 +18,7 @@ from fla.models.utils import Cache -from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input, unpad_input def _upad_input( @@ -28,7 +28,6 @@ def _upad_input( gate_layer: torch.Tensor, beta_layer: torch.Tensor, attention_mask: torch.Tensor, - query_length: int, ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. @@ -62,18 +61,19 @@ def _upad_input( (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ + query_length = query_layer.shape[1] indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - head_v_dim = value_layer.shape[-1] + batch_size, kv_seq_len, dim = key_layer.shape + v_dim = value_layer.shape[-1] - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_v_dim), indices_k + value_layer.reshape(batch_size * kv_seq_len, v_dim), indices_k ) - gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) - beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, num_key_value_heads), indices_k) + gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, -1), indices_k) + beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, -1), indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -106,7 +106,8 @@ def transform( routing_mask: torch.Tensor, num_memories: int, selected_memories: torch.Tensor, - capacity: float + capacity: float, + attention_mask: torch.Tensor, ): ''' Transform input sequences into memory-organized chunks with capacity constraints. @@ -153,6 +154,11 @@ def transform( # (batch, seq, topk) selected_memories = selected_memories.reshape(selected_memories.shape[0], -1) # (batch, seq * topk) + + if attention_mask is not None: + attention_mask = attention_mask[:, -routing_mask.shape[1]:] + # mask out the masked tokens + routing_mask[attention_mask.bitwise_not().unsqueeze(-1).expand(-1, -1, num_memories)] = 0 b, s, d = x.shape x_flat = x.reshape(b * s, d) # [b*s, d] @@ -160,6 +166,9 @@ def transform( with torch.no_grad(): batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) batch_indices = batch_indices.expand(b, s).reshape(-1) + if attention_mask is not None: + # sort the masked tokens to the end + batch_indices[attention_mask.repeat_interleave(topk, dim=1).bitwise_not().flatten()] = b # (b * s) memories_flat = selected_memories.reshape(-1) # [b*s] @@ -172,35 +181,24 @@ def transform( # routing_mask (b, s, num_memories) batch_memory_tokens = routing_mask.sum(dim=1) # (b, num_memories) - offset = batch_memory_tokens.cumsum(dim=1) - memory_batch_offset = offset.transpose(0, 1) - batch_offset = torch.arange(0, b*s, s, device=offset.device) - memory_batch_offset += batch_offset - flatten_offset = memory_batch_offset.transpose(0, 1).reshape(-1) - lengths = torch.concat([flatten_offset[:1], flatten_offset[1:] - flatten_offset[:-1]], dim=0) - max_len = lengths.max() - capacity_len = math.ceil(s / topk * capacity) - max_len = min(max_len, capacity_len) - + flatten_offset = batch_memory_tokens.flatten().cumsum(dim=0) + max_len = batch_memory_tokens.max() indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) - # discard tokens exceed capacity and is far from now - # left pad - truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len - mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) - mask = torch.bitwise_and(mask, truncation_indices >= torch.cat( - (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) - truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) + mask = indices < flatten_offset.unsqueeze(-1) + truncation_indices = torch.where(mask, indices, torch.zeros_like(indices)) gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) - transformed_x = gathered_x.reshape(b * num_memories, -1, d) - transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) - pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) - pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) + transformed_x = gathered_x.reshape(b * num_memories, -1, d).reshape((b, num_memories, max_len, d)).transpose(0, 1) + # transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) + # pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) + # pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) # left pad - transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) - mask_2 = torch.cat((pad_mask.bool(), mask), dim=1).reshape((b, num_memories, capacity_len)).transpose(0, 1) + # transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) + mask_2 = mask.reshape((b, num_memories, max_len)).transpose(0, 1) # truncation_indices += capacity_len-max_len + # if attention_mask is not None: + # mask_2 return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 @@ -255,26 +253,26 @@ def reconstruct( Shape: (batch_size, seq_len, hidden_size) ''' transformed_x = transformed_x.transpose(0, 1).reshape( - (-1, transformed_x.shape[2], transformed_x.shape[3], transformed_x.shape[4])) - b, s, k, h, d = batch_size, seq_len, topk, transformed_x.shape[2], transformed_x.shape[3] + (-1, transformed_x.shape[2], transformed_x.shape[3])) + b, s, k, d = batch_size, seq_len, topk, transformed_x.shape[2] gathered_x = transformed_x.reshape( - (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2], transformed_x.shape[3])) - mask_expanded = mask.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand_as(gathered_x) + (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2])) + mask_expanded = mask.reshape(-1).unsqueeze(-1).expand_as(gathered_x) gathered_x = gathered_x * mask_expanded assert (indices >= 0).all(), "Indices should be non-negative" - resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( 0, - indices.reshape(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, h, d), + indices.reshape(-1).unsqueeze(-1).expand(-1, d), gathered_x, ) assert (indices < resortd_x.size(0)).all(), "Indices should be less than resortd_x size" inverse_indices = sorted_indices.argsort() rearranged_x_flat = resortd_x[inverse_indices] - restored_x = rearranged_x_flat.reshape((b, s * k, h, d)) - restored_x = restored_x.reshape(b, s, k, h, d) * routing_weights.reshape(b, s, k).unsqueeze(-1).unsqueeze(-1) + restored_x = rearranged_x_flat.reshape((b, s * k, d)) + restored_x = restored_x.reshape(b, s, k, d) * routing_weights.reshape(b, s, k).unsqueeze(-1) restored_x = restored_x.sum(dim=2) return restored_x @@ -445,7 +443,7 @@ def forward( assert mode == 'chunk', "Only chunk mode is supported in training." last_state = None - _, q_len = hidden_states.shape[0], hidden_states.shape[1] + # _, q_len = hidden_states.shape[0], hidden_states.shape[1] if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] @@ -464,14 +462,14 @@ def forward( ).scatter(-1, selected_memories, routing_weights) routing_mask = routing_weights_full.bool().int() - if self.use_output_gate: - o_g = self.g_proj(hidden_states) + # if self.use_output_gate: + # o_g = self.g_proj(hidden_states) batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] shared_hidden_states = hidden_states - hidden_states, indices, sorted_indices, max_len, mask, mask2 = transform( - hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity) + hidden_states, indices, sorted_indices, max_len, mask, mask_2 = transform( + hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity, attention_mask) q = self.q_proj(hidden_states) if self.single_kv_proj: @@ -486,54 +484,57 @@ def forward( g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) for i, a_expert in enumerate(self.a_proj)], dim=0) + q, k, v, g, beta, mask_2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask_2)) + cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask_2) + cu_seqlens, reverse_indices = cu_seqlen_all[0].to(torch.long).unique(return_inverse=True) + cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) + if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - conv_mask = attention_mask[:, -hidden_states.shape[2] - :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None - q, k, v = map(lambda x: rearrange(x, 'e b t d -> (e b) t d'), (q, k, v)) - q, conv_state_q[0] = self.q_conv1d( - x=q, - cache=conv_state_q[0], + + ocnv_q = self.prepare_recurrent_state(conv_state_q[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_q, conv_q_new = self.q_conv1d( + x=cu_q, + cache=ocnv_q, output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens, ) - k, conv_state_k[0] = self.k_conv1d( - x=k, - cache=conv_state_k[0], + conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) + ocnv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_k, conv_k_new = self.k_conv1d( + x=cu_k, + cache=ocnv_k, output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens, ) - v, conv_state_v[0] = self.v_conv1d( - x=v, - cache=conv_state_v[0], + conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_v = self.prepare_recurrent_state(conv_state_v[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_v, conv_v_new = self.v_conv1d( + x=cu_v, + cache=conv_v, output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens, ) - - q, k, v = map(lambda x: rearrange(x, '(e b) t d -> e b t d', b=batch_size), (q, k, v)) + conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) else: q, k, v = self.silu(q), self.silu(k), self.silu(v), - q, k, v = map(lambda x: rearrange(x, 'e b t (h d) -> e b t h d', h=self.num_heads), (q, k, v)) + cu_q, cu_k, cu_v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (cu_q, cu_k, cu_v)) # dealing with padding - if attention_mask is not None: - beta = beta.mul(attention_mask[None, None, :, -beta.shape[-2]:, None]) - g = g.mul(attention_mask[None, None, :, -g.shape[-2]:, None]) - - q, k, v, g, beta, mask2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask2)) - cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask2, q_len) - cu_seqlens, reverse_indices = cu_seqlen_all[0].to(torch.long).unique(return_inverse=True) - cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) + # if attention_mask is not None: + # v = v.mul(attention_mask[None, :, -v.shape[-3]:, None, None]) + # k = k.mul(attention_mask[None, :, -k.shape[-3]:, None, None]) + # beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) + # g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) recurrent_state = last_state['recurrent_state'] if last_state is not None else [ None for _ in range(1 + self.shared_mem)] - if mode == 'chunk': - o_, recurrent_state_ = chunk_gated_delta_rule( + o, recurrent_state_ = chunk_gated_delta_rule( q=cu_q, k=cu_k, v=cu_v, @@ -544,40 +545,43 @@ def forward( use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) - total_len = len(cu_seqlen_all[0]) - if use_cache and len(cu_seqlens) != total_len: - # handle the case where some memories are not used - recurrent_state[0] = recurrent_state_[reverse_indices[1:]-1] - for i in range(total_len-1): - if cu_seqlen_all[0][i] == cu_seqlen_all[0][i+1]: - recurrent_state[0][i] = torch.zeros_like(recurrent_state[0][i]) - else: - recurrent_state[0] = recurrent_state_ - o_ = o_.squeeze(0).contiguous() - o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) - o_list = o_list[:, :, -max_len:] - - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) + # total_len = len(cu_seqlen_all[0]) + # if use_cache and len(cu_seqlens) != total_len: + # # handle the case where some memories are not used + # recurrent_state[0] = recurrent_state_[reverse_indices[1:]-1] + # for i in range(total_len-1): + # if cu_seqlen_all[0][i] == cu_seqlen_all[0][i+1]: + # recurrent_state[0][i] = torch.zeros_like(recurrent_state[0][i]) + # else: + # recurrent_state[0] = recurrent_state_ + + # o_ = o_.squeeze(0).contiguous() + # o = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) + # o = rearrange(o, '(e b) l h d -> e b l h d', b=batch_size) + # o_list = o_list[:, :, -max_len:] + + # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': - total_len = len(cu_seqlen_all[0]) - if use_cache and recurrent_state[0] is not None and len(cu_seqlens) != total_len: - if recurrent_state[0] is None: - recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) - # select memories that are activated - memories = torch.zeros_like(recurrent_state[0][:self.topk*batch_size]) - mem_id = 0 - for i in range(total_len-1): - if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: - memories[mem_id] = recurrent_state[0][reverse_indices[i+1]-1] - mem_id += 1 - assert q_len != 1 or mem_id == self.topk * batch_size, "The number of memories is not correct." - else: - memories = recurrent_state[0] - - o_, recurrent_state_ = fused_recurrent_gated_delta_rule( + # total_len = len(cu_seqlen_all[0]) + # if use_cache and recurrent_state[0] is not None and len(cu_seqlens) != total_len: + # if recurrent_state[0] is None: + # recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) + # # select memories that are activated + # memories = torch.zeros_like(recurrent_state[0][:self.topk*batch_size]) + # mem_id = 0 + # for i in range(total_len-1): + # if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: + # memories[mem_id] = recurrent_state[0][i] + # mem_id += 1 + # assert seq_len != 1 or mem_id == self.topk * batch_size, "The number of memories is not correct." + # else: + # memories = recurrent_state[0] + memories = self.prepare_recurrent_state(recurrent_state[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + + o, recurrent_state_ = fused_recurrent_gated_delta_rule( q=cu_q, k=cu_k, v=cu_v, @@ -588,27 +592,26 @@ def forward( use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, ) - if use_cache and len(cu_seqlens) != total_len: - if recurrent_state[0] is None: - recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) - # handle the case where some memories are not used - for i in range(total_len-1): - if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: - recurrent_state[0][i] = recurrent_state_[reverse_indices[i+1]-1] - else: - recurrent_state[0] = recurrent_state_ - o_ = o_.squeeze(0).contiguous() - o_list = pad_input(o_, indices_q, batch_size*self.num_memories, q_len) - o_list = rearrange(o_list, '(e b) l h d -> e b l h d', b=batch_size) - o_list = o_list[:, :, -max_len:] - - o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) - - if self.shared_mem: - shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, - use_cache, conv_state_q, conv_state_k, conv_state_v) - o += shared_o + recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) + # if use_cache and len(cu_seqlens) != total_len: + # if recurrent_state[0] is None: + # recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) + # # handle the case where some memories are not used + # for i in range(total_len-1): + # if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: + # recurrent_state[0][i] = recurrent_state_[reverse_indices[i+1]-1] + # else: + # recurrent_state[0] = recurrent_state_ + + # o_list = o_list[:, :, -max_len:] + + # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + + # if self.shared_mem: + # shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, + # use_cache, conv_state_q, conv_state_k, conv_state_v) + # o += shared_o if past_key_values is not None: past_key_values.update( @@ -619,13 +622,20 @@ def forward( ) if self.use_output_gate: - o_g = rearrange(o_g, '... (h d) -> ... h d', h=self.num_heads) - o = self.o_norm(o, o_g) + hidden_states = index_first_axis(rearrange(hidden_states, "e b s ... -> (e b s) ..."), indices_q).unsqueeze(0) + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) else: o = self.o_norm(o) o = rearrange(o, 'b t h d -> b t (h d)') o = self.o_proj(o) + o = o.squeeze(0).contiguous() + o = pad_input(o, indices_q, batch_size*self.num_memories, max_len) + o = rearrange(o, '(e b) l d -> e b l d', b=batch_size) + o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + return o, None, past_key_values, router_logits def shared_o( @@ -714,3 +724,71 @@ def shared_o( raise NotImplementedError(f"Not supported mode `{mode}`.") return o + + def prepare_conv_state(self, conv_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): + if conv_state is None: + return None + + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + # select memories that are activated + memories = torch.zeros_like(conv_state[:self.topk*batch_size]) + mem_id = 0 + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + memories[mem_id] = conv_state[i] + mem_id += 1 + assert mem_id == self.topk * batch_size, "The number of memories is not correct." + else: + memories = conv_state + + return memories + + def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): + if recurrent_state is None: + return None + + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + # select memories that are activated + memories = torch.zeros_like(recurrent_state[:self.topk*batch_size]) + mem_id = 0 + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + memories[mem_id] = recurrent_state[i] + mem_id += 1 + assert mem_id == self.topk * batch_size, "The number of memories is not correct." + else: + memories = recurrent_state + + return memories + + + def handle_conv_state(self, conv_state, conv_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): + if conv_state_new is None: + return None + if conv_state is None: + conv_state = torch.zeros_like(conv_state_new[reverse_indices[1:]-1]) + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + # handle the case where some memories are not used + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + conv_state[i] = conv_state_new[reverse_indices[i+1]-1] + else: + conv_state = conv_state_new + return conv_state + + def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): + if recurrent_state_new is None: + return None + if recurrent_state is None: + recurrent_state = torch.zeros_like(recurrent_state_new[reverse_indices[1:]-1]) + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + recurrent_state[i] = recurrent_state_new[reverse_indices[i+1]-1] + else: + recurrent_state = recurrent_state_new + return recurrent_state \ No newline at end of file From da5272238399818f01e6ad425b59eebb3ebf9fe1 Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 27 Jul 2025 13:17:19 +0800 Subject: [PATCH 22/28] Support cu_seqlens --- fla/layers/mom.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index cb37a09be..40507c810 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -438,6 +438,10 @@ def forward( "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) + origin_cu_seqlens = kwargs.get('cu_seqlens', None) + if origin_cu_seqlens is not None: + hidden_states, attention_mask = self.cu2pad(hidden_states, origin_cu_seqlens) + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode if self.training: assert mode == 'chunk', "Only chunk mode is supported in training." @@ -636,6 +640,9 @@ def forward( o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + if origin_cu_seqlens is not None: + indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) + o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) return o, None, past_key_values, router_logits def shared_o( @@ -725,6 +732,22 @@ def shared_o( return o + def cu2pad(self, x, cu_seqlens): + batch_size = cu_seqlens.shape[0] - 1 + max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() + indices = torch.tensor([], dtype=torch.long, device=x.device) + attention_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=x.device) + for i in range(batch_size): + seq_len = cu_seqlens[i+1] - cu_seqlens[i] + pad_len = max_len - seq_len + batch_indices = torch.arange(pad_len, max_len, device=x.device) + batch_indices = batch_indices + i * max_len + indices = torch.cat([indices, batch_indices]) + attention_mask[i, :pad_len] = False + x = pad_input(x.squeeze(0), indices, batch_size, max_len) + return x, attention_mask + + def prepare_conv_state(self, conv_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): if conv_state is None: return None From 273cce65f9a50686cd8688e8123233299f108236 Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 27 Jul 2025 13:29:08 +0800 Subject: [PATCH 23/28] Support shared memory & test generatation passed wo norm --- fla/layers/mom.py | 78 ++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 59 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 40507c810..8bd2fe1a4 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -612,10 +612,10 @@ def forward( # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) - # if self.shared_mem: - # shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, - # use_cache, conv_state_q, conv_state_k, conv_state_v) - # o += shared_o + if self.shared_mem: + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, + use_cache, conv_state_q, conv_state_k, conv_state_v) + o += shared_o if past_key_values is not None: past_key_values.update( @@ -667,25 +667,28 @@ def shared_o( if self.training: assert mode == 'chunk', "Only chunk mode is supported in training." + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + if self.use_short_conv: - conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None q, conv_state_q[1] = self.q_conv1d( x=self.q_proj(hidden_states), cache=conv_state_q[1], output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens ) k, conv_state_k[1] = self.k_conv1d( x=self.shared_k(hidden_states), cache=conv_state_k[1], output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens ) v, conv_state_v[1] = self.v_conv1d( x=self.shared_v(hidden_states), cache=conv_state_v[1], output_final_state=use_cache, - mask=conv_mask + cu_seqlens=cu_seqlens ) else: q = self.silu(self.q_proj(hidden_states)) @@ -696,25 +699,18 @@ def shared_o( beta = self.shared_b(hidden_states).sigmoid() g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) - # dealing with padding - if attention_mask is not None: - beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) - g = g.mul(attention_mask[:, -g.shape[-2]:, None]) - - cu_seqlens = kwargs.get('cu_seqlens', None) if mode == 'chunk': o, recurrent_state[-1] = chunk_gated_delta_rule( - q=q.to(dtype=torch.bfloat16), - k=k.to(dtype=torch.bfloat16), - v=v.to(dtype=torch.bfloat16), - g=g.to(dtype=torch.bfloat16), - beta=beta.to(dtype=torch.bfloat16), + q=q, + k=k, + v=v, + g=g, + beta=beta, initial_state=recurrent_state[-1], output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True ) - o = o.to(dtype=q.dtype) elif mode == 'fused_recurrent': o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( q=q, @@ -724,8 +720,8 @@ def shared_o( beta=beta, initial_state=recurrent_state[-1], output_final_state=use_cache, - use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") @@ -746,27 +742,7 @@ def cu2pad(self, x, cu_seqlens): attention_mask[i, :pad_len] = False x = pad_input(x.squeeze(0), indices, batch_size, max_len) return x, attention_mask - - - def prepare_conv_state(self, conv_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): - if conv_state is None: - return None - - total_len = len(cu_seqlen_all) - if len(cu_seqlens) != total_len: - # select memories that are activated - memories = torch.zeros_like(conv_state[:self.topk*batch_size]) - mem_id = 0 - for i in range(total_len-1): - if cu_seqlen_all[i] != cu_seqlen_all[i+1]: - memories[mem_id] = conv_state[i] - mem_id += 1 - assert mem_id == self.topk * batch_size, "The number of memories is not correct." - else: - memories = conv_state - - return memories - + def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): if recurrent_state is None: return None @@ -785,22 +761,6 @@ def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, re memories = recurrent_state return memories - - - def handle_conv_state(self, conv_state, conv_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): - if conv_state_new is None: - return None - if conv_state is None: - conv_state = torch.zeros_like(conv_state_new[reverse_indices[1:]-1]) - total_len = len(cu_seqlen_all) - if len(cu_seqlens) != total_len: - # handle the case where some memories are not used - for i in range(total_len-1): - if cu_seqlen_all[i] != cu_seqlen_all[i+1]: - conv_state[i] = conv_state_new[reverse_indices[i+1]-1] - else: - conv_state = conv_state_new - return conv_state def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): if recurrent_state_new is None: From bf5b511e5285422767fc4ab60a2c28bc2d8a1c26 Mon Sep 17 00:00:00 2001 From: Jusen Date: Sun, 27 Jul 2025 19:42:42 +0800 Subject: [PATCH 24/28] Fix bugs --- fla/layers/mom.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 8bd2fe1a4..aa97d0ffe 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -498,20 +498,26 @@ def forward( if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - ocnv_q = self.prepare_recurrent_state(conv_state_q[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_cu_seqlens = cu_seqlens + if self.training and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.q_conv1d.kernel_size[0]: + # During training, some memories may degrade and the numbers of tokens routed to them are smaller than + # short conv kernel size which is not supported, so we set cu_seqlens to None during pretraining. + conv_cu_seqlens = None + + conv_q = self.prepare_recurrent_state(conv_state_q[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_q, conv_q_new = self.q_conv1d( x=cu_q, - cache=ocnv_q, + cache=conv_q, output_final_state=use_cache, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, ) conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) - ocnv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_k, conv_k_new = self.k_conv1d( x=cu_k, - cache=ocnv_k, + cache=conv_k, output_final_state=use_cache, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, ) conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) conv_v = self.prepare_recurrent_state(conv_state_v[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) @@ -519,7 +525,7 @@ def forward( x=cu_v, cache=conv_v, output_final_state=use_cache, - cu_seqlens=cu_seqlens, + cu_seqlens=conv_cu_seqlens, ) conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) @@ -612,6 +618,13 @@ def forward( # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o = o.squeeze(0).contiguous() + o = pad_input(o, indices_q, batch_size*self.num_memories, max_len) + o = rearrange(o, '(e b) l h d -> e b l (h d)', b=batch_size) + o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) + if self.shared_mem: shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) @@ -626,20 +639,16 @@ def forward( ) if self.use_output_gate: - hidden_states = index_first_axis(rearrange(hidden_states, "e b s ... -> (e b s) ..."), indices_q).unsqueeze(0) - g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -seq_len:]) + shared_hidden_states = index_first_axis(rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + g = rearrange(self.g_proj(shared_hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) o = self.o_norm(o, g) else: o = self.o_norm(o) o = rearrange(o, 'b t h d -> b t (h d)') o = self.o_proj(o) - o = o.squeeze(0).contiguous() - o = pad_input(o, indices_q, batch_size*self.num_memories, max_len) - o = rearrange(o, '(e b) l d -> e b l d', b=batch_size) - o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) - if origin_cu_seqlens is not None: indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) @@ -667,7 +676,9 @@ def shared_o( if self.training: assert mode == 'chunk', "Only chunk mode is supported in training." + cu_seqlens = None if attention_mask is not None: + batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1] indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) @@ -726,6 +737,8 @@ def shared_o( else: raise NotImplementedError(f"Not supported mode `{mode}`.") + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) return o def cu2pad(self, x, cu_seqlens): From 0f80d59ef39d568f9325820e3dc6b3b5f6293b20 Mon Sep 17 00:00:00 2001 From: Jusen Date: Mon, 28 Jul 2025 10:16:59 +0800 Subject: [PATCH 25/28] Fix bugs --- fla/layers/mom.py | 4 ++++ fla/models/mom/configuration_mom.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index aa97d0ffe..045c129ec 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -642,6 +642,7 @@ def forward( if attention_mask is not None: indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -seq_len:]) shared_hidden_states = index_first_axis(rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) g = rearrange(self.g_proj(shared_hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) o = self.o_norm(o, g) else: @@ -649,6 +650,9 @@ def forward( o = rearrange(o, 'b t h d -> b t (h d)') o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, seq_len) + if origin_cu_seqlens is not None: indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py index 98020e0e6..b8573e362 100644 --- a/fla/models/mom/configuration_mom.py +++ b/fla/models/mom/configuration_mom.py @@ -37,7 +37,7 @@ def __init__( capacity: float = 1.0, use_layer_wise_balance: bool = True, aux_loss_scale: float = 0.01, - shared_mem: bool = False, + shared_mem: bool = True, single_kv_proj: bool = False, mom_backend: str = 'gated_deltanet', fuse_norm: bool = True, From 8ba6bed3bb6597c95d68f957a7d05c7ccf0b4421 Mon Sep 17 00:00:00 2001 From: Jusen Date: Mon, 28 Jul 2025 12:54:08 +0800 Subject: [PATCH 26/28] Remove router_logits in output & fix bugs --- fla/layers/mom.py | 47 +++++++++++++++++++++++++++------- fla/models/mom/modeling_mom.py | 47 +++++++++++++++++----------------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index 045c129ec..f0aac9a5e 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -300,6 +300,7 @@ def __init__( capacity: float = 1.0, shared_mem: bool = False, single_kv_proj: bool = False, + aux_loss_scale: float = 0.01, **kwargs ) -> MomAttention: super().__init__() @@ -499,35 +500,36 @@ def forward( conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_cu_seqlens = cu_seqlens - if self.training and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.q_conv1d.kernel_size[0]: + if seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.q_conv1d.kernel_size[0]: + assert self.training, "Memory degradation should only happen during training." # During training, some memories may degrade and the numbers of tokens routed to them are smaller than # short conv kernel size which is not supported, so we set cu_seqlens to None during pretraining. conv_cu_seqlens = None - conv_q = self.prepare_recurrent_state(conv_state_q[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_q, conv_q_new = self.q_conv1d( x=cu_q, cache=conv_q, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, ) - conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) - conv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_k = self.prepare_recurrent_state(conv_state_k[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_k, conv_k_new = self.k_conv1d( x=cu_k, cache=conv_k, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, ) - conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) - conv_v = self.prepare_recurrent_state(conv_state_v[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_v = self.prepare_recurrent_state(conv_state_v[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_v, conv_v_new = self.v_conv1d( x=cu_v, cache=conv_v, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, ) - conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) else: q, k, v = self.silu(q), self.silu(k), self.silu(v), @@ -656,7 +658,9 @@ def forward( if origin_cu_seqlens is not None: indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) - return o, None, past_key_values, router_logits + + aux_loss = self.cal_aux_loss(routing_weights_full, selected_memories) if self.training else None + return o, None, past_key_values, aux_loss def shared_o( self, @@ -791,4 +795,29 @@ def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlen recurrent_state[i] = recurrent_state_new[reverse_indices[i+1]-1] else: recurrent_state = recurrent_state_new - return recurrent_state \ No newline at end of file + return recurrent_state + + def cal_aux_loss(self, routing_weights_full, selected_experts): + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, self.num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (self.num_memories**2) + return balance_loss \ No newline at end of file diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index 3e28c77f4..33a993694 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -61,7 +61,8 @@ def __init__(self, config: MomConfig, layer_idx: int): topk=config.topk, capacity=config.capacity, shared_mem=config.shared_mem, - single_kv_proj=config.single_kv_proj + single_kv_proj=config.single_kv_proj, + aux_loss_scale=config.aux_loss_scale, ) else: raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") @@ -86,7 +87,7 @@ def forward( residual = hidden_states if hasattr(self, 'attn_norm'): hidden_states = self.attn_norm(hidden_states) - hidden_states, attentions, past_key_values, router_logits = self.attn( + hidden_states, attentions, past_key_values, aux_loss = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -102,7 +103,7 @@ def forward( hidden_states = self.mlp(hidden_states, **kwargs) hidden_states = residual + hidden_states - outputs = (hidden_states, attentions, past_key_values, router_logits) + outputs = (hidden_states, attentions, past_key_values, aux_loss) return outputs @@ -154,7 +155,7 @@ def _init_weights( @dataclass class MomOutputWithPast(BaseModelOutputWithPast): - router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + aux_loss: Optional[torch.FloatTensor] = None class MomModel(MomPreTrainedModel): @@ -217,14 +218,14 @@ def forward( all_hidden_states = () if output_hidden_states else None all_attns = () if output_attentions else None - all_router_logits = () + total_aux_loss = 0. for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + hidden_states, attentions, past_key_values, aux_loss = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -234,7 +235,7 @@ def forward( **kwargs ) else: - hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, attentions, past_key_values, aux_loss = layer( hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -245,7 +246,8 @@ def forward( if output_attentions: all_attns += (attentions,) - all_router_logits += (router_logits,) + if aux_loss is not None: + total_aux_loss = total_aux_loss + aux_loss hidden_states = self.norm(hidden_states) @@ -260,14 +262,13 @@ def forward( past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attns, - router_logits=all_router_logits + aux_loss=total_aux_loss ) @dataclass class MomCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None - router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None class MomForCausalLM(MomPreTrainedModel, GenerationMixin): @@ -282,6 +283,7 @@ def __init__(self, config): self.num_memories = config.num_memories self.topk = config.topk self.aux_loss_scale = config.aux_loss_scale + self.num_hidden_layers = config.num_hidden_layers # Initialize weights and apply final processing self.post_init() @@ -410,18 +412,18 @@ def forward( else: loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - valid_router_logits = tuple( - logits - for logits in (outputs.router_logits if return_dict else outputs[-1]) - if logits is not None - ) - aux_loss = load_balancing_loss_func( - valid_router_logits, - self.num_memories, - self.topk, - use_layer_wise_balance=self.config.use_layer_wise_balance, - ) - aux_loss *= self.aux_loss_scale + # valid_router_logits = tuple( + # logits + # for logits in (outputs.router_logits if return_dict else outputs[-1]) + # if logits is not None + # ) + # aux_loss = load_balancing_loss_func( + # valid_router_logits, + # self.num_memories, + # self.topk, + # use_layer_wise_balance=self.config.use_layer_wise_balance, + # ) + aux_loss = self.aux_loss_scale * outputs.aux_loss / self.num_hidden_layers loss += aux_loss @@ -435,7 +437,6 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - router_logits=outputs.router_logits, aux_loss=aux_loss ) From f5df22a45800492ce0bcfee88211036262764dcc Mon Sep 17 00:00:00 2001 From: Jusen Date: Mon, 4 Aug 2025 15:23:33 +0800 Subject: [PATCH 27/28] Fix bugs --- fla/layers/mom.py | 128 ++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index f0aac9a5e..dce2368ac 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -165,7 +165,7 @@ def transform( with torch.no_grad(): batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) - batch_indices = batch_indices.expand(b, s).reshape(-1) + batch_indices = batch_indices.repeat(1, s).reshape(-1) if attention_mask is not None: # sort the masked tokens to the end batch_indices[attention_mask.repeat_interleave(topk, dim=1).bitwise_not().flatten()] = b @@ -433,6 +433,7 @@ def forward( **kwargs: Unpack[Dict] ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if attention_mask is not None: + attention_mask = (attention_mask == 1) assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " @@ -498,13 +499,15 @@ def forward( conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] - + conv_cu_seqlens = cu_seqlens - if seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.q_conv1d.kernel_size[0]: - assert self.training, "Memory degradation should only happen during training." - # During training, some memories may degrade and the numbers of tokens routed to them are smaller than - # short conv kernel size which is not supported, so we set cu_seqlens to None during pretraining. - conv_cu_seqlens = None + padded = False + if seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.conv_size: + if self.training: + conv_cu_seqlens = None + else: + padded = True + conv_cu_seqlens, cu_q, cu_k, cu_v, pad_lengths = self.pad_for_conv(cu_seqlens, cu_q, cu_k, cu_v) conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_q, conv_q_new = self.q_conv1d( @@ -531,18 +534,14 @@ def forward( ) conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + if padded: + cu_q, cu_k, cu_v = self.unpad_after_conv(conv_cu_seqlens, cu_seqlens, cu_q, cu_k, cu_v, pad_lengths) + else: q, k, v = self.silu(q), self.silu(k), self.silu(v), cu_q, cu_k, cu_v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (cu_q, cu_k, cu_v)) - # dealing with padding - # if attention_mask is not None: - # v = v.mul(attention_mask[None, :, -v.shape[-3]:, None, None]) - # k = k.mul(attention_mask[None, :, -k.shape[-3]:, None, None]) - # beta = beta.mul(attention_mask[None, :, -beta.shape[-2]:, None]) - # g = g.mul(attention_mask[None, :, -g.shape[-2]:, None]) - recurrent_state = last_state['recurrent_state'] if last_state is not None else [ None for _ in range(1 + self.shared_mem)] if mode == 'chunk': @@ -558,39 +557,8 @@ def forward( cu_seqlens=cu_seqlens, ) recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) - # total_len = len(cu_seqlen_all[0]) - # if use_cache and len(cu_seqlens) != total_len: - # # handle the case where some memories are not used - # recurrent_state[0] = recurrent_state_[reverse_indices[1:]-1] - # for i in range(total_len-1): - # if cu_seqlen_all[0][i] == cu_seqlen_all[0][i+1]: - # recurrent_state[0][i] = torch.zeros_like(recurrent_state[0][i]) - # else: - # recurrent_state[0] = recurrent_state_ - - # o_ = o_.squeeze(0).contiguous() - # o = pad_input(o_, indices_q, batch_size*self.num_memories, max_len) - # o = rearrange(o, '(e b) l h d -> e b l h d', b=batch_size) - # o_list = o_list[:, :, -max_len:] - - # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) elif mode == 'fused_recurrent': - # total_len = len(cu_seqlen_all[0]) - # if use_cache and recurrent_state[0] is not None and len(cu_seqlens) != total_len: - # if recurrent_state[0] is None: - # recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) - # # select memories that are activated - # memories = torch.zeros_like(recurrent_state[0][:self.topk*batch_size]) - # mem_id = 0 - # for i in range(total_len-1): - # if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: - # memories[mem_id] = recurrent_state[0][i] - # mem_id += 1 - # assert seq_len != 1 or mem_id == self.topk * batch_size, "The number of memories is not correct." - # else: - # memories = recurrent_state[0] memories = self.prepare_recurrent_state(recurrent_state[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) o, recurrent_state_ = fused_recurrent_gated_delta_rule( @@ -605,20 +573,6 @@ def forward( cu_seqlens=cu_seqlens, ) recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) - # if use_cache and len(cu_seqlens) != total_len: - # if recurrent_state[0] is None: - # recurrent_state[0] = torch.zeros_like(recurrent_state_[reverse_indices[1:]-1]) - # # handle the case where some memories are not used - # for i in range(total_len-1): - # if cu_seqlen_all[0][i] != cu_seqlen_all[0][i+1]: - # recurrent_state[0][i] = recurrent_state_[reverse_indices[i+1]-1] - # else: - # recurrent_state[0] = recurrent_state_ - - # o_list = o_list[:, :, -max_len:] - - # o = reconstruct(o_list, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - # seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) o = o.squeeze(0).contiguous() o = pad_input(o, indices_q, batch_size*self.num_memories, max_len) @@ -641,10 +595,6 @@ def forward( ) if self.use_output_gate: - if attention_mask is not None: - indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -seq_len:]) - shared_hidden_states = index_first_axis(rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) - o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) g = rearrange(self.g_proj(shared_hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) o = self.o_norm(o, g) else: @@ -652,9 +602,6 @@ def forward( o = rearrange(o, 'b t h d -> b t (h d)') o = self.o_proj(o) - if attention_mask is not None: - o = pad_input(o.squeeze(0), indices, batch_size, seq_len) - if origin_cu_seqlens is not None: indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) @@ -764,10 +711,55 @@ def cu2pad(self, x, cu_seqlens): x = pad_input(x.squeeze(0), indices, batch_size, max_len) return x, attention_mask + def pad_for_conv(self, cu_seqlens, cu_q, cu_k, cu_v): + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + pad_lengths = torch.clamp(self.conv_size - lengths, min=0) + new_lengths = lengths + pad_lengths + new_cu_seqlens = torch.cat([ + torch.tensor([0], device=cu_seqlens.device, dtype=cu_seqlens.dtype), + torch.cumsum(new_lengths, dim=0) + ]) + final_total_len = new_cu_seqlens[-1].item() + new_q = torch.zeros((1, final_total_len, cu_q.shape[-1]), dtype=cu_q.dtype, device=cu_q.device) + new_k = torch.zeros((1, final_total_len, cu_k.shape[-1]), dtype=cu_k.dtype, device=cu_k.device) + new_v = torch.zeros((1, final_total_len, cu_v.shape[-1]), dtype=cu_v.dtype, device=cu_v.device) + num_sequences = len(lengths) + for i in range(num_sequences): + src_start = cu_seqlens[i] + src_end = cu_seqlens[i+1] + dest_start = new_cu_seqlens[i] + pad_lengths[i] + dest_end = new_cu_seqlens[i+1] + new_q[:, dest_start:dest_end, ...] = cu_q[:, src_start:src_end, ...] + new_k[:, dest_start:dest_end, ...] = cu_k[:, src_start:src_end, ...] + new_v[:, dest_start:dest_end, ...] = cu_v[:, src_start:src_end, ...] + + return new_cu_seqlens, new_q, new_k, new_v, pad_lengths + + def unpad_after_conv(self, conv_cu_seqlens, cu_seqlens, cu_q, cu_k, cu_v, pad_lengths): + original_total_len = cu_seqlens[-1].item() + orig_q = torch.empty((1, original_total_len, cu_q.shape[-1]), dtype=cu_q.dtype, device=cu_q.device) + orig_k = torch.empty((1, original_total_len, cu_k.shape[-1]), dtype=cu_k.dtype, device=cu_k.device) + orig_v = torch.empty((1, original_total_len, cu_v.shape[-1]), dtype=cu_v.dtype, device=cu_v.device) + + num_sequences = len(pad_lengths) + for i in range(num_sequences): + dest_start = cu_seqlens[i] + dest_end = cu_seqlens[i+1] + src_start = conv_cu_seqlens[i] + pad_lengths[i] + src_end = conv_cu_seqlens[i+1] + + orig_q[:, dest_start:dest_end, ...] = cu_q[:, src_start:src_end, ...] + orig_k[:, dest_start:dest_end, ...] = cu_k[:, src_start:src_end, ...] + orig_v[:, dest_start:dest_end, ...] = cu_v[:, src_start:src_end, ...] + return orig_q, orig_v, orig_v + def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): if recurrent_state is None: return None + if cu_seqlens is None: + return recurrent_state + total_len = len(cu_seqlen_all) if len(cu_seqlens) != total_len: # select memories that are activated @@ -777,7 +769,7 @@ def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, re if cu_seqlen_all[i] != cu_seqlen_all[i+1]: memories[mem_id] = recurrent_state[i] mem_id += 1 - assert mem_id == self.topk * batch_size, "The number of memories is not correct." + assert mem_id == self.topk * batch_size, f"The number of memories {mem_id} is not correct." else: memories = recurrent_state @@ -786,6 +778,8 @@ def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, re def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): if recurrent_state_new is None: return None + if cu_seqlens is None: + return recurrent_state_new if recurrent_state is None: recurrent_state = torch.zeros_like(recurrent_state_new[reverse_indices[1:]-1]) total_len = len(cu_seqlen_all) From ad802b0efdda3380aba31bb057fc8e7231db32d0 Mon Sep 17 00:00:00 2001 From: Jusen Date: Mon, 18 Aug 2025 17:04:04 +0800 Subject: [PATCH 28/28] Fix bugs --- fla/layers/mom.py | 51 ++------- fla/models/mom/modeling_mom.py | 204 +++++++++++++++++---------------- 2 files changed, 115 insertions(+), 140 deletions(-) diff --git a/fla/layers/mom.py b/fla/layers/mom.py index dce2368ac..cea23d17a 100644 --- a/fla/layers/mom.py +++ b/fla/layers/mom.py @@ -300,7 +300,6 @@ def __init__( capacity: float = 1.0, shared_mem: bool = False, single_kv_proj: bool = False, - aux_loss_scale: float = 0.01, **kwargs ) -> MomAttention: super().__init__() @@ -363,9 +362,8 @@ def __init__( self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) - A_log = torch.log(A) - self.A_log = nn.Parameter(A_log) - self.D = nn.Parameter(torch.ones(self.num_heads)) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True # hard coded for now dt_min = 0.001 dt_max = 0.1 @@ -380,6 +378,7 @@ def __init__( self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True if use_short_conv: self.conv_size = conv_size @@ -502,12 +501,11 @@ def forward( conv_cu_seqlens = cu_seqlens padded = False - if seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.conv_size: - if self.training: - conv_cu_seqlens = None - else: - padded = True - conv_cu_seqlens, cu_q, cu_k, cu_v, pad_lengths = self.pad_for_conv(cu_seqlens, cu_q, cu_k, cu_v) + if self.training: + conv_cu_seqlens = None + elif seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.conv_size: + padded = True + conv_cu_seqlens, cu_q, cu_k, cu_v, pad_lengths = self.pad_for_conv(cu_seqlens, cu_q, cu_k, cu_v) conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) cu_q, conv_q_new = self.q_conv1d( @@ -560,7 +558,6 @@ def forward( elif mode == 'fused_recurrent': memories = self.prepare_recurrent_state(recurrent_state[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) - o, recurrent_state_ = fused_recurrent_gated_delta_rule( q=cu_q, k=cu_k, @@ -606,8 +603,7 @@ def forward( indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) - aux_loss = self.cal_aux_loss(routing_weights_full, selected_memories) if self.training else None - return o, None, past_key_values, aux_loss + return o, None, past_key_values, router_logits.view(-1, self.num_memories) def shared_o( self, @@ -751,7 +747,7 @@ def unpad_after_conv(self, conv_cu_seqlens, cu_seqlens, cu_q, cu_k, cu_v, pad_le orig_q[:, dest_start:dest_end, ...] = cu_q[:, src_start:src_end, ...] orig_k[:, dest_start:dest_end, ...] = cu_k[:, src_start:src_end, ...] orig_v[:, dest_start:dest_end, ...] = cu_v[:, src_start:src_end, ...] - return orig_q, orig_v, orig_v + return orig_q, orig_k, orig_v def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): if recurrent_state is None: @@ -789,29 +785,4 @@ def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlen recurrent_state[i] = recurrent_state_new[reverse_indices[i+1]-1] else: recurrent_state = recurrent_state_new - return recurrent_state - - def cal_aux_loss(self, routing_weights_full, selected_experts): - # cast the expert indices to int64, otherwise one-hot encoding will fail - if selected_experts.dtype != torch.int64: - selected_experts = selected_experts.to(torch.int64) - - if len(selected_experts.shape) == 2: - selected_experts = selected_experts.unsqueeze(2) - - expert_mask = torch.nn.functional.one_hot(selected_experts, self.num_memories) - - # For a given token, determine if it was routed to a given expert. - expert_mask = torch.max(expert_mask, axis=-2).values - - # cast to float32 otherwise mean will fail - expert_mask = expert_mask.to(torch.float32) - tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) - - router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) - - # ✨ balance loss for this layer - balance_loss = torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert - ) * (self.num_memories**2) - return balance_loss \ No newline at end of file + return recurrent_state \ No newline at end of file diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py index 33a993694..26e8901a5 100644 --- a/fla/models/mom/modeling_mom.py +++ b/fla/models/mom/modeling_mom.py @@ -29,6 +29,89 @@ logger = logging.get_logger(__name__) + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + class MomBlock(nn.Module): def __init__(self, config: MomConfig, layer_idx: int): super().__init__() @@ -62,7 +145,6 @@ def __init__(self, config: MomConfig, layer_idx: int): capacity=config.capacity, shared_mem=config.shared_mem, single_kv_proj=config.single_kv_proj, - aux_loss_scale=config.aux_loss_scale, ) else: raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") @@ -87,7 +169,7 @@ def forward( residual = hidden_states if hasattr(self, 'attn_norm'): hidden_states = self.attn_norm(hidden_states) - hidden_states, attentions, past_key_values, aux_loss = self.attn( + hidden_states, attentions, past_key_values, router_logits = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -103,7 +185,7 @@ def forward( hidden_states = self.mlp(hidden_states, **kwargs) hidden_states = residual + hidden_states - outputs = (hidden_states, attentions, past_key_values, aux_loss) + outputs = (hidden_states, attentions, past_key_values, router_logits) return outputs @@ -155,7 +237,7 @@ def _init_weights( @dataclass class MomOutputWithPast(BaseModelOutputWithPast): - aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None class MomModel(MomPreTrainedModel): @@ -218,14 +300,14 @@ def forward( all_hidden_states = () if output_hidden_states else None all_attns = () if output_attentions else None - total_aux_loss = 0. + all_router_logits = () for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - hidden_states, attentions, past_key_values, aux_loss = self._gradient_checkpointing_func( + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -235,7 +317,7 @@ def forward( **kwargs ) else: - hidden_states, attentions, past_key_values, aux_loss = layer( + hidden_states, attentions, past_key_values, router_logits = layer( hidden_states, attention_mask=attention_mask, past_key_values=past_key_values, @@ -246,8 +328,7 @@ def forward( if output_attentions: all_attns += (attentions,) - if aux_loss is not None: - total_aux_loss = total_aux_loss + aux_loss + all_router_logits += (router_logits,) hidden_states = self.norm(hidden_states) @@ -262,13 +343,14 @@ def forward( past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attns, - aux_loss=total_aux_loss + router_logits=all_router_logits ) @dataclass class MomCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None class MomForCausalLM(MomPreTrainedModel, GenerationMixin): @@ -283,7 +365,6 @@ def __init__(self, config): self.num_memories = config.num_memories self.topk = config.topk self.aux_loss_scale = config.aux_loss_scale - self.num_hidden_layers = config.num_hidden_layers # Initialize weights and apply final processing self.post_init() @@ -412,20 +493,16 @@ def forward( else: loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - # valid_router_logits = tuple( - # logits - # for logits in (outputs.router_logits if return_dict else outputs[-1]) - # if logits is not None - # ) - # aux_loss = load_balancing_loss_func( - # valid_router_logits, - # self.num_memories, - # self.topk, - # use_layer_wise_balance=self.config.use_layer_wise_balance, - # ) - aux_loss = self.aux_loss_scale * outputs.aux_loss / self.num_hidden_layers - - loss += aux_loss + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_memories, + self.topk, + attention_mask, + ) + + # print(aux_loss) + + loss += aux_loss.to(loss.device) * self.aux_loss_scale if not return_dict: output = (logits,) + outputs[1:] @@ -437,79 +514,6 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + router_logits=outputs.router_logits, aux_loss=aux_loss - ) - - -def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple], - num_memories: torch.Tensor = None, - top_k=2, - use_layer_wise_balance=False, -) -> torch.FloatTensor: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. - - Args: - gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): - Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. - num_memories (`int`, *optional*): - Number of experts - - Returns: - The auxiliary loss. - """ - if gate_logits is None or ( - isinstance(gate_logits, Iterable) and len(gate_logits) == 0 - ): - return 0 - - # ✨ Here is the fix for balance loss in Mixtral. - # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. - if use_layer_wise_balance: - if not isinstance(gate_logits, Iterable): - gate_logits = (gate_logits,) - else: - if isinstance(gate_logits, Iterable): - gate_logits = (torch.cat(gate_logits, dim=0),) - else: - gate_logits = (gate_logits,) - - all_balance_losses = [] - - for logits in gate_logits: - routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) - routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) - routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) - - # cast the expert indices to int64, otherwise one-hot encoding will fail - if selected_experts.dtype != torch.int64: - selected_experts = selected_experts.to(torch.int64) - - if len(selected_experts.shape) == 2: - selected_experts = selected_experts.unsqueeze(2) - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) - - # For a given token, determine if it was routed to a given expert. - expert_mask = torch.max(expert_mask, axis=-2).values - - # cast to float32 otherwise mean will fail - expert_mask = expert_mask.to(torch.float32) - tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) - - router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) - - # ✨ balance loss for this layer - balance_loss = torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert - ) * (num_memories**2) - all_balance_losses.append(balance_loss.reshape(1)) - - all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ - - return all_balance_losses + ) \ No newline at end of file