Skip to content

Commit 7cd49d9

Browse files
committed
Type casting fixes
1 parent c863a68 commit 7cd49d9

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

mlx_lm/models/llada2_moe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,10 @@ def group_expert_select(
156156
norm_topk_prob,
157157
score_function,
158158
):
159-
in_type = gates.dtype
160-
161159
scores = (
162-
mx.sigmoid(gates.astype(mx.float32))
160+
mx.sigmoid(gates)
163161
if score_function == "sigmoid"
164-
else mx.softmax(gates.astype(mx.float32), axis=-1)
162+
else mx.softmax(gates, axis=-1, precise=True)
165163
)
166164
orig_scores = scores
167165

@@ -185,7 +183,7 @@ def group_expert_select(
185183
scores = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20)
186184

187185
scores = scores * routed_scaling_factor
188-
return inds, scores.astype(in_type)
186+
return inds, scores
189187

190188

191189
class LLaDA2MoeGate(nn.Module):
@@ -207,7 +205,7 @@ def __init__(self, args: ModelArgs):
207205
def __call__(self, x):
208206
orig_shape = x.shape
209207
x = x.reshape(-1, x.shape[-1])
210-
gates = mx.matmul(x.astype(mx.float32), self.weight.T)
208+
gates = mx.matmul(x, self.weight.T)
211209

212210
indices, scores = group_expert_select(
213211
gates,
@@ -248,7 +246,7 @@ def __init__(self, args: ModelArgs):
248246
def __call__(self, x):
249247
inds, scores = self.gate(x)
250248
y = self.switch_mlp(x, inds)
251-
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
249+
y = (y * scores[..., None]).sum(axis=-2)
252250
if self.shared_experts is not None:
253251
y = y + self.shared_experts(x)
254252
return y
@@ -336,12 +334,14 @@ def head_dim(self):
336334
def n_kv_heads(self):
337335
return self.args.num_key_value_heads
338336

339-
def _create_block_diagonal_mask(self, num_blocks: int, block_length: int):
337+
def _create_block_diagonal_mask(
338+
self, num_blocks: int, block_length: int, dtype=mx.float32
339+
):
340340
"""Create block-diagonal attention mask for diffusion generation."""
341341
mask = mx.tril(mx.ones((num_blocks, num_blocks)))
342342
mask = mx.repeat(mx.repeat(mask, block_length, axis=0), block_length, axis=1)
343343
mask = mask[None, None, :, :]
344-
return mx.where(mask, 0.0, float("-inf")).astype(mx.bfloat16)
344+
return mx.where(mask, 0.0, float("-inf")).astype(dtype)
345345

346346
def _select_tokens_to_update(
347347
self, confidence: mx.array, mask: mx.array, num_tokens: int, threshold: float
@@ -422,7 +422,9 @@ def generate_step(
422422
num_blocks = (prompt_length + max_tokens + block_length - 1) // block_length
423423
total_length = num_blocks * block_length
424424

425-
mask = self._create_block_diagonal_mask(num_blocks, block_length)
425+
mask = self._create_block_diagonal_mask(
426+
num_blocks, block_length, dtype=self.model.word_embeddings.weight.dtype
427+
)
426428
transfer_schedule = self._get_num_transfer_tokens(block_length, steps)
427429

428430
x = mx.full((1, total_length), mask_id, dtype=mx.int32)

0 commit comments

Comments
 (0)