Skip to content

Commit b264da7

Browse files
gabe-l-hartawni
andauthored
feat: Refactor granitemoehybrid to support dense and non-hybrid variants (#518)
* feat: Refactor granitemoehybrid to support dense and non-hybrid variants Written with Claude Code. Initial prompt: I need to modify the model support implemented in `mlx_lm/models/granitemoehybrid.py` in two ways: * Support optionally using a dense block in place of MoE. The dense block should look like `mlx_lm/models/granite.py` instead of `mlx_lm/models/granitemoe.py`. * Support the case where there are no `mamba` layers (ie non-hybrid). This should devolve to exactly `granite.py` or `granitemoe.py` depending on whether the block after attention is dense or MoE. You can test this using the following two models: * Dense w/ hybrid: /Users/ghart/models/dmf_models/granite-4.0-h-micro-r250918a * Dense w/ non-hybrid: /Users/ghart/models/dmf_models/granite-4.0-micro-r250918a Branch: GraniteFourDense Signed-off-by: Gabe Goodhart <[email protected]> * refact: Clean up Claude's code a bit Branch: GraniteFourDense Signed-off-by: Gabe Goodhart <[email protected]> * style: pre-commit format Branch: GraniteFourDense Signed-off-by: Gabe Goodhart <[email protected]> * version bump --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent cf8cfd0 commit b264da7

File tree

2 files changed

+102
-34
lines changed

2 files changed

+102
-34
lines changed

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2025 Apple Inc.
22

3-
__version__ = "0.28.1"
3+
__version__ = "0.28.2"

mlx_lm/models/granitemoehybrid.py

Lines changed: 101 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
@dataclass
2222
class ModelArgs(BaseModelArgs):
23+
# Required fields (no defaults)
2324
model_type: str
2425
vocab_size: int
2526
hidden_size: int
@@ -29,34 +30,42 @@ class ModelArgs(BaseModelArgs):
2930
num_attention_heads: int
3031
num_key_value_heads: int
3132
attention_bias: bool
32-
33-
# Scalar multipliers
3433
embedding_multiplier: float
3534
attention_multiplier: float
3635
logits_scaling: float
3736
residual_multiplier: float
38-
39-
# MoE parameters
40-
num_local_experts: int
41-
num_experts_per_tok: int
42-
shared_intermediate_size: int
43-
44-
# Mamba parameters
45-
mamba_n_heads: int
46-
mamba_d_head: int
47-
mamba_proj_bias: bool
48-
mamba_d_state: int
49-
mamba_d_conv: int
50-
mamba_n_groups: int
51-
mamba_conv_bias: bool
52-
5337
layer_types: List[str]
5438
rms_norm_eps: float
5539
rope_theta: float
40+
41+
# Optional fields (with defaults)
42+
# MoE parameters (optional for dense mode)
43+
num_local_experts: Optional[int] = None
44+
num_experts_per_tok: Optional[int] = None
45+
shared_intermediate_size: Optional[int] = None
46+
47+
# Mamba parameters (optional for non-hybrid mode)
48+
mamba_n_heads: Optional[int] = None
49+
mamba_d_head: Optional[int] = None
50+
mamba_proj_bias: Optional[bool] = None
51+
mamba_d_state: Optional[int] = None
52+
mamba_d_conv: Optional[int] = None
53+
mamba_n_groups: Optional[int] = None
54+
mamba_conv_bias: Optional[bool] = None
55+
56+
# Dense MLP parameters (for non-MoE mode)
57+
mlp_bias: bool = False
58+
59+
# Other optional parameters
5660
position_embedding_type: str = "rope"
5761
tie_word_embeddings: bool = True
5862
time_step_limit: Tuple[float, float] = (0.001, 100.0)
5963

64+
# Mode flags - inferred from num_local_experts
65+
@property
66+
def use_moe(self) -> bool:
67+
return bool(self.num_local_experts)
68+
6069

6170
class GraniteMoeHybridRMSNormGated(nn.Module):
6271
def __init__(self, hidden_size: int, eps: float = 1e-6):
@@ -314,11 +323,27 @@ def __call__(self, x: mx.array) -> mx.array:
314323
return self.output_linear(nn.silu(gate) * up)
315324

316325

326+
class GraniteMoeHybridMLP(nn.Module):
327+
def __init__(self, args: ModelArgs):
328+
super().__init__()
329+
dim = args.hidden_size
330+
hidden_dim = args.intermediate_size
331+
mlp_bias = args.mlp_bias
332+
333+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
334+
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
335+
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
336+
337+
def __call__(self, x) -> mx.array:
338+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
339+
340+
317341
class GraniteMoeHybridLayer(nn.Module):
318342
def __init__(self, args: ModelArgs, layer_type: str):
319343
super().__init__()
320344
self.layer_type = layer_type
321345
self.residual_multiplier = args.residual_multiplier
346+
self.use_moe = args.use_moe
322347

323348
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
324349

@@ -329,8 +354,14 @@ def __init__(self, args: ModelArgs, layer_type: str):
329354
else:
330355
raise ValueError(f"Unknown layer type: {layer_type}")
331356

332-
self.shared_mlp = GraniteMoeHybridSharedMLP(args)
333-
self.block_sparse_moe = GraniteMoeHybridMoE(args)
357+
# MoE or dense MLP after attention/mamba
358+
if self.use_moe:
359+
self.shared_mlp = GraniteMoeHybridSharedMLP(args)
360+
self.block_sparse_moe = GraniteMoeHybridMoE(args)
361+
else:
362+
# Dense MLP mode
363+
self.mlp = GraniteMoeHybridMLP(args)
364+
334365
self.post_attention_layernorm = nn.RMSNorm(
335366
args.hidden_size, eps=args.rms_norm_eps
336367
)
@@ -352,13 +383,16 @@ def __call__(
352383

353384
hidden_states = residual + hidden_states * self.residual_multiplier
354385

355-
# Second block: MoE + shared_mlp
386+
# Second block: MoE + shared_mlp OR dense MLP
356387
residual = hidden_states
357388
normed = self.post_attention_layernorm(hidden_states)
358389

359-
moe_out = self.block_sparse_moe(normed)
360-
shared_out = self.shared_mlp(normed)
361-
mlp_out = moe_out + shared_out
390+
if self.use_moe:
391+
moe_out = self.block_sparse_moe(normed)
392+
shared_out = self.shared_mlp(normed)
393+
mlp_out = moe_out + shared_out
394+
else:
395+
mlp_out = self.mlp(normed)
362396

363397
hidden_states = residual + mlp_out * self.residual_multiplier
364398

@@ -375,9 +409,16 @@ def __init__(self, args: ModelArgs):
375409
]
376410
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
377411
self.embedding_multiplier = args.embedding_multiplier
378-
self.fa_idx = args.layer_types.index("attention")
379-
self.ssm_idx = args.layer_types.index("mamba")
380-
self.layer_types = args.layer_types
412+
413+
# Handle hybrid vs non-hybrid mode
414+
self.fa_idx = (
415+
args.layer_types.index("attention")
416+
if "attention" in args.layer_types
417+
else None
418+
)
419+
self.ssm_idx = (
420+
args.layer_types.index("mamba") if "mamba" in args.layer_types else None
421+
)
381422

382423
def __call__(
383424
self,
@@ -389,11 +430,16 @@ def __call__(
389430
if cache is None:
390431
cache = [None] * len(self.layers)
391432

392-
attn_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
393-
mamba_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx])
433+
# Create masks based on what layer types exist
434+
attn_mask = None
435+
mamba_mask = None
436+
437+
if self.fa_idx is not None:
438+
attn_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
439+
if self.ssm_idx is not None:
440+
mamba_mask = create_ssm_mask(hidden_states, cache[self.ssm_idx])
394441

395-
cache_counter = 0
396-
for layer, c, layer_type in zip(self.layers, cache, self.layer_types):
442+
for layer, c in zip(self.layers, cache):
397443
mask = attn_mask if layer.layer_type == "attention" else mamba_mask
398444
hidden_states = layer(hidden_states, mask=mask, cache=c)
399445

@@ -443,8 +489,11 @@ def sanitize(self, weights):
443489
if "conv1d.weight" in k and v.shape[-1] != 1:
444490
weights[k] = v.moveaxis(2, 1)
445491

446-
# Handle MoE weight transformation to SwitchGLU format
447-
if "model.layers.0.block_sparse_moe.input_linear.weight" in weights:
492+
# Handle MoE weight transformation to SwitchGLU format (only for MoE models)
493+
if (
494+
self.args.use_moe
495+
and "model.layers.0.block_sparse_moe.input_linear.weight" in weights
496+
):
448497
for l in range(self.args.num_hidden_layers):
449498
prefix = f"model.layers.{l}.block_sparse_moe"
450499

@@ -461,12 +510,31 @@ def sanitize(self, weights):
461510
f"{prefix}.output_linear.weight"
462511
)
463512

513+
# Handle dense MLP weight transformation (for dense models)
514+
elif (
515+
not self.args.use_moe
516+
and "model.layers.0.shared_mlp.input_linear.weight" in weights
517+
):
518+
for l in range(self.args.num_hidden_layers):
519+
prefix = f"model.layers.{l}.shared_mlp"
520+
521+
# Transform shared_mlp weights to standard mlp weights
522+
input_weight = weights.pop(f"{prefix}.input_linear.weight")
523+
# Split into gate and up projections (each half)
524+
gate_proj, up_proj = mx.split(input_weight, 2, axis=0)
525+
weights[f"model.layers.{l}.mlp.gate_proj.weight"] = gate_proj
526+
weights[f"model.layers.{l}.mlp.up_proj.weight"] = up_proj
527+
528+
weights[f"model.layers.{l}.mlp.down_proj.weight"] = weights.pop(
529+
f"{prefix}.output_linear.weight"
530+
)
531+
464532
return weights
465533

466534
@property
467535
def quant_predicate(self):
468536
def predicate(path, _):
469-
if path.endswith("router.layer"):
537+
if self.args.use_moe and path.endswith("router.layer"):
470538
return {"group_size": 64, "bits": 8}
471539
return True
472540

0 commit comments

Comments
 (0)