Skip to content

Commit 6c876ca

Browse files
Shashikant86awni
andauthored
Add Additional Features of GPT-OSS Model : Lora, Alternating attention, MoE Support (#357)
* Adde Lora, Alternating attention, MoE suport * nits * comment * comment * comment * fix test --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent cfa74ad commit 6c876ca

File tree

3 files changed

+56
-24
lines changed

3 files changed

+56
-24
lines changed

mlx_lm/models/gpt_oss.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
@dataclass
1818
class ModelArgs(BaseModelArgs):
19+
model_type: str = "gpt_oss"
1920
num_hidden_layers: int = 36
2021
num_local_experts: int = 128
2122
num_experts_per_tok: int = 4
@@ -29,6 +30,7 @@ class ModelArgs(BaseModelArgs):
2930
sliding_window: int = 128
3031
rope_theta: int = 150000
3132
rope_scaling: Any = None
33+
layer_types: list = None
3234

3335

3436
# These operators emulate particular methods in torch that don't exist in MLX natively
@@ -47,9 +49,12 @@ def swiglu(x_linear, x_glu, alpha: float = 1.702, limit: float = 7.0):
4749
# Clamp the input values
4850
x_glu = mx.clip(x_glu, a_min=None, a_max=limit)
4951
x_linear = mx.clip(x_linear, a_min=-limit, a_max=limit)
50-
glu_scaled = (alpha * x_glu.astype(mx.float32)).astype(mx.bfloat16)
52+
53+
# Preserve input dtype
54+
input_dtype = x_glu.dtype
55+
glu_scaled = (alpha * x_glu.astype(mx.float32)).astype(input_dtype)
5156
negative_glu = (-glu_scaled).astype(mx.float32)
52-
sig = (1.0 / (1.0 + mx.exp(negative_glu))).astype(mx.bfloat16)
57+
sig = (1.0 / (1.0 + mx.exp(negative_glu))).astype(input_dtype)
5358

5459
out_glu = x_glu * sig
5560
# Note we add an extra bias of 1 to the linear layer
@@ -168,11 +173,11 @@ def _make_mask(L, offset):
168173

169174
return self._previous_mask[..., : min(L + offset, window_size + 1)]
170175

171-
def get_mask(self, x, cache, window_size, idx):
172-
if idx % 2 == 1:
173-
return self.get_causal_mask(x, cache)
174-
else:
176+
def get_mask(self, x, cache, window_size):
177+
if window_size is not None:
175178
return self.get_sliding_window_mask(x, cache, window_size)
179+
else:
180+
return self.get_causal_mask(x, cache)
176181

177182
def __call__(self, x: mx.array, mask: mx.array, cache=None) -> mx.array:
178183
B, L, _ = x.shape
@@ -225,18 +230,15 @@ def __init__(self, config: ModelArgs):
225230
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True)
226231

227232
def __call__(self, x: mx.array) -> mx.array:
228-
x = x.reshape(-1, self.hidden_size)
229-
230-
# N.B. As elsewhere, upcast is required in linear layers
231-
g = self.router(x.astype(mx.float32)).astype(mx.bfloat16)
233+
g = self.router(x)
232234
experts, indices = mlx_topk(g, k=self.num_experts_per_tok, axis=-1)
233235
expert_weights = mx.softmax(experts, axis=-1, precise=True)
234236

235237
# Experts block
236238
x = self.experts(x, indices)
237239

238-
x = x * mx.expand_dims(expert_weights, axis=2)
239-
return x.sum(axis=1)
240+
x = x * mx.expand_dims(expert_weights, axis=-1)
241+
return x.sum(axis=-2)
240242

241243

