Skip to content

Commit f461119

Browse files
committed
deepseek v32
1 parent 0c0b722 commit f461119

File tree

4 files changed

+601
-107
lines changed

4 files changed

+601
-107
lines changed

mlx_lm/models/cache.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,55 @@ def extend(self, other):
659659
c.extend(other)
660660

661661

662+
class KCache(_BaseCache):
663+
step = 256
664+
665+
def __init__(self):
666+
self.keys = None
667+
self.offset = 0
668+
669+
def update_and_fetch(self, k):
670+
prev = self.offset
671+
if self.keys is None or (prev + k.shape[2]) > self.keys.shape[2]:
672+
B, n_heads, _, head_dim = k.shape
673+
n_steps = (self.step + k.shape[2] - 1) // self.step
674+
shape = (B, n_heads, n_steps * self.step, head_dim)
675+
new_k = mx.zeros(shape, k.dtype)
676+
if self.keys is not None:
677+
if prev % self.step != 0:
678+
self.keys = self.keys[..., :prev, :]
679+
self.keys = mx.concatenate([self.keys, new_k], axis=2)
680+
else:
681+
self.keys = new_k
682+
683+
self.offset += k.shape[2]
684+
self.keys[..., prev : self.offset, :] = k
685+
return self.keys[..., : self.offset, :]
686+
687+
@property
688+
def state(self):
689+
if self.offset == self.keys.shape[2]:
690+
return self.keys
691+
else:
692+
return (self.keys[..., : self.offset, :],)
693+
694+
@state.setter
695+
def state(self, v):
696+
self.keys = v
697+
self.offset = self.keys.shape[2]
698+
699+
def is_trimmable(self):
700+
return True
701+
702+
def trim(self, n):
703+
n = min(self.offset, n)
704+
self.offset -= n
705+
return n
706+
707+
def make_mask(self, *args, **kwargs):
708+
return create_attention_mask(*args, offset=self.offset, **kwargs)
709+
710+
662711
class BatchKVCache(_BaseCache):
663712
step = 256
664713

mlx_lm/models/deepseek_v3.py

Lines changed: 13 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import mlx.nn as nn
1010

1111
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
12+
from .rope_utils import initialize_rope
1213
from .switch_layers import SwitchGLU
1314

1415

@@ -45,85 +46,6 @@ class ModelArgs(BaseModelArgs):
4546
attention_bias: bool = False
4647

4748

48-
def yarn_find_correction_dim(
49-
num_rotations, dim, base=10000, max_position_embeddings=2048
50-
):
51-
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
52-
2 * math.log(base)
53-
)
54-
55-
56-
def yarn_find_correction_range(
57-
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
58-
):
59-
low = math.floor(
60-
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
61-
)
62-
high = math.ceil(
63-
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
64-
)
65-
return max(low, 0), min(high, dim - 1)
66-
67-
68-
def yarn_get_mscale(scale=1, mscale=1):
69-
if scale <= 1:
70-
return 1.0
71-
return 0.1 * mscale * math.log(scale) + 1.0
72-
73-
74-
def yarn_linear_ramp_mask(min_val, max_val, dim):
75-
if min_val == max_val:
76-
max_val += 0.001 # Prevent singularity
77-
78-
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
79-
return mx.clip(linear_func, 0, 1)
80-
81-
82-
class DeepseekV3YarnRotaryEmbedding(nn.Module):
83-
def __init__(
84-
self,
85-
dim,
86-
max_position_embeddings=2048,
87-
base=10000,
88-
scaling_factor=1.0,
89-
original_max_position_embeddings=4096,
90-
beta_fast=32,
91-
beta_slow=1,
92-
mscale=1,
93-
mscale_all_dim=0,
94-
):
95-
super().__init__()
96-
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
97-
scaling_factor, mscale_all_dim
98-
)
99-
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
100-
freq_inter = scaling_factor * freq_extra
101-
low, high = yarn_find_correction_range(
102-
beta_fast,
103-
beta_slow,
104-
dim,
105-
base,
106-
original_max_position_embeddings,
107-
)
108-
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
109-
self._freqs = (freq_inter * freq_extra) / (
110-
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
111-
)
112-
113-
def __call__(self, x, offset=0):
114-
if self.mscale != 1.0:
115-
x = self.mscale * x
116-
return mx.fast.rope(
117-
x,
118-
x.shape[-1],
119-
traditional=True,
120-
base=None,
121-
scale=1.0,
122-
offset=offset,
123-
freqs=self._freqs,
124-
)
125-
126-
12749
class DeepseekV3Attention(nn.Module):
12850
def __init__(self, config: ModelArgs):
12951
super().__init__()
@@ -175,35 +97,19 @@ def __init__(self, config: ModelArgs):
17597

17698
if self.config.rope_scaling is not None:
17799
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
178-
scaling_factor = self.config.rope_scaling["factor"]
179100
if mscale_all_dim:
180-
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
181-
self.scale = self.scale * mscale * mscale
182-
183-
rope_kwargs = {
184-
key: self.config.rope_scaling[key]
185-
for key in [
186-
"original_max_position_embeddings",
187-
"beta_fast",
188-
"beta_slow",
189-
"mscale",
190-
"mscale_all_dim",
191-
]
192-
if key in self.config.rope_scaling
193-
}
194-
self.rope = DeepseekV3YarnRotaryEmbedding(
195-
dim=self.qk_rope_head_dim,
196-
max_position_embeddings=self.max_position_embeddings,
197-
scaling_factor=scaling_factor,
198-
base=self.rope_theta,
199-
**rope_kwargs,
200-
)
201-
else:
202-
self.rope = nn.RoPE(
203-
dims=self.qk_rope_head_dim,
204-
base=self.rope_theta,
205-
traditional=True,
206-
)
101+
scaling_factor = self.config.rope_scaling["factor"]
102+
if scaling_factor > 1:
103+
s = 0.1 * mscale_all_dim * math.log(scaling_factor) + 1.0
104+
self.scale = self.scale * s * s
105+
106+
self.rope = initialize_rope(
107+
dims=self.qk_rope_head_dim,
108+
base=self.rope_theta,
109+
traditional=False,
110+
max_position_embeddings=self.max_position_embeddings,
111+
scaling_config=self.config.rope_scaling,
112+
)
207113

208114
def __call__(
209115
self,

0 commit comments

Comments
 (0)