11# Copyright © 2023-2024 Apple Inc.
22
33from dataclasses import dataclass
4- from typing import Any , Dict , Optional , Union
4+ from typing import Any , Dict , List , Optional , Union
55
66import mlx .core as mx
77import mlx .nn as nn
88
99from .base import BaseModelArgs , create_attention_mask , scaled_dot_product_attention
10+ from .cache import KVCache , RotatingKVCache
1011from .rope_utils import initialize_rope
1112
1213
@@ -28,11 +29,16 @@ class ModelArgs(BaseModelArgs):
2829 rope_traditional : bool = False
2930 rope_scaling : Optional [Dict [str , Union [float , str ]]] = None
3031 tie_word_embeddings : bool = True
32+ layer_types : Optional [List [str ]] = None
33+ sliding_window : Optional [int ] = None
3134
3235 def __post_init__ (self ):
3336 if self .num_key_value_heads is None :
3437 self .num_key_value_heads = self .num_attention_heads
3538
39+ if self .layer_types is None :
40+ self .layer_types = ["full_attention" ] * self .num_hidden_layers
41+
3642
3743class Attention (nn .Module ):
3844 def __init__ (self , args : ModelArgs ):
@@ -114,10 +120,11 @@ def __call__(self, x) -> mx.array:
114120
115121
116122class TransformerBlock (nn .Module ):
117- def __init__ (self , args : ModelArgs ):
123+ def __init__ (self , args : ModelArgs , use_sliding : bool = False ):
118124 super ().__init__ ()
119125 self .num_attention_heads = args .num_attention_heads
120126 self .hidden_size = args .hidden_size
127+ self .use_sliding = use_sliding
121128 self .self_attn = Attention (args )
122129 self .mlp = MLP (args )
123130 self .input_layernorm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
@@ -145,12 +152,21 @@ def __init__(self, args: ModelArgs):
145152 self .args = args
146153 self .vocab_size = args .vocab_size
147154 self .num_hidden_layers = args .num_hidden_layers
155+ self .layer_types = args .layer_types
156+ self .sliding_window = args .sliding_window
148157 assert self .vocab_size > 0
149158 self .embed_tokens = nn .Embedding (args .vocab_size , args .hidden_size )
150159 self .layers = [
151- TransformerBlock (args = args ) for _ in range (args .num_hidden_layers )
160+ TransformerBlock (args = args , use_sliding = layer_type == "sliding_attention" )
161+ for layer_type in self .layer_types
152162 ]
153163 self .norm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
164+ self .fa_idx = self .layer_types .index ("full_attention" )
165+ self .swa_idx = None
166+ for e , l in enumerate (self .layers ):
167+ if l .use_sliding :
168+ self .swa_idx = e
169+ break
154170
155171 def __call__ (
156172 self ,
@@ -166,10 +182,15 @@ def __call__(
166182 if cache is None :
167183 cache = [None ] * len (self .layers )
168184
169- mask = create_attention_mask (h , cache [0 ])
185+ fa_mask = create_attention_mask (h , cache [self .fa_idx ])
186+ if self .swa_idx is not None :
187+ swa_mask = create_attention_mask (
188+ h , cache [self .swa_idx ], window_size = self .sliding_window
189+ )
170190
171- for layer , c in zip (self .layers , cache ):
172- h = layer (h , mask , cache = c )
191+ for layer , cache in zip (self .layers , cache ):
192+ mask = swa_mask if layer .use_sliding else fa_mask
193+ h = layer (h , mask , cache = cache )
173194
174195 return self .norm (h )
175196
@@ -208,3 +229,13 @@ def sanitize(self, weights):
208229 @property
209230 def layers (self ):
210231 return self .model .layers
232+
233+ def make_cache (self ):
234+ return [
235+ (
236+ RotatingKVCache (max_size = self .model .sliding_window )
237+ if layer .use_sliding
238+ else KVCache ()
239+ )
240+ for layer in self .layers
241+ ]
0 commit comments