242244
class TransformerBlock(nn.Module):
@@ -267,6 +269,10 @@ def __init__(self, args: ModelArgs):
267269
super().__init__()
268270
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
269271
self.norm = nn.RMSNorm(args.hidden_size, args.rms_norm_eps)
272+
self.layer_types = args.layer_types or [
273+
"sliding_attention",
274+
"full_attention",
275+
] * (args.num_hidden_layers // 2)
270276
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
271277
self.window_size = args.sliding_window
272278

@@ -287,8 +293,10 @@ def __call__(
287293

288294
if mask is None:
289295
masks = [
290-
l.self_attn.get_mask(x, c, self.window_size, i)
291-
for i, (l, c) in enumerate(zip(self.layers, cache))
296+
l.self_attn.get_mask(
297+
x, c, self.window_size if lt == "sliding_attention" else None
298+
)
299+
for (l, c, lt) in zip(self.layers, cache, self.layer_types)
292300
]
293301
else:
294302
masks = [mask] * len(self.layers)
@@ -328,10 +336,9 @@ def convert_moe_packed_tensors(blocks, scales):
328336
)
329337

330338
*prefix_shape, G, B = blocks.shape
331-
rows_total = math.prod(prefix_shape) * G
332339

333-
blocks = blocks.reshape(rows_total, B)
334-
scales = scales.reshape(rows_total, 1)
340+
blocks = blocks.reshape(-1, B)
341+
scales = scales.reshape(-1, 1)
335342

336343
idx_lo = blocks & 0x0F
337344
idx_hi = blocks >> 4
@@ -346,9 +353,7 @@ class Model(nn.Module):
346353
def __init__(self, args: ModelArgs):
347354
super().__init__()
348355
self.args = args
349-
self.model_type = (
350-
args.model_type if hasattr(args, "model_type") else "gpt_oss_moe"
351-
)
356+
self.model_type = args.model_type
352357
self.model = GptOssMoeModel(args)
353358
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
354359

@@ -405,9 +410,8 @@ def layers(self):
405410

406411
def make_cache(self):
407412
caches = []
408-
for i in range(self.args.num_hidden_layers):
409-
# full attn on odd indices, swa on even
410-
if i % 2 == 1:
413+
for lt in self.model.layer_types:
414+
if lt == "full_attention":
411415
caches.append(KVCache())
412416
else:
413417
caches.append(

mlx_lm/tuner/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def to_lora(layer):
122122
"smollm3",
123123
"exaone4",
124124
"hunyuan_v1_dense",
125+
"gpt_oss",
125126
}:
126127
keys = {"self_attn.q_proj", "self_attn.v_proj"}
127128
if model.model_type in ["mixtral", "phimoe"]:

tests/test_models.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def model_test_runner(self, model, model_type, vocab_size, num_layers):
221221
self.assertEqual(outputs.shape, (1, 2, vocab_size))
222222
self.assertEqual(outputs.dtype, t)
223223

224-
if model_type not in ("mamba", "plamo2"):
224+
if model_type not in ("mamba", "plamo2", "gpt_oss"):
225225
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
226226
outputs = model(inputs, mask=mask)
227227
self.assertEqual(outputs.shape, (1, 2, vocab_size))
@@ -1158,6 +1158,33 @@ def test_smollm3(self):
11581158
model, "smollm3", args.vocab_size, args.num_hidden_layers
11591159
)
11601160

1161+
def test_gpt_oss(self):
1162+
from mlx_lm.models import gpt_oss
1163+
1164+
args = gpt_oss.ModelArgs(
1165+
model_type="gpt_oss",
1166+
hidden_size=1024,
1167+
num_hidden_layers=4,
1168+
intermediate_size=2048,
1169+
num_attention_heads=8,
1170+
num_key_value_heads=2,
1171+
num_local_experts=16,
1172+
num_experts_per_tok=2,
1173+
sliding_window=128,
1174+
rope_theta=10000,
1175+
vocab_size=10_000,
1176+
layer_types=[
1177+
"sliding_attention",
1178+
"full_attention",
1179+
"sliding_attention",
1180+
"full_attention",
1181+
],
1182+
)
1183+
model = gpt_oss.Model(args)
1184+
self.model_test_runner(
1185+
model, args.model_type, args.vocab_size, args.num_hidden_layers
1186+
)
1187+
11611188

11621189
if __name__ == "__main__":
11631190
unittest.main()

0 commit comments

Comments
 (0)