Skip to content

Commit cfa74ad

Browse files
ivanfioravantiawni
andauthored
Hunyuan V1 Dense model support (#351)
* Add Hunyuan V1 Dense model and support for --trust-remote-code option in evaluate and convert. * add explicit head dimension support in Hunyuan V1 Dense model for differences between - 0.5B - 4B - 1.8B - 7B * remove unused sanitize method from Hunyuan V1 Dense model * add lora --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent e22bdaa commit cfa74ad

File tree

5 files changed

+274
-1
lines changed

5 files changed

+274
-1
lines changed

mlx_lm/convert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ def configure_parser() -> argparse.ArgumentParser:
209209
action="store_true",
210210
default=False,
211211
)
212+
parser.add_argument(
213+
"--trust-remote-code",
214+
help="Trust remote code when loading tokenizer.",
215+
action="store_true",
216+
default=False,
217+
)
212218
return parser
213219

214220

mlx_lm/evaluate.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,13 @@ def __init__(
7171
path_or_hf_repo: str,
7272
max_tokens: Optional[int] = None,
7373
use_chat_template: Optional[bool] = None,
74+
trust_remote_code: bool = False,
7475
) -> None:
7576
super().__init__()
76-
self._model, self.tokenizer = load(path_or_hf_repo)
77+
tokenizer_config = {"trust_remote_code": True if trust_remote_code else None}
78+
self._model, self.tokenizer = load(
79+
path_or_hf_repo, tokenizer_config=tokenizer_config
80+
)
7781
self._max_tokens = max_tokens or self.tokenizer.model_max_length
7882
self._batch_size = 8
7983
self.use_chat_template = use_chat_template
@@ -378,6 +382,11 @@ def main():
378382
help="Confirm that you want to run tasks that execute untrusted code.",
379383
default=False,
380384
)
385+
parser.add_argument(
386+
"--trust-remote-code",
387+
action="store_true",
388+
help="Enable trusting remote code for tokenizer",
389+
)
381390

382391
args = parser.parse_args()
383392

@@ -393,6 +402,7 @@ def main():
393402
args.model,
394403
max_tokens=args.max_tokens,
395404
use_chat_template=args.apply_chat_template,
405+
trust_remote_code=args.trust_remote_code,
396406
)
397407
MLXLM.apply_chat_template = chat_template_fn(**args.chat_template_args)
398408

