Skip to content

Commit 29a810e

Browse files
authored
Prevent unnecessarily overwriting the default Hugging Face chat template (#2183)
Signed-off-by: Keshav Santhanam <[email protected]>
1 parent 63d4e7d commit 29a810e

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,13 @@ def __init__(
6969
pretrained_model_name_or_path=tokenizer_path,
7070
use_fast=use_fast,
7171
trust_remote_code=trust_remote_code,
72-
chat_template=chat_template,
7372
)
7473
elif merges_file is None:
7574
self.tokenizer = AutoTokenizer.from_pretrained(
7675
pretrained_model_name_or_path=tokenizer_path,
7776
vocab_file=vocab_file,
7877
use_fast=use_fast,
7978
trust_remote_code=trust_remote_code,
80-
chat_template=chat_template,
8179
)
8280
else:
8381
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -86,14 +84,21 @@ def __init__(
8684
merge_files=merges_file,
8785
use_fast=use_fast,
8886
trust_remote_code=trust_remote_code,
89-
chat_template=chat_template,
9087
)
9188
except Exception as e:
9289
raise ValueError(
9390
'Unable to instantiate HuggingFace AutoTokenizer '
9491
f'for {tokenizer_path}. Exception: {e}'
9592
)
9693

94+
# Store the tokenizer's existing chat template if the user does not provide
95+
# a custom chat template. Otherwise, override the default chat template with
96+
# the user-provided template.
97+
if chat_template is None:
98+
chat_template = self.tokenizer.chat_template
99+
else:
100+
self.tokenizer.chat_template = chat_template
101+
97102
self.include_special_tokens = include_special_tokens
98103
self.original_vocab_size = len(self.tokenizer)
99104
self.chat_template = chat_template

megatron/core/tokenizers/text/text_tokenizer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ def __init__(self, path: str, config: dict, **kwargs) -> None:
3737
self._tokenizer = self._restore_model(**kwargs)
3838
self.additional_args = kwargs
3939
self.path = path
40-
if (
41-
config.get("chat_template", None) is None
42-
and kwargs.get("chat_template", None) is not None
43-
):
44-
self.chat_template = kwargs.get("chat_template", None)
40+
41+
config_template = config.get("chat_template", None)
42+
tokenizer_template = getattr(self._tokenizer, "chat_template", None)
43+
kwargs_template = kwargs.get("chat_template", None)
44+
45+
if config_template is not None:
46+
self.chat_template = config_template
47+
elif tokenizer_template is not None:
48+
self.chat_template = tokenizer_template
4549
else:
46-
self.chat_template = config.get("chat_template", None)
50+
self.chat_template = kwargs_template
4751

4852
def _restore_model(self, **kwargs) -> MegatronTokenizerTextAbstract:
4953
"""Returns tokenizer library object."""

0 commit comments

Comments
 (0)