File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -44,6 +44,7 @@ class ModelArgs(BaseModelArgs):
4444 score_function : str = "sigmoid"
4545 n_group : int = 1
4646 topk_group : int = 4
47+ router_dtype : Optional [str ] = None
4748 mask_token_id : int = 156895
4849 eos_token_id : int = 156892
4950
@@ -195,9 +196,15 @@ def __init__(self, args: ModelArgs):
195196 self .topk_group = args .topk_group
196197 self .routed_scaling_factor = args .routed_scaling_factor
197198 self .score_function = args .score_function
198- self .weight = mx .zeros ((args .num_experts , args .hidden_size ))
199+
200+ if args .router_dtype == "fp32" :
201+ router_dtype = mx .float32
202+ else :
203+ router_dtype = None
204+
205+ self .weight = mx .zeros ((args .num_experts , args .hidden_size ), dtype = router_dtype )
199206 self .expert_bias = (
200- mx .zeros ((args .num_experts ,))
207+ mx .zeros ((args .num_experts ,), dtype = router_dtype )
201208 if args .moe_router_enable_expert_bias
202209 else None
203210 )
You can’t perform that action at this time.
0 commit comments