mlx_lm/models/hunyuan_v1_dense.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright © 2023-2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, Optional, Union
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
11+
12+
@dataclass
13+
class ModelArgs(BaseModelArgs):
14+
model_type: str
15+
vocab_size: int
16+
hidden_size: int
17+
num_hidden_layers: int
18+
intermediate_size: int
19+
num_attention_heads: int
20+
num_key_value_heads: int
21+
rms_norm_eps: float
22+
rope_theta: float = 10000
23+
max_position_embeddings: int = 32768
24+
attention_bias: bool = False
25+
use_qk_norm: bool = True
26+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
27+
tie_word_embeddings: bool = False
28+
head_dim: Optional[int] = None
29+
30+
def __post_init__(self):
31+
if self.rope_scaling:
32+
required_keys = {"alpha", "factor", "type"}
33+
if not all(key in self.rope_scaling for key in required_keys):
34+
raise ValueError(f"rope_scaling must contain keys {required_keys}")
35+
36+
37+
class DynamicNTKAlphaRoPE(nn.Module):
38+
def __init__(
39+
self,
40+
dims: int,
41+
base: float = 10000,
42+
scaling_alpha: float = 1.0,
43+
):
44+
super().__init__()
45+
self.dims = dims
46+
base = base * scaling_alpha ** (dims / (dims - 2))
47+
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
48+
49+
def __call__(self, x, offset: int = 0):
50+
return mx.fast.rope(
51+
x,
52+
self.dims,
53+
traditional=False,
54+
base=None,
55+
scale=1.0,
56+
offset=offset,
57+
freqs=self._freqs,
58+
)
59+
60+
61+
class Attention(nn.Module):
62+
def __init__(self, args: ModelArgs):
63+
super().__init__()
64+
65+
dim = args.hidden_size
66+
self.n_heads = n_heads = args.num_attention_heads
67+
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
68+
69+
head_dim = (
70+
args.head_dim if args.head_dim is not None else args.hidden_size // n_heads
71+
)
72+
self.head_dim = head_dim
73+
self.scale = head_dim**-0.5
74+
75+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
76+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
77+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
78+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
79+
80+
self.use_qk_norm = args.use_qk_norm
81+
if self.use_qk_norm:
82+
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
83+
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
84+
85+
scaling_alpha = 1.0
86+
if args.rope_scaling and "alpha" in args.rope_scaling:
87+
scaling_alpha = args.rope_scaling["alpha"]
88+
89+
self.rope = DynamicNTKAlphaRoPE(
90+
head_dim,
91+
base=args.rope_theta,
92+
scaling_alpha=scaling_alpha,
93+
)
94+
95+
def __call__(
96+
self,
97+
x: mx.array,
98+
mask: Optional[mx.array] = None,
99+
cache: Optional[Any] = None,
100+
) -> mx.array:
101+
B, L, D = x.shape
102+
103+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
104+
105+
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
106+
0, 2, 1, 3
107+
)
108+
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
109+
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
110+
0, 2, 1, 3
111+
)
112+
113+
if cache is not None:
114+
queries = self.rope(queries, offset=cache.offset)
115+
keys = self.rope(keys, offset=cache.offset)
116+
else:
117+
queries = self.rope(queries)
118+
keys = self.rope(keys)
119+
120+
if self.use_qk_norm:
121+
queries = self.query_layernorm(queries)
122+
keys = self.key_layernorm(keys)
123+
124+
if cache is not None:
125+
keys, values = cache.update_and_fetch(keys, values)
126+
127+
output = scaled_dot_product_attention(
128+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
129+
)
130+
131+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
132+
return self.o_proj(output)
133+
134+
135+
class MLP(nn.Module):
136+
def __init__(self, args: ModelArgs):
137+
super().__init__()
138+
139+
dim = args.hidden_size
140+
hidden_dim = args.intermediate_size
141+
142+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
143+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
144+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
145+
146+
def __call__(self, x) -> mx.array:
147+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
148+
149+
150+
class TransformerBlock(nn.Module):
151+
def __init__(self, args: ModelArgs):
152+
super().__init__()
153+
self.num_attention_heads = args.num_attention_heads
154+
self.hidden_size = args.hidden_size
155+
self.self_attn = Attention(args)
156+
self.mlp = MLP(args)
157+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
158+
self.post_attention_layernorm = nn.RMSNorm(
159+
args.hidden_size, eps=args.rms_norm_eps
160+
)
161+
self.args = args
162+
163+
def __call__(
164+
self,
165+
x: mx.array,
166+
mask: Optional[mx.array] = None,
167+
cache: Optional[Any] = None,
168+
) -> mx.array:
169+
r = self.self_attn(self.input_layernorm(x), mask, cache)
170+
h = x + r
171+
r = self.mlp(self.post_attention_layernorm(h))
172+
out = h + r
173+
return out
174+
175+
176+
class HunyuanV1DenseModel(nn.Module):
177+
def __init__(self, args: ModelArgs):
178+
super().__init__()
179+
self.args = args
180+
self.vocab_size = args.vocab_size
181+
self.num_hidden_layers = args.num_hidden_layers
182+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
183+
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
184+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
185+
186+
def __call__(
187+
self,
188+
inputs: mx.array,
189+
mask: mx.array = None,
190+
cache=None,
191+
):
192+
h = self.embed_tokens(inputs)
193+
194+
if mask is None:
195+
mask = create_attention_mask(h, cache)
196+
197+
if cache is None:
198+
cache = [None] * len(self.layers)
199+
200+
for layer, c in zip(self.layers, cache):
201+
h = layer(h, mask, c)
202+
203+
return self.norm(h)
204+
205+
206+
class Model(nn.Module):
207+
def __init__(self, args: ModelArgs):
208+
super().__init__()
209+
self.args = args
210+
self.model_type = args.model_type
211+
self.model = HunyuanV1DenseModel(args)
212+
213+
def __call__(
214+
self,
215+
inputs: mx.array,
216+
mask: mx.array = None,
217+
cache=None,
218+
):
219+
out = self.model(inputs, mask, cache)
220+
return self.model.embed_tokens.as_linear(out)
221+
222+
@property
223+
def layers(self):
224+
return self.model.layers

mlx_lm/tuner/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def to_lora(layer):
121121
"dots1",
122122
"smollm3",
123123
"exaone4",
124+
"hunyuan_v1_dense",
124125
}:
125126
keys = {"self_attn.q_proj", "self_attn.v_proj"}
126127
if model.model_type in ["mixtral", "phimoe"]:

tests/test_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,38 @@ def test_hunyuan(self):
10391039
model, args.model_type, args.vocab_size, args.num_hidden_layers
10401040
)
10411041

1042+
def test_hunyuan_v1_dense(self):
1043+
from mlx_lm.models import hunyuan_v1_dense
1044+
1045+
args = hunyuan_v1_dense.ModelArgs(
1046+
model_type="hunyuan_v1_dense",
1047+
hidden_size=128,
1048+
attention_bias=False,
1049+
intermediate_size=256,
1050+
num_attention_heads=4,
1051+
num_hidden_layers=4,
1052+
num_key_value_heads=2,
1053+
rms_norm_eps=1e-4,
1054+
rope_theta=1000,
1055+
vocab_size=1000,
1056+
use_qk_norm=True,
1057+
rope_scaling={
1058+
"alpha": 1000.0,
1059+
"factor": 1.0,
1060+
"type": "dynamic",
1061+
"beta_fast": 32,
1062+
"beta_slow": 1,
1063+
"mscale": 1.0,
1064+
"mscale_all_dim": 0.0,
1065+
"original_max_position_embeddings": 8192,
1066+
},
1067+
max_position_embeddings=32768,
1068+
)
1069+
model = hunyuan_v1_dense.Model(args)
1070+
self.model_test_runner(
1071+
model, args.model_type, args.vocab_size, args.num_hidden_layers
1072+
)
1073+
10421074
def test_olmo2(self):
10431075
from mlx_lm.models import olmo2
10441076

0 commit comments

Comments
 (0)