Skip to content

Commit edc656a

Browse files
dnakovawni
andauthored
Add support for nanochat (#554)
* Add support for nanochat * format * compile softcap * add test --------- Co-authored-by: dnakov <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent a4c6470 commit edc656a

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed

mlx_lm/models/nanochat.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
import math
4+
from dataclasses import dataclass
5+
from functools import partial
6+
from typing import Any, Optional
7+
8+
import mlx.core as mx
9+
import mlx.nn as nn
10+
11+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
12+
13+
14+
@dataclass
15+
class ModelArgs(BaseModelArgs):
16+
model_type: str = "nanochat"
17+
hidden_size: int = 1280
18+
num_hidden_layers: int = 20
19+
num_attention_heads: int = 10
20+
num_key_value_heads: int = 10
21+
vocab_size: int = 65536
22+
max_position_embeddings: int = 2048
23+
intermediate_size: int = 5120 # 4 * hidden_size
24+
rope_theta: float = 10000.0
25+
26+
27+
def rms_norm(x):
28+
"""Functional RMSNorm with no learnable parameters."""
29+
return mx.fast.rms_norm(x, None, 1e-5)
30+
31+
32+
def apply_rotary_emb(x, offset, base=10000.0, freqs=None):
33+
"""Apply RoPE with blocked layout.
34+
35+
36+
Args:
37+
x: Input tensor in (B, H, T, D) format
38+
offset: Position offset for KV caching
39+
base: RoPE base frequency (default 10000.0)
40+
freqs: Precomputed negated frequencies (optional)
41+
42+
Returns:
43+
Tensor with RoPE applied, same shape as input
44+
"""
45+
head_dim = x.shape[-1]
46+
47+
if freqs is None:
48+
# Compute negated frequencies
49+
half_D = head_dim // 2
50+
freqs = -mx.exp(
51+
mx.arange(0.0, half_D, dtype=mx.float32) * (math.log(base) / half_D)
52+
)
53+
54+
# Use traditional=False + negated freqs
55+
return mx.fast.rope(
56+
x,
57+
dims=head_dim,
58+
traditional=False,
59+
base=None,
60+
freqs=freqs,
61+
scale=1.0,
62+
offset=offset,
63+
)
64+
65+
66+
class Attention(nn.Module):
67+
def __init__(self, args: ModelArgs):
68+
super().__init__()
69+
70+
self.hidden_size = args.hidden_size
71+
self.num_heads = args.num_attention_heads
72+
self.num_kv_heads = args.num_key_value_heads
73+
self.head_dim = self.hidden_size // self.num_heads
74+
self.scale = self.head_dim**-0.5
75+
self.rope_theta = args.rope_theta
76+
77+
self.c_q = nn.Linear(
78+
self.hidden_size, self.num_heads * self.head_dim, bias=False
79+
)
80+
self.c_k = nn.Linear(
81+
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
82+
)
83+
self.c_v = nn.Linear(
84+
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
85+
)
86+
self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
87+
88+
# Precompute negated RoPE frequencies for awni's approach
89+
half_D = self.head_dim // 2
90+
self._rope_freqs = -mx.exp(
91+
mx.arange(0.0, half_D, dtype=mx.float32)
92+
* (math.log(self.rope_theta) / half_D)
93+
)
94+
95+
def __call__(
96+
self,
97+
x: mx.array,
98+
mask: Optional[mx.array] = None,
99+
cache: Optional[Any] = None,
100+
) -> mx.array:
101+
B, L, _ = x.shape
102+
103+
queries = self.c_q(x)
104+
keys = self.c_k(x)
105+
values = self.c_v(x)
106+
107+
# Reshape to (B, L, H, D) then transpose to (B, H, L, D)
108+
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
109+
0, 2, 1, 3
110+
)
111+
keys = keys.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(
112+
0, 2, 1, 3
113+
)
114+
values = values.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(
115+
0, 2, 1, 3
116+
)
117+
118+
# Apply RoPE using precomputed frequencies (expects B, H, T, D format)
119+
offset = cache.offset if cache is not None else 0
120+
queries = apply_rotary_emb(
121+
queries, offset=offset, base=self.rope_theta, freqs=self._rope_freqs
122+
)
123+
keys = apply_rotary_emb(
124+
keys, offset=offset, base=self.rope_theta, freqs=self._rope_freqs
125+
)
126+
127+
# QK norm (critical feature of nanochat!)
128+
queries = rms_norm(queries)
129+
keys = rms_norm(keys)
130+
131+
# Handle KV cache after transpose
132+
if cache is not None:
133+
keys, values = cache.update_and_fetch(keys, values)
134+
135+
output = scaled_dot_product_attention(
136+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
137+
)
138+
139+
# Reshape back
140+
output = output.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size)
141+
return self.c_proj(output)
142+
143+
144+
class MLP(nn.Module):
145+
def __init__(self, args: ModelArgs):
146+
super().__init__()
147+
self.c_fc = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
148+
self.c_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
149+
150+
def __call__(self, x: mx.array) -> mx.array:
151+
# Critical: nanochat uses ReLU^2, not GELU!
152+
x = self.c_fc(x)
153+
x = nn.relu2(x)
154+
return self.c_proj(x)
155+
156+
157+
class TransformerBlock(nn.Module):
158+
def __init__(self, args: ModelArgs):
159+
super().__init__()
160+
self.attn = Attention(args)
161+
self.mlp = MLP(args)
162+
163+
def __call__(
164+
self,
165+
x: mx.array,
166+
mask: Optional[mx.array] = None,
167+
cache: Optional[Any] = None,
168+
) -> mx.array:
169+
# Pre-norm architecture with functional RMSNorm
170+
h = x + self.attn(rms_norm(x), mask=mask, cache=cache)
171+
out = h + self.mlp(rms_norm(h))
172+
return out
173+
174+
175+
class NanoChatModel(nn.Module):
176+
def __init__(self, args: ModelArgs):
177+
super().__init__()
178+
self.args = args
179+
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
180+
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
181+
182+
def __call__(
183+
self,
184+
inputs: mx.array,
185+
cache=None,
186+
) -> mx.array:
187+
h = self.wte(inputs)
188+
# Critical: norm after token embedding
189+
h = rms_norm(h)
190+
191+
if cache is None:
192+
cache = [None] * len(self.h)
193+
194+
mask = create_attention_mask(h, cache[0])
195+
196+
for layer, c in zip(self.h, cache):
197+
h = layer(h, mask=mask, cache=c)
198+
199+
# Critical: final norm before lm_head
200+
h = rms_norm(h)
201+
202+
return h
203+
204+
205+
@partial(mx.compile, shapeless=True)
206+
def softcap(logits, cap=15.0):
207+
return cap * mx.tanh(logits / cap)
208+
209+
210+
class Model(nn.Module):
211+
def __init__(self, args: ModelArgs):
212+
super().__init__()
213+
self.args = args
214+
self.model_type = args.model_type
215+
self.transformer = NanoChatModel(args)
216+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
217+
218+
def __call__(
219+
self,
220+
inputs: mx.array,
221+
cache=None,
222+
) -> mx.array:
223+
out = self.transformer(inputs, cache=cache)
224+
logits = self.lm_head(out)
225+
226+
# Critical: logits softcap (nanochat uses softcap=15)
227+
logits = softcap(logits)
228+
229+
return logits
230+
231+
@property
232+
def layers(self):
233+
return self.transformer.h

tests/test_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,13 @@ def test_all_models(self):
19321932
"max_position_embeddings": 1000,
19331933
"vocab_size": 1000,
19341934
},
1935+
{
1936+
"model_type": "nanochat",
1937+
"hidden_size": 1280,
1938+
"num_hidden_layers": 20,
1939+
"vocab_size": 32,
1940+
"intermediate_size": 128,
1941+
},
19351942
]
19361943
for config in test_configs:
19371944
model_type = config["model_type"]

0 commit comments

Comments
 (0)