Skip to content

Commit b326d4d

Browse files
committed
chore: factor mistral out of tokenizer_utils
1 parent ba8ede7 commit b326d4d

File tree

2 files changed

+453
-475
lines changed

2 files changed

+453
-475
lines changed

mlx_lm/tokenizer_mistral_utils.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
"""
2+
Mistral-specific tokenizer utilities.
3+
4+
This module contains all Mistral-specific functionality that was previously embedded
5+
in tokenizer_utils.py, providing clean separation between standard HuggingFace
6+
tokenizer support and Mistral tokenizer support.
7+
"""
8+
9+
try:
10+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
11+
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
12+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
13+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
14+
from mistral_common.protocol.instruct.messages import (
15+
UserMessage,
16+
AssistantMessage,
17+
SystemMessage,
18+
ToolMessage,
19+
)
20+
from mistral_common.protocol.instruct.tool_calls import (
21+
ToolCall,
22+
FunctionCall,
23+
Function,
24+
Tool,
25+
)
26+
27+
MISTRAL_AVAILABLE = True
28+
except ImportError:
29+
MistralTokenizer = None # type: ignore
30+
Tekkenizer = None # type: ignore
31+
SpecialTokenPolicy = None # type: ignore
32+
ChatCompletionRequest = None # type: ignore
33+
UserMessage = None # type: ignore
34+
AssistantMessage = None # type: ignore
35+
SystemMessage = None # type: ignore
36+
TextChunk = None # type: ignore
37+
ToolMessage = None # type: ignore
38+
ToolCall = None # type: ignore
39+
FunctionCall = None # type: ignore
40+
Function = None # type: ignore
41+
Tool = None # type: ignore
42+
43+
MISTRAL_AVAILABLE = False
44+
45+
46+
class StreamingDetokenizer:
47+
"""The streaming detokenizer interface so that we can detokenize one token at a time."""
48+
49+
__slots__ = ("text", "tokens", "offset")
50+
51+
def reset(self):
52+
raise NotImplementedError()
53+
54+
def add_token(self, token):
55+
raise NotImplementedError()
56+
57+
def finalize(self):
58+
raise NotImplementedError()
59+
60+
@property
61+
def last_segment(self):
62+
"""Return the last segment of readable text since last time this property was accessed."""
63+
text = self.text
64+
segment = text[self.offset :]
65+
self.offset = len(text)
66+
return segment
67+
68+
69+
class MistralStreamingDetokenizer(StreamingDetokenizer):
70+
"""Efficient streaming detokenizer for MistralTokenizer with byte/unicode edge handling."""
71+
72+
def __init__(self, tokenizer):
73+
# Extract the underlying Tekkenizer from MistralTokenizer
74+
if hasattr(tokenizer, "instruct_tokenizer") and hasattr(
75+
tokenizer.instruct_tokenizer, "tokenizer"
76+
):
77+
self._tokenizer = tokenizer.instruct_tokenizer.tokenizer
78+
else:
79+
self._tokenizer = tokenizer
80+
if MISTRAL_AVAILABLE and Tekkenizer is not None:
81+
assert isinstance(self._tokenizer, Tekkenizer)
82+
self.reset()
83+
84+
def reset(self):
85+
self.offset = 0
86+
self.tokens = []
87+
self._text = ""
88+
self._buffer = []
89+
self._current_text = ""
90+
91+
def add_token(self, token):
92+
self._buffer.append(token)
93+
self.tokens.append(token)
94+
# Decode only the buffer to avoid unnecessary detokenization
95+
if MISTRAL_AVAILABLE and SpecialTokenPolicy is not None:
96+
decoded = self._tokenizer.decode(
97+
self._buffer, special_token_policy=SpecialTokenPolicy.KEEP
98+
)
99+
else:
100+
decoded = self._tokenizer.decode(self._buffer)
101+
# Heuristic: only flush if the decoded text is valid (no replacement
102+
# char) or ends with a space/newline
103+
if decoded and not decoded.endswith("\ufffd"):
104+
self._text += decoded
105+
self._buffer.clear()
106+
self._current_text = ""
107+
else:
108+
self._current_text = decoded
109+
110+
def finalize(self):
111+
if self._buffer:
112+
decoded = self._tokenizer.decode(self._buffer)
113+
self._text += decoded
114+
self._buffer = []
115+
self._current_text = ""
116+
117+
@property
118+
def text(self):
119+
return self._text + self._current_text
120+
121+
122+
class MistralTokenizerWrapper:
123+
"""Helper class that provides Mistral-specific tokenizer functionality."""
124+
125+
def __init__(self, tokenizer):
126+
self._tokenizer = tokenizer
127+
128+
def is_mistral_tokenizer(self, tokenizer) -> bool:
129+
"""Check if tokenizer is a MistralTokenizer."""
130+
return hasattr(tokenizer, "instruct_tokenizer") and hasattr(
131+
tokenizer.instruct_tokenizer, "tokenizer"
132+
)
133+
134+
def get_underlying_tokenizer(self, tokenizer):
135+
"""Get the underlying Tekkenizer from MistralTokenizer if applicable."""
136+
if self.is_mistral_tokenizer(tokenizer):
137+
return tokenizer.instruct_tokenizer.tokenizer
138+
return tokenizer
139+
140+
def get_vocab(self, tokenizer):
141+
"""Get vocabulary from Mistral tokenizer."""
142+
if self.is_mistral_tokenizer(tokenizer):
143+
# For MistralTokenizer, get vocab from underlying tokenizer
144+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
145+
if hasattr(underlying_tokenizer, "vocab") and callable(
146+
underlying_tokenizer.vocab
147+
):
148+
vocab_list = underlying_tokenizer.vocab()
149+
return {token: idx for idx, token in enumerate(vocab_list)} # type: ignore
150+
return {}
151+
152+
def has_mistral_chat_completion(self, tokenizer):
153+
"""Check if tokenizer supports Mistral chat completion API."""
154+
return (
155+
hasattr(tokenizer, "encode_chat_completion")
156+
and ChatCompletionRequest is not None
157+
and UserMessage is not None
158+
and AssistantMessage is not None
159+
and SystemMessage is not None
160+
)
161+
162+
def get_eos_token_id(self, tokenizer):
163+
"""Get EOS token ID from Mistral tokenizer."""
164+
if self.is_mistral_tokenizer(tokenizer):
165+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
166+
return getattr(underlying_tokenizer, "eos_id", None)
167+
return None
168+
169+
def encode(self, tokenizer, text, add_special_tokens=True, **kwargs):
170+
"""Custom encode method for Mistral tokenizers."""
171+
if self.is_mistral_tokenizer(tokenizer):
172+
# For MistralTokenizer, use underlying Tekkenizer with bos/eos parameters
173+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
174+
return underlying_tokenizer.encode(
175+
text,
176+
bos=add_special_tokens,
177+
eos=False, # Usually we don't want EOS during encoding
178+
**kwargs,
179+
)
180+
else:
181+
raise ValueError("Not a Mistral tokenizer")
182+
183+
def convert_to_mistral_messages(self, messages):
184+
"""Convert OpenAI-format messages to Mistral-common format."""
185+
if not MISTRAL_AVAILABLE:
186+
return []
187+
188+
mistral_messages = []
189+
# Track tool calls to map IDs back to function names
190+
tool_call_map = {}
191+
192+
for msg in messages:
193+
role = msg["role"]
194+
content = msg.get("content")
195+
196+
if role == "system" and SystemMessage is not None:
197+
mistral_messages.append(SystemMessage(content=content))
198+
199+
elif role == "user" and UserMessage is not None:
200+
mistral_messages.append(UserMessage(content=content))
201+
202+
elif role == "assistant" and AssistantMessage is not None:
203+
if "tool_calls" in msg and msg["tool_calls"]:
204+
try:
205+
if ToolCall is not None and FunctionCall is not None:
206+
tool_calls = []
207+
for tool_call in msg["tool_calls"]:
208+
if tool_call.get("type") == "function":
209+
function_call = tool_call["function"]
210+
call_id = tool_call["id"]
211+
func_name = function_call["name"]
212+
213+
# Store mapping for later tool result messages
214+
tool_call_map[call_id] = func_name
215+
216+
tool_calls.append(
217+
ToolCall(
218+
id=call_id,
219+
function=FunctionCall(
220+
name=func_name,
221+
arguments=function_call["arguments"],
222+
),
223+
)
224+
)
225+
226+
mistral_messages.append(
227+
AssistantMessage(content=content, tool_calls=tool_calls)
228+
)
229+
else:
230+
mistral_messages.append(AssistantMessage(content=content))
231+
except (ImportError, TypeError):
232+
mistral_messages.append(AssistantMessage(content=content))
233+
else:
234+
mistral_messages.append(AssistantMessage(content=content))
235+
236+
elif role == "tool":
237+
try:
238+
if ToolMessage is not None:
239+
tool_call_id = msg["tool_call_id"]
240+
name = msg.get("name", "")
241+
242+
# If name is missing, try to get it from our mapping
243+
if not name and tool_call_id in tool_call_map:
244+
name = tool_call_map[tool_call_id]
245+
246+
# If we still don't have a name, log a warning but continue
247+
if not name:
248+
print(
249+
f"Warning: Tool message missing function name for call_id {tool_call_id}"
250+
)
251+
252+
mistral_messages.append(
253+
ToolMessage(
254+
tool_call_id=tool_call_id,
255+
name=name,
256+
content=content,
257+
)
258+
)
259+
except (ImportError, TypeError):
260+
pass
261+
262+
return mistral_messages
263+
264+
def convert_to_mistral_tools(self, tools):
265+
"""Convert OpenAI-format tools to Mistral-common format."""
266+
if not tools or Tool is None or Function is None:
267+
return None
268+
269+
mistral_tools = []
270+
for tool in tools:
271+
if tool.get("type") == "function" and "function" in tool:
272+
func_def = tool["function"]
273+
mistral_tool = Tool(
274+
function=Function(
275+
name=func_def["name"],
276+
description=func_def.get("description", ""),
277+
parameters=func_def.get("parameters", {}),
278+
)
279+
)
280+
mistral_tools.append(mistral_tool)
281+
return mistral_tools
282+
283+
def apply_mistral_chat_template(
284+
self, tokenizer, messages, add_generation_prompt=True, tools=None
285+
):
286+
"""Apply chat template using Mistral tokenizer."""
287+
if not MISTRAL_AVAILABLE or ChatCompletionRequest is None:
288+
raise ValueError("Mistral libraries not available")
289+
290+
try:
291+
# Convert to Mistral-common format
292+
mistral_messages = self.convert_to_mistral_messages(messages)
293+
mistral_tools = self.convert_to_mistral_tools(tools)
294+
295+
# Create ChatCompletionRequest
296+
request = ChatCompletionRequest(
297+
messages=mistral_messages, tools=mistral_tools
298+
)
299+
300+
# Encode with MistralTokenizer
301+
result = tokenizer.encode_chat_completion(request)
302+
303+
# Handle generation prompt - if we don't want generation prompt,
304+
# we might need to modify the tokens to remove the space at the end
305+
if not add_generation_prompt and result.text.endswith(" "):
306+
# Remove the last token if it's just a space for generation
307+
return (
308+
result.tokens[:-1]
309+
if result.tokens and result.tokens[-1] != result.tokens[0]
310+
else result.tokens
311+
)
312+
313+
return result.tokens
314+
315+
except Exception:
316+
# Let the main tokenizer wrapper handle the fallback
317+
raise
318+
319+
def get_bos_token(self, tokenizer):
320+
"""Get BOS token from Mistral tokenizer."""
321+
if self.is_mistral_tokenizer(tokenizer):
322+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
323+
if hasattr(underlying_tokenizer, "bos_id"):
324+
try:
325+
return underlying_tokenizer.decode([underlying_tokenizer.bos_id])
326+
except Exception:
327+
return None
328+
return None
329+
330+
def get_eos_token(self, tokenizer):
331+
"""Get EOS token from Mistral tokenizer."""
332+
if self.is_mistral_tokenizer(tokenizer):
333+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
334+
if hasattr(underlying_tokenizer, "eos_id"):
335+
try:
336+
return underlying_tokenizer.decode([underlying_tokenizer.eos_id])
337+
except Exception:
338+
return None
339+
return None
340+
341+
def save_pretrained(self, tokenizer, save_directory, **kwargs):
342+
"""Save Mistral tokenizer."""
343+
from pathlib import Path
344+
345+
save_path = Path(save_directory)
346+
save_path.mkdir(parents=True, exist_ok=True)
347+
348+
if self.is_mistral_tokenizer(tokenizer):
349+
# For MistralTokenizer, check if the underlying tokenizer has a file_path
350+
underlying_tokenizer = self.get_underlying_tokenizer(tokenizer)
351+
if (
352+
hasattr(underlying_tokenizer, "file_path")
353+
and underlying_tokenizer.file_path
354+
):
355+
# Copy the original tekken.json file
356+
import shutil
357+
358+
tekken_file = Path(underlying_tokenizer.file_path)
359+
if tekken_file.exists():
360+
shutil.copy2(tekken_file, save_path / "tekken.json")
361+
else:
362+
print(f"Warning: Could not find tekken.json at {tekken_file}")
363+
else:
364+
print(
365+
"Warning: MistralTokenizer has no file_path, cannot save tekken.json"
366+
)
367+
368+
369+
def load_mistral_tokenizer(model_path, eos_token_ids=None):
370+
"""Load a Mistral tokenizer if tekken.json exists."""
371+
tekken_file = model_path / "tekken.json"
372+
if tekken_file.exists() and MistralTokenizer is not None:
373+
tokenizer = MistralTokenizer.from_file(str(tekken_file))
374+
if isinstance(eos_token_ids, int):
375+
eos_token_ids = [eos_token_ids]
376+
return tokenizer, MistralStreamingDetokenizer, eos_token_ids
377+
return None, None, None

0 commit comments

Comments
 (0)