Skip to content

Commit 6940b44

Browse files
authored
Auto convert tekken.json (#42299)
* auto convert tekken.json * fix conversion * simplify * nit * model info based on the fly fix * up * last nit * fixup * call it fix mistral regex * fix behaviour for local or only tok is saved * style * rm comment at wrong palce * fix escaping * style * fix backend tokenizer attr to _tokenizer * update * up * update * fix the last red tests
1 parent 00ab75e commit 6940b44

File tree

3 files changed

+172
-6
lines changed

3 files changed

+172
-6
lines changed

src/transformers/convert_slow_tokenizer.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
"""
2020

2121
import warnings
22+
from functools import lru_cache
2223
from typing import Optional
2324

2425
from packaging import version
2526
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
2627
from tokenizers.models import BPE, Unigram, WordPiece
28+
from tqdm import tqdm
2729

2830
from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
2931
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
@@ -1692,6 +1694,85 @@ def converted(self) -> Tokenizer:
16921694
return tokenizer
16931695

16941696

1697+
class MistralConverter:
1698+
def __init__(
1699+
self,
1700+
vocab_file=None,
1701+
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
1702+
add_prefix_space=False,
1703+
additional_special_tokens=None,
1704+
**kwargs,
1705+
):
1706+
self.vocab_file = vocab_file
1707+
self.pattern = pattern
1708+
self.add_prefix_space = add_prefix_space
1709+
self.additional_special_tokens = (
1710+
additional_special_tokens.keys()
1711+
if isinstance(additional_special_tokens, dict)
1712+
else additional_special_tokens
1713+
)
1714+
1715+
def extract_vocab_merges_from_model(self, tiktoken_url: str):
1716+
import base64
1717+
import json
1718+
1719+
with open(self.vocab_file, "r", encoding="utf-8") as f:
1720+
untyped = json.load(f)
1721+
self.pattern = untyped["config"]["pattern"]
1722+
self.additional_special_tokens = [
1723+
AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
1724+
]
1725+
bpe_ranks = untyped["vocab"]
1726+
byte_encoder = bytes_to_unicode()
1727+
1728+
@lru_cache
1729+
def token_bytes_to_string(b):
1730+
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
1731+
1732+
merges = []
1733+
vocab = {}
1734+
for idx, token in enumerate(self.additional_special_tokens):
1735+
vocab[token.content] = idx
1736+
bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
1737+
rank_set = set(bpe_ranks)
1738+
for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
1739+
vocab[token_bytes_to_string(token)] = rank
1740+
if len(token) == 1:
1741+
continue
1742+
local = []
1743+
for index in range(1, len(token)):
1744+
piece_l, piece_r = token[:index], token[index:]
1745+
if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
1746+
local.append((piece_l, piece_r, rank))
1747+
local = sorted(local, key=lambda x: (bpe_ranks.index(x[0]), bpe_ranks.index(x[1])), reverse=False)
1748+
merges.extend(local)
1749+
merges = sorted(merges, key=lambda val: val[2], reverse=False)
1750+
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
1751+
return vocab, merges
1752+
1753+
def tokenizer(self):
1754+
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
1755+
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
1756+
if hasattr(tokenizer.model, "ignore_merges"):
1757+
tokenizer.model.ignore_merges = True
1758+
return tokenizer
1759+
1760+
def converted(self) -> Tokenizer:
1761+
tokenizer = self.tokenizer()
1762+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1763+
[
1764+
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
1765+
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
1766+
]
1767+
)
1768+
tokenizer.decoder = decoders.ByteLevel()
1769+
1770+
tokenizer.add_tokens(self.additional_special_tokens)
1771+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
1772+
1773+
return tokenizer
1774+
1775+
16951776
SLOW_TO_FAST_CONVERTERS = {
16961777
"AlbertTokenizer": AlbertConverter,
16971778
"BartTokenizer": RobertaConverter,
@@ -1771,7 +1852,10 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
17711852
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
17721853
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
17731854
return converter_class(transformer_tokenizer).converted()
1774-
1855+
elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
1856+
transformer_tokenizer.original_tokenizer = transformer_tokenizer
1857+
logger.info("Converting from Mistral tekken.json")
1858+
return MistralConverter(transformer_tokenizer.vocab_file).converted()
17751859
else:
17761860
try:
17771861
logger.info("Converting from Tiktoken")

src/transformers/models/auto/tokenization_auto.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,8 @@
748748
(
749749
"voxtral",
750750
(
751-
"MistralCommonTokenizer"
752-
if is_mistral_common_available()
753-
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
754-
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
751+
"MistralCommonTokenizer" if is_mistral_common_available() else None,
752+
"PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
755753
),
756754
),
757755
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),

src/transformers/tokenization_utils_base.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union, overload
3434

3535
import numpy as np
36+
from huggingface_hub import list_repo_files
3637
from packaging import version
3738

3839
from . import __version__
@@ -2098,7 +2099,21 @@ def from_pretrained(
20982099
template = template.removesuffix(".jinja")
20992100
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
21002101

2101-
# Get files from url, cache, or disk depending on the case
2102+
if not is_local and not local_files_only:
2103+
try:
2104+
remote_files = list_repo_files(pretrained_model_name_or_path)
2105+
except Exception:
2106+
remote_files = []
2107+
else:
2108+
remote_files = os.listdir(pretrained_model_name_or_path)
2109+
2110+
if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
2111+
# mistral tokenizer names are different, but we can still convert them if
2112+
# mistral common is not there
2113+
other_pattern = re.escape("tekken.json|tokenizer.model.*")
2114+
if match := re.search(other_pattern, "\n".join(remote_files)):
2115+
vocab_files["vocab_file"] = match.group()
2116+
21022117
resolved_vocab_files = {}
21032118
for file_id, file_path in vocab_files.items():
21042119
if file_path is None:
@@ -2417,6 +2432,75 @@ def _from_pretrained(
24172432
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
24182433
" fine-tuned or trained."
24192434
)
2435+
try:
2436+
vocab_size = tokenizer.vocab_size
2437+
except NotImplementedError:
2438+
vocab_size = 0
2439+
2440+
if (
2441+
vocab_size > 100000
2442+
and hasattr(tokenizer, "_tokenizer")
2443+
and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
2444+
):
2445+
from huggingface_hub import model_info
2446+
2447+
def is_base_mistral(model_id: str) -> bool:
2448+
model = model_info(model_id)
2449+
if model.tags is not None:
2450+
if re.search("base_model:.*mistralai", "".join(model.tags)):
2451+
return True
2452+
return False
2453+
2454+
if _is_local or is_base_mistral(pretrained_model_name_or_path):
2455+
_config_file = cached_file(
2456+
pretrained_model_name_or_path,
2457+
"config.json",
2458+
cache_dir=cache_dir,
2459+
token=token,
2460+
local_files_only=local_files_only,
2461+
_raise_exceptions_for_missing_entries=False,
2462+
_raise_exceptions_for_connection_errors=False,
2463+
_commit_hash=_commit_hash,
2464+
)
2465+
if _config_file is not None:
2466+
with open(_config_file, encoding="utf-8") as f:
2467+
_config = json.load(f)
2468+
transformers_version = _config.get("transformers_version")
2469+
2470+
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
2471+
if _is_local and _config.model_type not in [
2472+
"mistral",
2473+
"mistral3",
2474+
"voxstral",
2475+
"ministral",
2476+
"pixtral",
2477+
]:
2478+
return tokenizer
2479+
2480+
# Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
2481+
if "fix_mistral_regex" in init_kwargs:
2482+
setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])
2483+
2484+
fix_mistral_regex = kwargs.get("fix_mistral_regex") # not init kwargs
2485+
# only warn if its not explicitly passed
2486+
if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
2487+
setattr(tokenizer, "fix_mistral_regex", False)
2488+
logger.warning(
2489+
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
2490+
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
2491+
" This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
2492+
)
2493+
elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
2494+
setattr(tokenizer, "fix_mistral_regex", True)
2495+
import tokenizers
2496+
2497+
tokenizer.backend_tokenizer.pre_tokenizer[0] = tokenizers.pre_tokenizers.Split(
2498+
pattern=tokenizers.Regex(
2499+
r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+"
2500+
),
2501+
behavior="isolated",
2502+
)
2503+
24202504
return tokenizer
24212505

24222506
@staticmethod

0 commit comments

Comments
 (0)