Skip to content

Commit a2bc266

Browse files
committed
Get router dtype from config
1 parent 7cd49d9 commit a2bc266

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mlx_lm/models/llada2_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class ModelArgs(BaseModelArgs):
4444
score_function: str = "sigmoid"
4545
n_group: int = 1
4646
topk_group: int = 4
47+
router_dtype: Optional[str] = None
4748
mask_token_id: int = 156895
4849
eos_token_id: int = 156892
4950

@@ -195,9 +196,15 @@ def __init__(self, args: ModelArgs):
195196
self.topk_group = args.topk_group
196197
self.routed_scaling_factor = args.routed_scaling_factor
197198
self.score_function = args.score_function
198-
self.weight = mx.zeros((args.num_experts, args.hidden_size))
199+
200+
if args.router_dtype == "fp32":
201+
router_dtype = mx.float32
202+
else:
203+
router_dtype = None
204+
205+
self.weight = mx.zeros((args.num_experts, args.hidden_size), dtype=router_dtype)
199206
self.expert_bias = (
200-
mx.zeros((args.num_experts,))
207+
mx.zeros((args.num_experts,), dtype=router_dtype)
201208
if args.moe_router_enable_expert_bias
202209
else None
203210
)

0 commit comments

Comments
 (0)