1616
1717@dataclass
1818class ModelArgs (BaseModelArgs ):
19+ model_type : str = "gpt_oss"
1920 num_hidden_layers : int = 36
2021 num_local_experts : int = 128
2122 num_experts_per_tok : int = 4
@@ -29,6 +30,7 @@ class ModelArgs(BaseModelArgs):
2930 sliding_window : int = 128
3031 rope_theta : int = 150000
3132 rope_scaling : Any = None
33+ layer_types : list = None
3234
3335
3436# These operators emulate particular methods in torch that don't exist in MLX natively
@@ -47,9 +49,12 @@ def swiglu(x_linear, x_glu, alpha: float = 1.702, limit: float = 7.0):
4749 # Clamp the input values
4850 x_glu = mx .clip (x_glu , a_min = None , a_max = limit )
4951 x_linear = mx .clip (x_linear , a_min = - limit , a_max = limit )
50- glu_scaled = (alpha * x_glu .astype (mx .float32 )).astype (mx .bfloat16 )
52+
53+ # Preserve input dtype
54+ input_dtype = x_glu .dtype
55+ glu_scaled = (alpha * x_glu .astype (mx .float32 )).astype (input_dtype )
5156 negative_glu = (- glu_scaled ).astype (mx .float32 )
52- sig = (1.0 / (1.0 + mx .exp (negative_glu ))).astype (mx . bfloat16 )
57+ sig = (1.0 / (1.0 + mx .exp (negative_glu ))).astype (input_dtype )
5358
5459 out_glu = x_glu * sig
5560 # Note we add an extra bias of 1 to the linear layer
@@ -168,11 +173,11 @@ def _make_mask(L, offset):
168173
169174 return self ._previous_mask [..., : min (L + offset , window_size + 1 )]
170175
171- def get_mask (self , x , cache , window_size , idx ):
172- if idx % 2 == 1 :
173- return self .get_causal_mask (x , cache )
174- else :
176+ def get_mask (self , x , cache , window_size ):
177+ if window_size is not None :
175178 return self .get_sliding_window_mask (x , cache , window_size )
179+ else :
180+ return self .get_causal_mask (x , cache )
176181
177182 def __call__ (self , x : mx .array , mask : mx .array , cache = None ) -> mx .array :
178183 B , L , _ = x .shape
@@ -225,18 +230,15 @@ def __init__(self, config: ModelArgs):
225230 self .router = nn .Linear (config .hidden_size , config .num_local_experts , bias = True )
226231
227232 def __call__ (self , x : mx .array ) -> mx .array :
228- x = x .reshape (- 1 , self .hidden_size )
229-
230- # N.B. As elsewhere, upcast is required in linear layers
231- g = self .router (x .astype (mx .float32 )).astype (mx .bfloat16 )
233+ g = self .router (x )
232234 experts , indices = mlx_topk (g , k = self .num_experts_per_tok , axis = - 1 )
233235 expert_weights = mx .softmax (experts , axis = - 1 , precise = True )
234236
235237 # Experts block
236238 x = self .experts (x , indices )
237239
238- x = x * mx .expand_dims (expert_weights , axis = 2 )
239- return x .sum (axis = 1 )
240+ x = x * mx .expand_dims (expert_weights , axis = - 1 )
241+ return x .sum (axis = - 2 )
240242
241243
242244class TransformerBlock (nn .Module ):
@@ -267,6 +269,10 @@ def __init__(self, args: ModelArgs):
267269 super ().__init__ ()
268270 self .embed_tokens = nn .Embedding (args .vocab_size , args .hidden_size )
269271 self .norm = nn .RMSNorm (args .hidden_size , args .rms_norm_eps )
272+ self .layer_types = args .layer_types or [
273+ "sliding_attention" ,
274+ "full_attention" ,
275+ ] * (args .num_hidden_layers // 2 )
270276 self .layers = [TransformerBlock (args ) for _ in range (args .num_hidden_layers )]
271277 self .window_size = args .sliding_window
272278
@@ -287,8 +293,10 @@ def __call__(
287293
288294 if mask is None :
289295 masks = [
290- l .self_attn .get_mask (x , c , self .window_size , i )
291- for i , (l , c ) in enumerate (zip (self .layers , cache ))
296+ l .self_attn .get_mask (
297+ x , c , self .window_size if lt == "sliding_attention" else None
298+ )
299+ for (l , c , lt ) in zip (self .layers , cache , self .layer_types )
292300 ]
293301 else :
294302 masks = [mask ] * len (self .layers )
@@ -328,10 +336,9 @@ def convert_moe_packed_tensors(blocks, scales):
328336 )
329337
330338 * prefix_shape , G , B = blocks .shape
331- rows_total = math .prod (prefix_shape ) * G
332339
333- blocks = blocks .reshape (rows_total , B )
334- scales = scales .reshape (rows_total , 1 )
340+ blocks = blocks .reshape (- 1 , B )
341+ scales = scales .reshape (- 1 , 1 )
335342
336343 idx_lo = blocks & 0x0F
337344 idx_hi = blocks >> 4
@@ -346,9 +353,7 @@ class Model(nn.Module):
346353 def __init__ (self , args : ModelArgs ):
347354 super ().__init__ ()
348355 self .args = args
349- self .model_type = (
350- args .model_type if hasattr (args , "model_type" ) else "gpt_oss_moe"
351- )
356+ self .model_type = args .model_type
352357 self .model = GptOssMoeModel (args )
353358 self .lm_head = nn .Linear (args .hidden_size , args .vocab_size , bias = False )
354359
@@ -405,9 +410,8 @@ def layers(self):
405410
406411 def make_cache (self ):
407412 caches = []
408- for i in range (self .args .num_hidden_layers ):
409- # full attn on odd indices, swa on even
410- if i % 2 == 1 :
413+ for lt in self .model .layer_types :
414+ if lt == "full_attention" :
411415 caches .append (KVCache ())
412416 else :
413417 caches .append (
0 commit comments