diff --git a/fla/__init__.py b/fla/__init__.py index 63e70d782..55d3078b6 100644 --- a/fla/__init__.py +++ b/fla/__init__.py @@ -16,6 +16,7 @@ LightNetAttention, LinearAttention, MesaNet, + MomAttention, MultiheadLatentAttention, MultiScaleRetention, NativeSparseAttention, @@ -54,6 +55,8 @@ MesaNetModel, MLAForCausalLM, MLAModel, + MomForCausalLM, + MomModel, NSAForCausalLM, NSAModel, PaTHAttentionForCausalLM, @@ -86,6 +89,7 @@ 'LightNetAttention', 'LightNetForCausalLM', 'LightNetModel', 'LinearAttention', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 'MesaNet', 'MesaNetForCausalLM', 'MesaNetModel', + 'MomAttention', 'MomForCausalLM', 'MomModel', 'MultiheadLatentAttention', 'MLAForCausalLM', 'MLAModel', 'MultiScaleRetention', 'RetNetForCausalLM', 'RetNetModel', 'NativeSparseAttention', 'NSAForCausalLM', 'NSAModel', diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py index 9df865d87..ec3f4b26b 100644 --- a/fla/layers/__init__.py +++ b/fla/layers/__init__.py @@ -20,6 +20,7 @@ from .mamba2 import Mamba2 from .mesa_net import MesaNet from .mla import MultiheadLatentAttention +from .mom import MomAttention from .multiscale_retention import MultiScaleRetention from .nsa import NativeSparseAttention from .path_attn import PaTHAttention @@ -47,6 +48,7 @@ 'Mamba', 'Mamba2', 'MesaNet', + 'MomAttention', 'MultiheadLatentAttention', 'MultiScaleRetention', 'NativeSparseAttention', diff --git a/fla/layers/mom.py b/fla/layers/mom.py new file mode 100644 index 000000000..cea23d17a --- /dev/null +++ b/fla/layers/mom.py @@ -0,0 +1,788 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input, unpad_input + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + gate_layer: torch.Tensor, + beta_layer: torch.Tensor, + attention_mask: torch.Tensor, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query_length = query_layer.shape[1] + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + batch_size, kv_seq_len, dim = key_layer.shape + v_dim = value_layer.shape[-1] + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, v_dim), indices_k + ) + gate_layer = index_first_axis(gate_layer.reshape(batch_size * kv_seq_len, -1), indices_k) + beta_layer = index_first_axis(beta_layer.reshape(batch_size * kv_seq_len, -1), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + gate_layer, + beta_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def transform( + x: torch.Tensor, + routing_mask: torch.Tensor, + num_memories: int, + selected_memories: torch.Tensor, + capacity: float, + attention_mask: torch.Tensor, +): + ''' + Transform input sequences into memory-organized chunks with capacity constraints. + + Processes input sequences by routing tokens to designated memory states according to routing_mask, + sorts tokens by memory assignments, handles token truncation/padding based on memory capacity, + and returns memory-aligned tensors for parallel processing. + + Key operations: + 1. Expands input tensors when multiple memories are selected per token (top-k routing) + 2. Sorts tokens globally by (batch_idx, memory_idx) to group memory-assigned tokens + 3. Applies capacity-aware truncation (left-truncate oldest tokens when exceeding capacity) + 4. Pads memory chunks to uniform length for tensorization + + Args: + x: Input hidden states + Shape: (batch_size, seq_len, hidden_size) + routing_mask: Binary mask indicating active memory assignments + Shape: (batch_size, seq_len, num_memories) + num_memories: Total number of memories per batch + selected_memories: Memory indices assigned to each token. When using top-k routing, + this contains k memory indices per token (k >= 1) + Shape: (batch_size, seq_len) for k=1 or (batch_size, seq_len, topk) for k>1 + capacity: Scaling factor for memory capacity calculation. Actual capacity per memory is + ceil(seq_len * capacity), maintaining proportional capacity to sequence length + + Returns: + transformed_x: Memory-organized tensor with zero-padded capacity alignment + Shape: (num_memories, batch_size, capacity_len, hidden_size) + truncation_indices: Original indices used for gathering tokens after capacity truncation + Shape: (batch*num_memories, max_len) + sorted_indices: Sorting indices used to group tokens by memory assignments + Shape: (batch_size*seq_len*topk) + max_len: Maximum tokens per memory + mask: Boolean mask indicating valid (non-padded) positions in transformed_x + Shape: (batch*num_memories, max_len) + ''' + if selected_memories.dim() == 3: + # (batch, seq, topk) + topk = selected_memories.shape[2] + # x (batch, seq, hidden) + x = x.repeat_interleave(topk, dim=1) + # x (batch, seq * topk, hidden) + # (batch, seq, topk) + selected_memories = selected_memories.reshape(selected_memories.shape[0], -1) + # (batch, seq * topk) + + if attention_mask is not None: + attention_mask = attention_mask[:, -routing_mask.shape[1]:] + # mask out the masked tokens + routing_mask[attention_mask.bitwise_not().unsqueeze(-1).expand(-1, -1, num_memories)] = 0 + + b, s, d = x.shape + x_flat = x.reshape(b * s, d) # [b*s, d] + + with torch.no_grad(): + batch_indices = torch.arange(b, device=x.device).unsqueeze(-1) + batch_indices = batch_indices.repeat(1, s).reshape(-1) + if attention_mask is not None: + # sort the masked tokens to the end + batch_indices[attention_mask.repeat_interleave(topk, dim=1).bitwise_not().flatten()] = b + # (b * s) + memories_flat = selected_memories.reshape(-1) # [b*s] + + combined = batch_indices * (memories_flat.max() + 1) + memories_flat + sorted_indices = combined.argsort() + + x_sorted = x_flat[sorted_indices] # [b*s, d] + # (b*s, hidden) -> (b, s, hidd) + with torch.no_grad(): + # routing_mask (b, s, num_memories) + batch_memory_tokens = routing_mask.sum(dim=1) + # (b, num_memories) + flatten_offset = batch_memory_tokens.flatten().cumsum(dim=0) + max_len = batch_memory_tokens.max() + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( + b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + mask = indices < flatten_offset.unsqueeze(-1) + truncation_indices = torch.where(mask, indices, torch.zeros_like(indices)) + + gathered_x = torch.gather(x_sorted, 0, truncation_indices.reshape(-1).unsqueeze(-1).expand(-1, d)) + transformed_x = gathered_x.reshape(b * num_memories, -1, d).reshape((b, num_memories, max_len, d)).transpose(0, 1) + # transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) + # pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) + # pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) + # left pad + # transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) + mask_2 = mask.reshape((b, num_memories, max_len)).transpose(0, 1) + # truncation_indices += capacity_len-max_len + # if attention_mask is not None: + # mask_2 + + return transformed_x, truncation_indices, sorted_indices, max_len, mask, mask_2 + + +def reconstruct( + transformed_x, + indices: torch.Tensor, + sorted_indices: torch.Tensor, + batch_size: int, + seq_len: int, + topk: int, + routing_weights: torch.Tensor, + mask: torch.Tensor +): + ''' + Reconstruct and mix transformed outputs back into the original input sequence shape. + + Key operations: + 1. Reshapes and transposes `transformed_x` to prepare for scattering. + 2. Applies the `mask` to zero out invalid positions. + 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`. + 4. Rearranges the scattered outputs using `sorted_indices` to ensure correct ordering. + 5. Applies the `routing_weights` to weight the outputs. + 6. Sums over the `topk` dimension to produce the final reconstructed output. + + Args: + transformed_x (torch.Tensor): + The transformed output tensor from memory units or experts. + Shape: (num_memories, batch_size, capacity_len, hidden_size) + indices (torch.Tensor): + Indices used for scattering the transformed outputs back to their corresponding positions. + Shape: (batch*num_memories, max_len) + sorted_indices (torch.Tensor): + Sorting indices used to rearrange the scattered outputs back into the original sequence order. + Shape: (batch_size*seq_len*topk) + batch_size (int): + The size of the batch. + seq_len (int): + The length of the input sequence. + topk (int): + The number of top elements selected (`topk`) per token during the selection process. + routing_weights (torch.Tensor): + Routing weights assigned to the top-k selected outputs when reconstructing the final output. + Shape: (batch_size, seq_len, topk) + mask (torch.Tensor): + Boolean mask indicating valid positions in the sequence. + Shape: (batch*num_memories, max_len) + + Returns: + restored_x (torch.Tensor): + The reconstructed output tensor in the original input sequence shape. + Shape: (batch_size, seq_len, hidden_size) + ''' + transformed_x = transformed_x.transpose(0, 1).reshape( + (-1, transformed_x.shape[2], transformed_x.shape[3])) + b, s, k, d = batch_size, seq_len, topk, transformed_x.shape[2] + gathered_x = transformed_x.reshape( + (transformed_x.shape[0] * transformed_x.shape[1], transformed_x.shape[2])) + mask_expanded = mask.reshape(-1).unsqueeze(-1).expand_as(gathered_x) + gathered_x = gathered_x * mask_expanded + + assert (indices >= 0).all(), "Indices should be non-negative" + + resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + 0, + indices.reshape(-1).unsqueeze(-1).expand(-1, d), + gathered_x, + ) + assert (indices < resortd_x.size(0)).all(), "Indices should be less than resortd_x size" + + inverse_indices = sorted_indices.argsort() + rearranged_x_flat = resortd_x[inverse_indices] + restored_x = rearranged_x_flat.reshape((b, s * k, d)) + restored_x = restored_x.reshape(b, s, k, d) * routing_weights.reshape(b, s, k).unsqueeze(-1) + restored_x = restored_x.sum(dim=2) + return restored_x + + +class MomAttention(nn.Module): + """ + The layer implementaion for [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685). + """ + + def __init__( + self, + hidden_size: int = 2048, + head_dim: int = 256, + num_heads: int = 4, + expand_v: float = 2, + mode: str = 'chunk', + use_output_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + num_memories: int = 8, + topk: int = 2, + capacity: float = 1.0, + shared_mem: bool = False, + single_kv_proj: bool = False, + **kwargs + ) -> MomAttention: + super().__init__() + self.num_memories = num_memories + self.topk = topk + self.capacity = capacity + self.shared_mem = shared_mem + self.single_kv_proj = single_kv_proj + + self.mode = mode + + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_output_gate = use_output_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + + self.key_dim = int(self.num_heads * self.head_dim) + self.value_dim = int(self.key_dim * self.expand_v) + self.head_qk_dim = head_dim + self.head_v_dim = int(head_dim * self.expand_v) + self.layer_idx = layer_idx + self.silu = nn.SiLU() + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.gate = nn.Linear(self.hidden_size, self.num_memories, bias=False) + if self.single_kv_proj: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + else: + self.k_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.key_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.v_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.value_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.b_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ]) + self.a_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ]) + if self.shared_mem: + self.shared_k = nn.Linear(hidden_size, self.key_dim, bias=False) + self.shared_v = nn.Linear(hidden_size, self.value_dim, bias=False) + self.shared_b = nn.Linear(hidden_size, self.num_heads, bias=False) + self.shared_a = nn.Linear(hidden_size, self.num_heads, bias=False) + + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu' + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu' + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + bias=conv_bias, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + attention_mask = (attention_mask == 1) + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + origin_cu_seqlens = kwargs.get('cu_seqlens', None) + if origin_cu_seqlens is not None: + hidden_states, attention_mask = self.cu2pad(hidden_states, origin_cu_seqlens) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = None + # _, q_len = hidden_states.shape[0], hidden_states.shape[1] + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + # 🔍 topk gating + router_logits = self.gate(hidden_states) # (bsz, q_len, num_memories) + scores = F.softmax(router_logits, dim=2, dtype=torch.float) + routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # (bsz, seq, topk) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) # we cast back to the input dtype + routing_weights_full = torch.zeros( + routing_weights.shape[0], + routing_weights.shape[1], + self.num_memories, + dtype=routing_weights.dtype, + device=routing_weights.device + ).scatter(-1, selected_memories, routing_weights) + routing_mask = routing_weights_full.bool().int() + + # if self.use_output_gate: + # o_g = self.g_proj(hidden_states) + + batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] + + shared_hidden_states = hidden_states + hidden_states, indices, sorted_indices, max_len, mask, mask_2 = transform( + hidden_states, routing_mask, self.num_memories, selected_memories, self.capacity, attention_mask) + + q = self.q_proj(hidden_states) + if self.single_kv_proj: + k = self.shared_k(hidden_states) + v = self.shared_v(hidden_states) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + else: + k = torch.stack([k_expert(hidden_states[i]) for i, k_expert in enumerate(self.k_proj)], dim=0) + v = torch.stack([v_expert(hidden_states[i]) for i, v_expert in enumerate(self.v_proj)], dim=0) + beta = torch.stack([b_expert(hidden_states[i]).sigmoid() for i, b_expert in enumerate(self.b_proj)], dim=0) + g = torch.stack([-self.A_log.float().exp() * F.softplus(a_expert(hidden_states[i]).float() + self.dt_bias) + for i, a_expert in enumerate(self.a_proj)], dim=0) + + q, k, v, g, beta, mask_2 = (rearrange(x, 'e b l ... -> (e b) l ...') for x in (q, k, v, g, beta, mask_2)) + cu_q, cu_k, cu_v, cu_g, cu_beta, indices_q, cu_seqlen_all, max_seq_lens = _upad_input(q, k, v, g, beta, mask_2) + cu_seqlens, reverse_indices = cu_seqlen_all[0].to(torch.long).unique(return_inverse=True) + cu_q, cu_k, cu_v, cu_g, cu_beta = (x.unsqueeze(0).contiguous() for x in (cu_q, cu_k, cu_v, cu_g, cu_beta)) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = [None, None], [None, None], [None, None] + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + + conv_cu_seqlens = cu_seqlens + padded = False + if self.training: + conv_cu_seqlens = None + elif seq_len != 1 and (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() < self.conv_size: + padded = True + conv_cu_seqlens, cu_q, cu_k, cu_v, pad_lengths = self.pad_for_conv(cu_seqlens, cu_q, cu_k, cu_v) + + conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_q, conv_q_new = self.q_conv1d( + x=cu_q, + cache=conv_q, + output_final_state=use_cache, + cu_seqlens=conv_cu_seqlens, + ) + conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_k = self.prepare_recurrent_state(conv_state_k[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_k, conv_k_new = self.k_conv1d( + x=cu_k, + cache=conv_k, + output_final_state=use_cache, + cu_seqlens=conv_cu_seqlens, + ) + conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_v = self.prepare_recurrent_state(conv_state_v[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + cu_v, conv_v_new = self.v_conv1d( + x=cu_v, + cache=conv_v, + output_final_state=use_cache, + cu_seqlens=conv_cu_seqlens, + ) + conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + + if padded: + cu_q, cu_k, cu_v = self.unpad_after_conv(conv_cu_seqlens, cu_seqlens, cu_q, cu_k, cu_v, pad_lengths) + + else: + q, k, v = self.silu(q), self.silu(k), self.silu(v), + + cu_q, cu_k, cu_v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (cu_q, cu_k, cu_v)) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else [ + None for _ in range(1 + self.shared_mem)] + if mode == 'chunk': + o, recurrent_state_ = chunk_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=recurrent_state[0], + output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) + + elif mode == 'fused_recurrent': + memories = self.prepare_recurrent_state(recurrent_state[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + o, recurrent_state_ = fused_recurrent_gated_delta_rule( + q=cu_q, + k=cu_k, + v=cu_v, + g=cu_g, + beta=cu_beta, + initial_state=memories, + output_final_state=use_cache, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) + + o = o.squeeze(0).contiguous() + o = pad_input(o, indices_q, batch_size*self.num_memories, max_len) + o = rearrange(o, '(e b) l h d -> e b l (h d)', b=batch_size) + o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) + + if self.shared_mem: + shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, + use_cache, conv_state_q, conv_state_k, conv_state_v) + o += shared_o + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[2] + ) + + if self.use_output_gate: + g = rearrange(self.g_proj(shared_hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + if origin_cu_seqlens is not None: + indices, _, _ = get_unpad_data(attention_mask[:, -seq_len:]) + o = index_first_axis(rearrange(o, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + return o, None, past_key_values, router_logits.view(-1, self.num_memories) + + def shared_o( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + recurrent_state=None, + use_cache: Optional[bool] = False, + conv_state_q=[None, None], + conv_state_k=[None, None], + conv_state_v=[None, None], + **kwargs + ) -> torch.Tensor: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + cu_seqlens = None + if attention_mask is not None: + batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1] + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + q, conv_state_q[1] = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q[1], + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k[1] = self.k_conv1d( + x=self.shared_k(hidden_states), + cache=conv_state_k[1], + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v[1] = self.v_conv1d( + x=self.shared_v(hidden_states), + cache=conv_state_v[1], + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.silu(self.q_proj(hidden_states)) + k = self.silu(self.shared_k(hidden_states)) + v = self.silu(self.shared_v(hidden_states)) + + q, k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q, k, v)) + beta = self.shared_b(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.shared_a(hidden_states).float() + self.dt_bias) + + if mode == 'chunk': + o, recurrent_state[-1] = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True + ) + elif mode == 'fused_recurrent': + o, recurrent_state[-1] = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state[-1], + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + return o + + def cu2pad(self, x, cu_seqlens): + batch_size = cu_seqlens.shape[0] - 1 + max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() + indices = torch.tensor([], dtype=torch.long, device=x.device) + attention_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=x.device) + for i in range(batch_size): + seq_len = cu_seqlens[i+1] - cu_seqlens[i] + pad_len = max_len - seq_len + batch_indices = torch.arange(pad_len, max_len, device=x.device) + batch_indices = batch_indices + i * max_len + indices = torch.cat([indices, batch_indices]) + attention_mask[i, :pad_len] = False + x = pad_input(x.squeeze(0), indices, batch_size, max_len) + return x, attention_mask + + def pad_for_conv(self, cu_seqlens, cu_q, cu_k, cu_v): + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + pad_lengths = torch.clamp(self.conv_size - lengths, min=0) + new_lengths = lengths + pad_lengths + new_cu_seqlens = torch.cat([ + torch.tensor([0], device=cu_seqlens.device, dtype=cu_seqlens.dtype), + torch.cumsum(new_lengths, dim=0) + ]) + final_total_len = new_cu_seqlens[-1].item() + new_q = torch.zeros((1, final_total_len, cu_q.shape[-1]), dtype=cu_q.dtype, device=cu_q.device) + new_k = torch.zeros((1, final_total_len, cu_k.shape[-1]), dtype=cu_k.dtype, device=cu_k.device) + new_v = torch.zeros((1, final_total_len, cu_v.shape[-1]), dtype=cu_v.dtype, device=cu_v.device) + num_sequences = len(lengths) + for i in range(num_sequences): + src_start = cu_seqlens[i] + src_end = cu_seqlens[i+1] + dest_start = new_cu_seqlens[i] + pad_lengths[i] + dest_end = new_cu_seqlens[i+1] + new_q[:, dest_start:dest_end, ...] = cu_q[:, src_start:src_end, ...] + new_k[:, dest_start:dest_end, ...] = cu_k[:, src_start:src_end, ...] + new_v[:, dest_start:dest_end, ...] = cu_v[:, src_start:src_end, ...] + + return new_cu_seqlens, new_q, new_k, new_v, pad_lengths + + def unpad_after_conv(self, conv_cu_seqlens, cu_seqlens, cu_q, cu_k, cu_v, pad_lengths): + original_total_len = cu_seqlens[-1].item() + orig_q = torch.empty((1, original_total_len, cu_q.shape[-1]), dtype=cu_q.dtype, device=cu_q.device) + orig_k = torch.empty((1, original_total_len, cu_k.shape[-1]), dtype=cu_k.dtype, device=cu_k.device) + orig_v = torch.empty((1, original_total_len, cu_v.shape[-1]), dtype=cu_v.dtype, device=cu_v.device) + + num_sequences = len(pad_lengths) + for i in range(num_sequences): + dest_start = cu_seqlens[i] + dest_end = cu_seqlens[i+1] + src_start = conv_cu_seqlens[i] + pad_lengths[i] + src_end = conv_cu_seqlens[i+1] + + orig_q[:, dest_start:dest_end, ...] = cu_q[:, src_start:src_end, ...] + orig_k[:, dest_start:dest_end, ...] = cu_k[:, src_start:src_end, ...] + orig_v[:, dest_start:dest_end, ...] = cu_v[:, src_start:src_end, ...] + return orig_q, orig_k, orig_v + + def prepare_recurrent_state(self, recurrent_state, cu_seqlens, cu_seqlen_all, reverse_indices, batch_size): + if recurrent_state is None: + return None + + if cu_seqlens is None: + return recurrent_state + + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + # select memories that are activated + memories = torch.zeros_like(recurrent_state[:self.topk*batch_size]) + mem_id = 0 + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + memories[mem_id] = recurrent_state[i] + mem_id += 1 + assert mem_id == self.topk * batch_size, f"The number of memories {mem_id} is not correct." + else: + memories = recurrent_state + + return memories + + def handle_recurrent_state(self, recurrent_state, recurrent_state_new, cu_seqlens, cu_seqlen_all, reverse_indices): + if recurrent_state_new is None: + return None + if cu_seqlens is None: + return recurrent_state_new + if recurrent_state is None: + recurrent_state = torch.zeros_like(recurrent_state_new[reverse_indices[1:]-1]) + total_len = len(cu_seqlen_all) + if len(cu_seqlens) != total_len: + for i in range(total_len-1): + if cu_seqlen_all[i] != cu_seqlen_all[i+1]: + recurrent_state[i] = recurrent_state_new[reverse_indices[i+1]-1] + else: + recurrent_state = recurrent_state_new + return recurrent_state \ No newline at end of file diff --git a/fla/models/__init__.py b/fla/models/__init__.py index 9cc61e369..6995f13d0 100644 --- a/fla/models/__init__.py +++ b/fla/models/__init__.py @@ -21,6 +21,7 @@ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel from fla.models.mla import MLAConfig, MLAForCausalLM, MLAModel +from fla.models.mom import MomConfig, MomForCausalLM, MomModel from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel @@ -47,6 +48,7 @@ 'MambaConfig', 'MambaForCausalLM', 'MambaModel', 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model', 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', + 'MomConfig', 'MomForCausalLM', 'MomModel', 'MLAConfig', 'MLAForCausalLM', 'MLAModel', 'NSAConfig', 'NSAForCausalLM', 'NSAModel', 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', diff --git a/fla/models/mom/__init__.py b/fla/models/mom/__init__.py new file mode 100644 index 000000000..bcc101c64 --- /dev/null +++ b/fla/models/mom/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mom.configuration_mom import MomConfig +from fla.models.mom.modeling_mom import MomForCausalLM, MomModel + +AutoConfig.register(MomConfig.model_type, MomConfig, exist_ok=True) +AutoModel.register(MomConfig, MomModel, exist_ok=True) +AutoModelForCausalLM.register(MomConfig, MomForCausalLM, exist_ok=True) + +__all__ = ['MomConfig', 'MomForCausalLM', 'MomModel'] diff --git a/fla/models/mom/configuration_mom.py b/fla/models/mom/configuration_mom.py new file mode 100644 index 000000000..b8573e362 --- /dev/null +++ b/fla/models/mom/configuration_mom.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class MomConfig(PretrainedConfig): + model_type = 'mom' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + conv_size: int = 4, + num_heads: int = 4, + head_dim: int = 256, + expand_v: float = 1., + use_output_gate: bool = True, + use_short_conv: bool = True, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + num_memories: int = 4, + topk: int = 2, + capacity: float = 1.0, + use_layer_wise_balance: bool = True, + aux_loss_scale: float = 0.01, + shared_mem: bool = True, + single_kv_proj: bool = False, + mom_backend: str = 'gated_deltanet', + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.expand_v = expand_v + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.use_short_conv = use_short_conv + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.num_memories = num_memories + self.topk = topk + self.capacity = capacity + self.use_layer_wise_balance = use_layer_wise_balance + self.aux_loss_scale = aux_loss_scale + self.shared_mem = shared_mem + self.single_kv_proj = single_kv_proj + self.mom_backend = mom_backend + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if self.mom_backend not in ['gated_deltanet']: + raise NotImplementedError(f"The MoM backend {mom_backend} is not currently supported.") + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['window_size'] = attn.get('window_size', None) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/mom/modeling_mom.py b/fla/models/mom/modeling_mom.py new file mode 100644 index 000000000..26e8901a5 --- /dev/null +++ b/fla/models/mom/modeling_mom.py @@ -0,0 +1,519 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers import MomAttention +from fla.layers.attn import Attention +from fla.models.mom.configuration_mom import MomConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as MomMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MomBlock(nn.Module): + def __init__(self, config: MomConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + if config.mom_backend == 'gated_deltanet': + self.attn = MomAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_output_gate=config.use_output_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + layer_idx=layer_idx, + num_memories=config.num_memories, + topk=config.topk, + capacity=config.capacity, + shared_mem=config.shared_mem, + single_kv_proj=config.single_kv_proj, + ) + else: + raise NotImplementedError(f"The MoM backend {config.mom_backend} is not currently supported.") + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = MomMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + if hasattr(self, 'attn_norm'): + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values, router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if hasattr(self, 'mlp_norm'): + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values, router_logits) + + return outputs + + +class MomPreTrainedModel(PreTrainedModel): + + config_class = MomConfig + supports_gradient_checkpointing = True + _no_split_modules = ['MomBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +@dataclass +class MomOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class MomModel(MomPreTrainedModel): + + def __init__(self, config: MomConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([MomBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return MomOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +@dataclass +class MomCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class MomForCausalLM(MomPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MomModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.num_memories = config.num_memories + self.topk = config.topk + self.aux_loss_scale = config.aux_loss_scale + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + num_logits_to_keep: Optional[int] = 0, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if num_logits_to_keep is not None: + model_inputs['num_logits_to_keep'] = num_logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + 'num_logits_to_keep': num_logits_to_keep, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + num_logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:]) + + loss = None + aux_loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + if fuse_linear_and_cross_entropy: + loss_fct = FusedLinearCrossEntropyLoss() + else: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = loss_fct(hidden_states.view(-1, self.config.hidden_size), + labels.view(-1), + self.lm_head.weight, + self.lm_head.bias) + else: + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_memories, + self.topk, + attention_mask, + ) + + # print(aux_loss) + + loss += aux_loss.to(loss.device) * self.aux_loss_scale + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MomCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 88a568246..4c6ff1585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,3 +40,6 @@ multi_line_output = 3 [tool.pytest.ini_options] log_cli = true log_cli_level = "INFO" +pythonpath = [ + "." +] diff --git a/tests/models/test_modeling_mom.py b/tests/models/test_modeling_mom.py new file mode 100644 index 000000000..8c89ac32d --- /dev/null +++ b/tests/models/test_modeling_mom.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +import pytest +import torch + +from fla.models import MomConfig + +from .test_modeling_base import run_test_generation, run_test_model_forward_backward + + +# =================================================================================== +# Test for Modeling (Forward/Backward Pass) +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) + for test in [ + (4, 4, 1024, 4, 64, True, torch.bfloat16), + (4, 4, 1024, 4, 64, False, torch.bfloat16), + (4, 4, 1024, 4, 128, False, torch.bfloat16), + ] + ] +) +def test_modeling( + L: int, + B: int, + T: int, + H: int, + D: int, + use_l2warp: bool, + dtype: torch.dtype, +): + run_test_model_forward_backward(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype) + + +# =================================================================================== +# Test for Generation +# =================================================================================== +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'dtype'], + [ + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) + for test in [ + (2, 4, 2000, 8, 64, torch.float16), + ] + ] +) +def test_generation( + L: int, + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + run_test_generation(L, B, T, H, D, MomConfig, dtype)