Skip to content

Commit cd9884d

Browse files
vincentamatoawni
andauthored
Add Qwen2-VL model implementation (#384)
* Add Qwen2-VL + Qwen2.5-VL * Fixed model sanitize method to handle both HF and MLX parameter formats * Cleaned up MRoPE implemenation * Formatted code * Added type casting in MRoPE * Removed unused instance variables * Removed unnecessary MRoPE implemenation * bump version --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 249b0a1 commit cd9884d

File tree

4 files changed

+67
-2
lines changed

4 files changed

+67
-2
lines changed

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2025 Apple Inc.
22

3-
__version__ = "0.26.3"
3+
__version__ = "0.26.4"

mlx_lm/models/qwen2_vl.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
from mlx.utils import tree_flatten, tree_unflatten
9+
10+
from . import qwen2
11+
from .base import BaseModelArgs
12+
13+
14+
@dataclass
15+
class ModelArgs(BaseModelArgs):
16+
model_type: str
17+
text_config: dict
18+
19+
@classmethod
20+
def from_dict(cls, params):
21+
if "text_config" not in params:
22+
return cls(model_type=params["model_type"], text_config=params)
23+
return cls(**params)
24+
25+
26+
class Model(nn.Module):
27+
def __init__(self, args: ModelArgs):
28+
super().__init__()
29+
self.args = args
30+
self.model_type = args.model_type
31+
self.language_model = qwen2.Model(qwen2.ModelArgs.from_dict(args.text_config))
32+
33+
def __call__(
34+
self,
35+
inputs: mx.array,
36+
cache=None,
37+
mask: Optional[mx.array] = None,
38+
input_embeddings: Optional[mx.array] = None,
39+
):
40+
return self.language_model(
41+
inputs, cache=cache, mask=mask, input_embeddings=input_embeddings
42+
)
43+
44+
def sanitize(self, weights):
45+
weights = tree_unflatten(list(weights.items()))
46+
weights.pop("visual", None)
47+
weights.pop("vision_tower", None)
48+
weights = dict(tree_flatten(weights))
49+
50+
sanitized = {}
51+
for key, value in weights.items():
52+
if not key.startswith("language_model."):
53+
key = "language_model." + key
54+
sanitized[key] = value
55+
return sanitized
56+
57+
@property
58+
def layers(self):
59+
return self.language_model.model.layers

mlx_lm/models/rope_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ def initialize_rope(
251251
short_factor=scaling_config["short_factor"],
252252
long_factor=scaling_config["long_factor"],
253253
)
254-
254+
elif rope_type == "mrope":
255+
mrope_section = scaling_config.get("mrope_section", [])
256+
assert (
257+
len(mrope_section) == 3
258+
), f"MRoPE currently only supports 3 sections, got {len(mrope_section)}."
259+
return nn.RoPE(dims, traditional=traditional, base=base)
255260
else:
256261
raise ValueError(f"Unsupported RoPE type {rope_type}")

mlx_lm/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"phi-msft": "phixtral",
4646
"falcon_mamba": "mamba",
4747
"kimi_k2": "deepseek_v3",
48+
"qwen2_5_vl": "qwen2_vl",
4849
}
4950

5051
MAX_FILE_SIZE_GB = 5

0 commit comments

Comments
 (0)