Skip to content

Commit 1d11449

Browse files
authored
Add Qwen3-VL (Dense) language model implementation (#553)
* Added Qwen3-VL dense language model * Added Qwen3-VL dense language model test
1 parent b1fc49a commit 1d11449

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

mlx_lm/models/qwen3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,12 @@ def __call__(
141141
self,
142142
inputs: mx.array,
143143
cache=None,
144+
input_embeddings: Optional[mx.array] = None,
144145
):
145-
h = self.embed_tokens(inputs)
146+
if input_embeddings is not None:
147+
h = input_embeddings
148+
else:
149+
h = self.embed_tokens(inputs)
146150

147151
if cache is None:
148152
cache = [None] * len(self.layers)
@@ -167,8 +171,9 @@ def __call__(
167171
self,
168172
inputs: mx.array,
169173
cache=None,
174+
input_embeddings: Optional[mx.array] = None,
170175
):
171-
out = self.model(inputs, cache)
176+
out = self.model(inputs, cache, input_embeddings)
172177
if self.args.tie_word_embeddings:
173178
out = self.model.embed_tokens.as_linear(out)
174179
else:

mlx_lm/models/qwen3_vl.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
from mlx.utils import tree_flatten, tree_unflatten
9+
10+
from . import qwen3
11+
from .base import BaseModelArgs
12+
13+
14+
@dataclass
15+
class ModelArgs(BaseModelArgs):
16+
model_type: str
17+
text_config: dict
18+
19+
@classmethod
20+
def from_dict(cls, params):
21+
if "text_config" not in params:
22+
return cls(model_type=params["model_type"], text_config=params)
23+
return super().from_dict(params)
24+
25+
26+
class Model(nn.Module):
27+
def __init__(self, args: ModelArgs):
28+
super().__init__()
29+
self.args = args
30+
self.model_type = args.model_type
31+
self.language_model = qwen3.Model(qwen3.ModelArgs.from_dict(args.text_config))
32+
33+
def __call__(
34+
self,
35+
inputs: mx.array,
36+
cache=None,
37+
input_embeddings: Optional[mx.array] = None,
38+
):
39+
return self.language_model(
40+
inputs, cache=cache, input_embeddings=input_embeddings
41+
)
42+
43+
def sanitize(self, weights):
44+
weights = tree_unflatten(list(weights.items()))
45+
weights.pop("vision_tower", None)
46+
weights = dict(tree_flatten(weights))
47+
48+
sanitized = {}
49+
for key, value in weights.items():
50+
if not key.startswith("language_model."):
51+
key = "language_model." + key
52+
sanitized[key] = value
53+
return sanitized
54+
55+
@property
56+
def layers(self):
57+
return self.language_model.model.layers

tests/test_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,25 @@ def test_all_models(self):
17601760
"num_hidden_layers": 4,
17611761
"vocab_size": 1000,
17621762
},
1763+
{
1764+
"model_type": "qwen3_vl",
1765+
"text_config": {
1766+
"model_type": "qwen3",
1767+
"hidden_size": 128,
1768+
"num_hidden_layers": 4,
1769+
"intermediate_size": 256,
1770+
"num_attention_heads": 4,
1771+
"num_key_value_heads": 2,
1772+
"rms_norm_eps": 1e-5,
1773+
"vocab_size": 1000,
1774+
"head_dim": 32,
1775+
"max_position_embeddings": 1000,
1776+
"tie_word_embeddings": False,
1777+
"rope_theta": 1000,
1778+
},
1779+
"num_hidden_layers": 4,
1780+
"vocab_size": 1000,
1781+
},
17631782
{
17641783
"model_type": "seed_oss",
17651784
"hidden_size": 128,

0 commit comments

Comments
 (0)