|
9 | 9 | import mlx.nn as nn |
10 | 10 |
|
11 | 11 | from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention |
| 12 | +from .rope_utils import initialize_rope |
12 | 13 | from .switch_layers import SwitchGLU |
13 | 14 |
|
14 | 15 |
|
@@ -45,85 +46,6 @@ class ModelArgs(BaseModelArgs): |
45 | 46 | attention_bias: bool = False |
46 | 47 |
|
47 | 48 |
|
48 | | -def yarn_find_correction_dim( |
49 | | - num_rotations, dim, base=10000, max_position_embeddings=2048 |
50 | | -): |
51 | | - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( |
52 | | - 2 * math.log(base) |
53 | | - ) |
54 | | - |
55 | | - |
56 | | -def yarn_find_correction_range( |
57 | | - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 |
58 | | -): |
59 | | - low = math.floor( |
60 | | - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) |
61 | | - ) |
62 | | - high = math.ceil( |
63 | | - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) |
64 | | - ) |
65 | | - return max(low, 0), min(high, dim - 1) |
66 | | - |
67 | | - |
68 | | -def yarn_get_mscale(scale=1, mscale=1): |
69 | | - if scale <= 1: |
70 | | - return 1.0 |
71 | | - return 0.1 * mscale * math.log(scale) + 1.0 |
72 | | - |
73 | | - |
74 | | -def yarn_linear_ramp_mask(min_val, max_val, dim): |
75 | | - if min_val == max_val: |
76 | | - max_val += 0.001 # Prevent singularity |
77 | | - |
78 | | - linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) |
79 | | - return mx.clip(linear_func, 0, 1) |
80 | | - |
81 | | - |
82 | | -class DeepseekV3YarnRotaryEmbedding(nn.Module): |
83 | | - def __init__( |
84 | | - self, |
85 | | - dim, |
86 | | - max_position_embeddings=2048, |
87 | | - base=10000, |
88 | | - scaling_factor=1.0, |
89 | | - original_max_position_embeddings=4096, |
90 | | - beta_fast=32, |
91 | | - beta_slow=1, |
92 | | - mscale=1, |
93 | | - mscale_all_dim=0, |
94 | | - ): |
95 | | - super().__init__() |
96 | | - self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( |
97 | | - scaling_factor, mscale_all_dim |
98 | | - ) |
99 | | - freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) |
100 | | - freq_inter = scaling_factor * freq_extra |
101 | | - low, high = yarn_find_correction_range( |
102 | | - beta_fast, |
103 | | - beta_slow, |
104 | | - dim, |
105 | | - base, |
106 | | - original_max_position_embeddings, |
107 | | - ) |
108 | | - freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) |
109 | | - self._freqs = (freq_inter * freq_extra) / ( |
110 | | - freq_inter * freq_mask + freq_extra * (1 - freq_mask) |
111 | | - ) |
112 | | - |
113 | | - def __call__(self, x, offset=0): |
114 | | - if self.mscale != 1.0: |
115 | | - x = self.mscale * x |
116 | | - return mx.fast.rope( |
117 | | - x, |
118 | | - x.shape[-1], |
119 | | - traditional=True, |
120 | | - base=None, |
121 | | - scale=1.0, |
122 | | - offset=offset, |
123 | | - freqs=self._freqs, |
124 | | - ) |
125 | | - |
126 | | - |
127 | 49 | class DeepseekV3Attention(nn.Module): |
128 | 50 | def __init__(self, config: ModelArgs): |
129 | 51 | super().__init__() |
@@ -175,35 +97,19 @@ def __init__(self, config: ModelArgs): |
175 | 97 |
|
176 | 98 | if self.config.rope_scaling is not None: |
177 | 99 | mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) |
178 | | - scaling_factor = self.config.rope_scaling["factor"] |
179 | 100 | if mscale_all_dim: |
180 | | - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) |
181 | | - self.scale = self.scale * mscale * mscale |
182 | | - |
183 | | - rope_kwargs = { |
184 | | - key: self.config.rope_scaling[key] |
185 | | - for key in [ |
186 | | - "original_max_position_embeddings", |
187 | | - "beta_fast", |
188 | | - "beta_slow", |
189 | | - "mscale", |
190 | | - "mscale_all_dim", |
191 | | - ] |
192 | | - if key in self.config.rope_scaling |
193 | | - } |
194 | | - self.rope = DeepseekV3YarnRotaryEmbedding( |
195 | | - dim=self.qk_rope_head_dim, |
196 | | - max_position_embeddings=self.max_position_embeddings, |
197 | | - scaling_factor=scaling_factor, |
198 | | - base=self.rope_theta, |
199 | | - **rope_kwargs, |
200 | | - ) |
201 | | - else: |
202 | | - self.rope = nn.RoPE( |
203 | | - dims=self.qk_rope_head_dim, |
204 | | - base=self.rope_theta, |
205 | | - traditional=True, |
206 | | - ) |
| 101 | + scaling_factor = self.config.rope_scaling["factor"] |
| 102 | + if scaling_factor > 1: |
| 103 | + s = 0.1 * mscale_all_dim * math.log(scaling_factor) + 1.0 |
| 104 | + self.scale = self.scale * s * s |
| 105 | + |
| 106 | + self.rope = initialize_rope( |
| 107 | + dims=self.qk_rope_head_dim, |
| 108 | + base=self.rope_theta, |
| 109 | + traditional=False, |
| 110 | + max_position_embeddings=self.max_position_embeddings, |
| 111 | + scaling_config=self.config.rope_scaling, |
| 112 | + ) |
207 | 113 |
|
208 | 114 | def __call__( |
209 | 115 | self, |
|
0 commit comments