Skip to content

Fix bug with base NLLB translation #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,12 @@ def train(self) -> None:
model = self._convert_to_lora_model(model)

# Change specific variables based on the type of model
model, tokenizer = self._configure_model(model, tokenizer)
model, tokenizer = self._configure_model(
model,
tokenizer,
self._config.val_src_lang if self._config.val_src_lang else self._config.test_src_lang,
self._config.val_trg_lang if self._config.val_trg_lang else self._config.test_trg_lang,
)

def load_text_dataset(src_path: Path, trg_path: Path) -> Optional[Dataset]:
if not src_path.is_file() or not trg_path.is_file():
Expand Down Expand Up @@ -1104,7 +1109,7 @@ def translate_test_files(
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
) -> None:
tokenizer = self._config.get_tokenizer()
model = self._create_inference_model(ckpt, tokenizer)
model = self._create_inference_model(ckpt, tokenizer, self._config.test_src_lang, self._config.test_trg_lang)
pipeline = PretokenizedTranslationPipeline(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -1223,11 +1228,12 @@ def translate(
vrefs: Optional[Iterable[VerseRef]] = None,
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
) -> Iterable[TranslationGroup]:
src_lang = self._config.data["lang_codes"].get(src_iso, src_iso)
trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso)
tokenizer = self._config.get_tokenizer()
model = self._create_inference_model(ckpt, tokenizer)
model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang)
if model.config.max_length is not None and model.config.max_length < 512:
model.config.max_length = 512
lang_codes: Dict[str, str] = self._config.data["lang_codes"]

# The tokenizer isn't wrapped until after calling _create_inference_model,
# because the tokenizer's input/output language codes are set there
Expand All @@ -1237,8 +1243,8 @@ def translate(
pipeline = SilTranslationPipeline(
model=model,
tokenizer=tokenizer,
src_lang=lang_codes.get(src_iso, src_iso),
tgt_lang=lang_codes.get(trg_iso, trg_iso),
src_lang=src_lang,
tgt_lang=trg_lang,
device=0,
)

Expand Down Expand Up @@ -1697,7 +1703,11 @@ def _translate_with_diverse_beam_search(
return self._flatten_tokenized_translations(translations)

def _create_inference_model(
self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer
self,
ckpt: Union[CheckpointType, str, int],
tokenizer: PreTrainedTokenizer,
src_lang: str,
trg_lang: str,
) -> PreTrainedModel:
if self._config.model_dir.exists():
checkpoint_path, _ = self.get_checkpoint_path(ckpt)
Expand Down Expand Up @@ -1734,47 +1744,47 @@ def _create_inference_model(
if model_name == self._config.model and len(tokenizer) != model.get_input_embeddings().weight.size(dim=0):
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None)
if self._config.model_prefix == "google/madlad400" or model_name == self._config.model:
model, tokenizer = self._configure_model(model, tokenizer)
model, tokenizer = self._configure_model(model, tokenizer, src_lang, trg_lang)

return model

def _configure_model(
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, src_lang: str, trg_lang: str
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
# Set decoder_start_token_id
if (
self._config.val_trg_lang != ""
trg_lang != ""
and model.config.decoder_start_token_id is None
and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast))
):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[self._config.val_trg_lang]
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[trg_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang)
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(trg_lang)

if self._config.model_prefix == "google/madlad400":
model.config.decoder_start_token_id = tokenizer.pad_token_id
model.generation_config.decoder_start_token_id = tokenizer.pad_token_id
model.config.max_length = 256
model.generation_config.max_new_tokens = 256
tokenizer.tgt_lang = self._config.val_trg_lang
tokenizer.tgt_lang = trg_lang

if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

if (
self._config.val_src_lang != ""
and self._config.val_trg_lang != ""
src_lang != ""
and trg_lang != ""
and isinstance(
tokenizer, (MBartTokenizer, MBartTokenizerFast, M2M100Tokenizer, NllbTokenizer, NllbTokenizerFast)
)
):
tokenizer.src_lang = self._config.val_src_lang
tokenizer.tgt_lang = self._config.val_trg_lang
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = trg_lang

# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token.
forced_bos_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang)
forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang)
model.config.forced_bos_token_id = forced_bos_token_id
if model.generation_config is not None:
model.generation_config.forced_bos_token_id = forced_bos_token_id
Expand Down