@@ -156,12 +156,10 @@ def group_expert_select(
156156 norm_topk_prob ,
157157 score_function ,
158158):
159- in_type = gates .dtype
160-
161159 scores = (
162- mx .sigmoid (gates . astype ( mx . float32 ) )
160+ mx .sigmoid (gates )
163161 if score_function == "sigmoid"
164- else mx .softmax (gates . astype ( mx . float32 ) , axis = - 1 )
162+ else mx .softmax (gates , axis = - 1 , precise = True )
165163 )
166164 orig_scores = scores
167165
@@ -185,7 +183,7 @@ def group_expert_select(
185183 scores = scores / (scores .sum (axis = - 1 , keepdims = True ) + 1e-20 )
186184
187185 scores = scores * routed_scaling_factor
188- return inds , scores . astype ( in_type )
186+ return inds , scores
189187
190188
191189class LLaDA2MoeGate (nn .Module ):
@@ -207,7 +205,7 @@ def __init__(self, args: ModelArgs):
207205 def __call__ (self , x ):
208206 orig_shape = x .shape
209207 x = x .reshape (- 1 , x .shape [- 1 ])
210- gates = mx .matmul (x . astype ( mx . float32 ) , self .weight .T )
208+ gates = mx .matmul (x , self .weight .T )
211209
212210 indices , scores = group_expert_select (
213211 gates ,
@@ -248,7 +246,7 @@ def __init__(self, args: ModelArgs):
248246 def __call__ (self , x ):
249247 inds , scores = self .gate (x )
250248 y = self .switch_mlp (x , inds )
251- y = (y * scores [..., None ]).sum (axis = - 2 ). astype ( y . dtype )
249+ y = (y * scores [..., None ]).sum (axis = - 2 )
252250 if self .shared_experts is not None :
253251 y = y + self .shared_experts (x )
254252 return y
@@ -336,12 +334,14 @@ def head_dim(self):
336334 def n_kv_heads (self ):
337335 return self .args .num_key_value_heads
338336
339- def _create_block_diagonal_mask (self , num_blocks : int , block_length : int ):
337+ def _create_block_diagonal_mask (
338+ self , num_blocks : int , block_length : int , dtype = mx .float32
339+ ):
340340 """Create block-diagonal attention mask for diffusion generation."""
341341 mask = mx .tril (mx .ones ((num_blocks , num_blocks )))
342342 mask = mx .repeat (mx .repeat (mask , block_length , axis = 0 ), block_length , axis = 1 )
343343 mask = mask [None , None , :, :]
344- return mx .where (mask , 0.0 , float ("-inf" )).astype (mx . bfloat16 )
344+ return mx .where (mask , 0.0 , float ("-inf" )).astype (dtype )
345345
346346 def _select_tokens_to_update (
347347 self , confidence : mx .array , mask : mx .array , num_tokens : int , threshold : float
@@ -422,7 +422,9 @@ def generate_step(
422422 num_blocks = (prompt_length + max_tokens + block_length - 1 ) // block_length
423423 total_length = num_blocks * block_length
424424
425- mask = self ._create_block_diagonal_mask (num_blocks , block_length )
425+ mask = self ._create_block_diagonal_mask (
426+ num_blocks , block_length , dtype = self .model .word_embeddings .weight .dtype
427+ )
426428 transfer_schedule = self ._get_num_transfer_tokens (block_length , steps )
427429
428430 x = mx .full ((1 , total_length ), mask_id , dtype = mx .int32 )
0 commit comments