diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 56c3ce7d..be960221 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -905,7 +905,12 @@ def train(self) -> None: model = self._convert_to_lora_model(model) # Change specific variables based on the type of model - model, tokenizer = self._configure_model(model, tokenizer) + model, tokenizer = self._configure_model( + model, + tokenizer, + self._config.val_src_lang if self._config.val_src_lang else self._config.test_src_lang, + self._config.val_trg_lang if self._config.val_trg_lang else self._config.test_trg_lang, + ) def load_text_dataset(src_path: Path, trg_path: Path) -> Optional[Dataset]: if not src_path.is_file() or not trg_path.is_file(): @@ -1104,7 +1109,7 @@ def translate_test_files( ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> None: tokenizer = self._config.get_tokenizer() - model = self._create_inference_model(ckpt, tokenizer) + model = self._create_inference_model(ckpt, tokenizer, self._config.test_src_lang, self._config.test_trg_lang) pipeline = PretokenizedTranslationPipeline( model=model, tokenizer=tokenizer, @@ -1223,11 +1228,12 @@ def translate( vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Iterable[TranslationGroup]: + src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) + trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) tokenizer = self._config.get_tokenizer() - model = self._create_inference_model(ckpt, tokenizer) + model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang) if model.config.max_length is not None and model.config.max_length < 512: model.config.max_length = 512 - lang_codes: Dict[str, str] = self._config.data["lang_codes"] # The tokenizer isn't wrapped until after calling _create_inference_model, # because the tokenizer's input/output language codes are set there @@ -1237,8 +1243,8 @@ def translate( pipeline = SilTranslationPipeline( model=model, tokenizer=tokenizer, - src_lang=lang_codes.get(src_iso, src_iso), - tgt_lang=lang_codes.get(trg_iso, trg_iso), + src_lang=src_lang, + tgt_lang=trg_lang, device=0, ) @@ -1697,7 +1703,11 @@ def _translate_with_diverse_beam_search( return self._flatten_tokenized_translations(translations) def _create_inference_model( - self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer + self, + ckpt: Union[CheckpointType, str, int], + tokenizer: PreTrainedTokenizer, + src_lang: str, + trg_lang: str, ) -> PreTrainedModel: if self._config.model_dir.exists(): checkpoint_path, _ = self.get_checkpoint_path(ckpt) @@ -1734,47 +1744,47 @@ def _create_inference_model( if model_name == self._config.model and len(tokenizer) != model.get_input_embeddings().weight.size(dim=0): model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8 if self._mixed_precision else None) if self._config.model_prefix == "google/madlad400" or model_name == self._config.model: - model, tokenizer = self._configure_model(model, tokenizer) + model, tokenizer = self._configure_model(model, tokenizer, src_lang, trg_lang) return model def _configure_model( - self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer + self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, src_lang: str, trg_lang: str ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: # Set decoder_start_token_id if ( - self._config.val_trg_lang != "" + trg_lang != "" and model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)) ): if isinstance(tokenizer, MBartTokenizer): - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[self._config.val_trg_lang] + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[trg_lang] else: - model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang) + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(trg_lang) if self._config.model_prefix == "google/madlad400": model.config.decoder_start_token_id = tokenizer.pad_token_id model.generation_config.decoder_start_token_id = tokenizer.pad_token_id model.config.max_length = 256 model.generation_config.max_new_tokens = 256 - tokenizer.tgt_lang = self._config.val_trg_lang + tokenizer.tgt_lang = trg_lang if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") if ( - self._config.val_src_lang != "" - and self._config.val_trg_lang != "" + src_lang != "" + and trg_lang != "" and isinstance( tokenizer, (MBartTokenizer, MBartTokenizerFast, M2M100Tokenizer, NllbTokenizer, NllbTokenizerFast) ) ): - tokenizer.src_lang = self._config.val_src_lang - tokenizer.tgt_lang = self._config.val_trg_lang + tokenizer.src_lang = src_lang + tokenizer.tgt_lang = trg_lang # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token # as the first generated token. - forced_bos_token_id = tokenizer.convert_tokens_to_ids(self._config.val_trg_lang) + forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang) model.config.forced_bos_token_id = forced_bos_token_id if model.generation_config is not None: model.generation_config.forced_bos_token_id = forced_bos_token_id