2020
2121@dataclass
2222class ModelArgs (BaseModelArgs ):
23+ # Required fields (no defaults)
2324 model_type : str
2425 vocab_size : int
2526 hidden_size : int
@@ -29,34 +30,42 @@ class ModelArgs(BaseModelArgs):
2930 num_attention_heads : int
3031 num_key_value_heads : int
3132 attention_bias : bool
32-
33- # Scalar multipliers
3433 embedding_multiplier : float
3534 attention_multiplier : float
3635 logits_scaling : float
3736 residual_multiplier : float
38-
39- # MoE parameters
40- num_local_experts : int
41- num_experts_per_tok : int
42- shared_intermediate_size : int
43-
44- # Mamba parameters
45- mamba_n_heads : int
46- mamba_d_head : int
47- mamba_proj_bias : bool
48- mamba_d_state : int
49- mamba_d_conv : int
50- mamba_n_groups : int
51- mamba_conv_bias : bool
52-
5337 layer_types : List [str ]
5438 rms_norm_eps : float
5539 rope_theta : float
40+
41+ # Optional fields (with defaults)
42+ # MoE parameters (optional for dense mode)
43+ num_local_experts : Optional [int ] = None
44+ num_experts_per_tok : Optional [int ] = None
45+ shared_intermediate_size : Optional [int ] = None
46+
47+ # Mamba parameters (optional for non-hybrid mode)
48+ mamba_n_heads : Optional [int ] = None
49+ mamba_d_head : Optional [int ] = None
50+ mamba_proj_bias : Optional [bool ] = None
51+ mamba_d_state : Optional [int ] = None
52+ mamba_d_conv : Optional [int ] = None
53+ mamba_n_groups : Optional [int ] = None
54+ mamba_conv_bias : Optional [bool ] = None
55+
56+ # Dense MLP parameters (for non-MoE mode)
57+ mlp_bias : bool = False
58+
59+ # Other optional parameters
5660 position_embedding_type : str = "rope"
5761 tie_word_embeddings : bool = True
5862 time_step_limit : Tuple [float , float ] = (0.001 , 100.0 )
5963
64+ # Mode flags - inferred from num_local_experts
65+ @property
66+ def use_moe (self ) -> bool :
67+ return bool (self .num_local_experts )
68+
6069
6170class GraniteMoeHybridRMSNormGated (nn .Module ):
6271 def __init__ (self , hidden_size : int , eps : float = 1e-6 ):
@@ -314,11 +323,27 @@ def __call__(self, x: mx.array) -> mx.array:
314323 return self .output_linear (nn .silu (gate ) * up )
315324
316325
326+ class GraniteMoeHybridMLP (nn .Module ):
327+ def __init__ (self , args : ModelArgs ):
328+ super ().__init__ ()
329+ dim = args .hidden_size
330+ hidden_dim = args .intermediate_size
331+ mlp_bias = args .mlp_bias
332+
333+ self .gate_proj = nn .Linear (dim , hidden_dim , bias = mlp_bias )
334+ self .down_proj = nn .Linear (hidden_dim , dim , bias = mlp_bias )
335+ self .up_proj = nn .Linear (dim , hidden_dim , bias = mlp_bias )
336+
337+ def __call__ (self , x ) -> mx .array :
338+ return self .down_proj (nn .silu (self .gate_proj (x )) * self .up_proj (x ))
339+
340+
317341class GraniteMoeHybridLayer (nn .Module ):
318342 def __init__ (self , args : ModelArgs , layer_type : str ):
319343 super ().__init__ ()
320344 self .layer_type = layer_type
321345 self .residual_multiplier = args .residual_multiplier
346+ self .use_moe = args .use_moe
322347
323348 self .input_layernorm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
324349
@@ -329,8 +354,14 @@ def __init__(self, args: ModelArgs, layer_type: str):
329354 else :
330355 raise ValueError (f"Unknown layer type: { layer_type } " )
331356
332- self .shared_mlp = GraniteMoeHybridSharedMLP (args )
333- self .block_sparse_moe = GraniteMoeHybridMoE (args )
357+ # MoE or dense MLP after attention/mamba
358+ if self .use_moe :
359+ self .shared_mlp = GraniteMoeHybridSharedMLP (args )
360+ self .block_sparse_moe = GraniteMoeHybridMoE (args )
361+ else :
362+ # Dense MLP mode
363+ self .mlp = GraniteMoeHybridMLP (args )
364+
334365 self .post_attention_layernorm = nn .RMSNorm (
335366 args .hidden_size , eps = args .rms_norm_eps
336367 )
@@ -352,13 +383,16 @@ def __call__(
352383
353384 hidden_states = residual + hidden_states * self .residual_multiplier
354385
355- # Second block: MoE + shared_mlp
386+ # Second block: MoE + shared_mlp OR dense MLP
356387 residual = hidden_states
357388 normed = self .post_attention_layernorm (hidden_states )
358389
359- moe_out = self .block_sparse_moe (normed )
360- shared_out = self .shared_mlp (normed )
361- mlp_out = moe_out + shared_out
390+ if self .use_moe :
391+ moe_out = self .block_sparse_moe (normed )
392+ shared_out = self .shared_mlp (normed )
393+ mlp_out = moe_out + shared_out
394+ else :
395+ mlp_out = self .mlp (normed )
362396
363397 hidden_states = residual + mlp_out * self .residual_multiplier
364398
@@ -375,9 +409,16 @@ def __init__(self, args: ModelArgs):
375409 ]
376410 self .norm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
377411 self .embedding_multiplier = args .embedding_multiplier
378- self .fa_idx = args .layer_types .index ("attention" )
379- self .ssm_idx = args .layer_types .index ("mamba" )
380- self .layer_types = args .layer_types
412+
413+ # Handle hybrid vs non-hybrid mode
414+ self .fa_idx = (
415+ args .layer_types .index ("attention" )
416+ if "attention" in args .layer_types
417+ else None
418+ )
419+ self .ssm_idx = (
420+ args .layer_types .index ("mamba" ) if "mamba" in args .layer_types else None
421+ )
381422
382423 def __call__ (
383424 self ,
@@ -389,11 +430,16 @@ def __call__(
389430 if cache is None :
390431 cache = [None ] * len (self .layers )
391432
392- attn_mask = create_attention_mask (hidden_states , cache [self .fa_idx ])
393- mamba_mask = create_ssm_mask (hidden_states , cache [self .ssm_idx ])
433+ # Create masks based on what layer types exist
434+ attn_mask = None
435+ mamba_mask = None
436+
437+ if self .fa_idx is not None :
438+ attn_mask = create_attention_mask (hidden_states , cache [self .fa_idx ])
439+ if self .ssm_idx is not None :
440+ mamba_mask = create_ssm_mask (hidden_states , cache [self .ssm_idx ])
394441
395- cache_counter = 0
396- for layer , c , layer_type in zip (self .layers , cache , self .layer_types ):
442+ for layer , c in zip (self .layers , cache ):
397443 mask = attn_mask if layer .layer_type == "attention" else mamba_mask
398444 hidden_states = layer (hidden_states , mask = mask , cache = c )
399445
@@ -443,8 +489,11 @@ def sanitize(self, weights):
443489 if "conv1d.weight" in k and v .shape [- 1 ] != 1 :
444490 weights [k ] = v .moveaxis (2 , 1 )
445491
446- # Handle MoE weight transformation to SwitchGLU format
447- if "model.layers.0.block_sparse_moe.input_linear.weight" in weights :
492+ # Handle MoE weight transformation to SwitchGLU format (only for MoE models)
493+ if (
494+ self .args .use_moe
495+ and "model.layers.0.block_sparse_moe.input_linear.weight" in weights
496+ ):
448497 for l in range (self .args .num_hidden_layers ):
449498 prefix = f"model.layers.{ l } .block_sparse_moe"
450499
@@ -461,12 +510,31 @@ def sanitize(self, weights):
461510 f"{ prefix } .output_linear.weight"
462511 )
463512
513+ # Handle dense MLP weight transformation (for dense models)
514+ elif (
515+ not self .args .use_moe
516+ and "model.layers.0.shared_mlp.input_linear.weight" in weights
517+ ):
518+ for l in range (self .args .num_hidden_layers ):
519+ prefix = f"model.layers.{ l } .shared_mlp"
520+
521+ # Transform shared_mlp weights to standard mlp weights
522+ input_weight = weights .pop (f"{ prefix } .input_linear.weight" )
523+ # Split into gate and up projections (each half)
524+ gate_proj , up_proj = mx .split (input_weight , 2 , axis = 0 )
525+ weights [f"model.layers.{ l } .mlp.gate_proj.weight" ] = gate_proj
526+ weights [f"model.layers.{ l } .mlp.up_proj.weight" ] = up_proj
527+
528+ weights [f"model.layers.{ l } .mlp.down_proj.weight" ] = weights .pop (
529+ f"{ prefix } .output_linear.weight"
530+ )
531+
464532 return weights
465533
466534 @property
467535 def quant_predicate (self ):
468536 def predicate (path , _ ):
469- if path .endswith ("router.layer" ):
537+ if self . args . use_moe and path .endswith ("router.layer" ):
470538 return {"group_size" : 64 , "bits" : 8 }
471539 return True
472540
0 commit comments