Skip to content

Commit dcb4b9b

Browse files
dnakovawni
andauthored
Add Code World Model support (#505)
* Add sliding-window support to LLaMA * nits * version --------- Co-authored-by: dnakov <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent 358b4d2 commit dcb4b9b

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2025 Apple Inc.
22

3-
__version__ = "0.28.0"
3+
__version__ = "0.28.1"

mlx_lm/models/llama.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Copyright © 2023-2024 Apple Inc.
22

33
from dataclasses import dataclass
4-
from typing import Any, Dict, Optional, Union
4+
from typing import Any, Dict, List, Optional, Union
55

66
import mlx.core as mx
77
import mlx.nn as nn
88

99
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
from .cache import KVCache, RotatingKVCache
1011
from .rope_utils import initialize_rope
1112

1213

@@ -28,11 +29,16 @@ class ModelArgs(BaseModelArgs):
2829
rope_traditional: bool = False
2930
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
3031
tie_word_embeddings: bool = True
32+
layer_types: Optional[List[str]] = None
33+
sliding_window: Optional[int] = None
3134

3235
def __post_init__(self):
3336
if self.num_key_value_heads is None:
3437
self.num_key_value_heads = self.num_attention_heads
3538

39+
if self.layer_types is None:
40+
self.layer_types = ["full_attention"] * self.num_hidden_layers
41+
3642

3743
class Attention(nn.Module):
3844
def __init__(self, args: ModelArgs):
@@ -114,10 +120,11 @@ def __call__(self, x) -> mx.array:
114120

115121

116122
class TransformerBlock(nn.Module):
117-
def __init__(self, args: ModelArgs):
123+
def __init__(self, args: ModelArgs, use_sliding: bool = False):
118124
super().__init__()
119125
self.num_attention_heads = args.num_attention_heads
120126
self.hidden_size = args.hidden_size
127+
self.use_sliding = use_sliding
121128
self.self_attn = Attention(args)
122129
self.mlp = MLP(args)
123130
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@@ -145,12 +152,21 @@ def __init__(self, args: ModelArgs):
145152
self.args = args
146153
self.vocab_size = args.vocab_size
147154
self.num_hidden_layers = args.num_hidden_layers
155+
self.layer_types = args.layer_types
156+
self.sliding_window = args.sliding_window
148157
assert self.vocab_size > 0
149158
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
150159
self.layers = [
151-
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
160+
TransformerBlock(args=args, use_sliding=layer_type == "sliding_attention")
161+
for layer_type in self.layer_types
152162
]
153163
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
164+
self.fa_idx = self.layer_types.index("full_attention")
165+
self.swa_idx = None
166+
for e, l in enumerate(self.layers):
167+
if l.use_sliding:
168+
self.swa_idx = e
169+
break
154170

155171
def __call__(
156172
self,
@@ -166,10 +182,15 @@ def __call__(
166182
if cache is None:
167183
cache = [None] * len(self.layers)
168184

169-
mask = create_attention_mask(h, cache[0])
185+
fa_mask = create_attention_mask(h, cache[self.fa_idx])
186+
if self.swa_idx is not None:
187+
swa_mask = create_attention_mask(
188+
h, cache[self.swa_idx], window_size=self.sliding_window
189+
)
170190

171-
for layer, c in zip(self.layers, cache):
172-
h = layer(h, mask, cache=c)
191+
for layer, cache in zip(self.layers, cache):
192+
mask = swa_mask if layer.use_sliding else fa_mask
193+
h = layer(h, mask, cache=cache)
173194

174195
return self.norm(h)
175196

@@ -208,3 +229,13 @@ def sanitize(self, weights):
208229
@property
209230
def layers(self):
210231
return self.model.layers
232+
233+
def make_cache(self):
234+
return [
235+
(
236+
RotatingKVCache(max_size=self.model.sliding_window)
237+
if layer.use_sliding
238+
else KVCache()
239+
)
240+
for layer in self.layers
241+
]

tests/test_models.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,49 @@ def test_mask_with_window(self):
175175
sums = mask.sum(axis=1)
176176
self.assertTrue(mx.array_equal(sums, expected_sums))
177177

178+
def test_llama_model_sliding_attention(self):
179+
from mlx_lm.models import llama
180+
181+
args = llama.ModelArgs(
182+
model_type="llama",
183+
hidden_size=64,
184+
num_hidden_layers=4,
185+
intermediate_size=256,
186+
num_attention_heads=8,
187+
num_key_value_heads=4,
188+
rms_norm_eps=1e-5,
189+
vocab_size=128,
190+
sliding_window=4,
191+
layer_types=[
192+
"full_attention",
193+
"sliding_attention",
194+
"sliding_attention",
195+
"full_attention",
196+
],
197+
tie_word_embeddings=False,
198+
rope_theta=10000.0,
199+
)
200+
model = llama.Model(args)
201+
202+
tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32)
203+
out = model(tokens)
204+
mx.eval(out)
205+
self.assertEqual(out.shape, (1, 5, args.vocab_size))
206+
207+
caches = model.make_cache()
208+
self.assertIsInstance(caches[0], KVCache)
209+
self.assertIsInstance(caches[1], RotatingKVCache)
210+
self.assertIsInstance(caches[2], RotatingKVCache)
211+
self.assertIsInstance(caches[3], KVCache)
212+
213+
caches = model.make_cache()
214+
step = model(tokens[:, :2], cache=caches)
215+
mx.eval(step)
216+
step = model(tokens[:, 2:3], cache=caches)
217+
mx.eval(step)
218+
self.assertEqual(caches[0].offset, 3)
219+
self.assertEqual(caches[1].offset, 3)
220+
178221
def test_rope(self):
179222
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
180223
self.assertTrue(isinstance(rope, nn.RoPE))

0 commit comments

Comments
 (0)