diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 2168dc0bbe19..3485e8e1edde 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -436,6 +436,13 @@ def __init__(self, **kwargs): self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) + # Ensure backward compatibility for models that use `forced_bos_token_id` within their config + if self._from_model_config and kwargs.get("force_bos_token_to_be_generated", False): + self.forced_bos_token_id = self.bos_token_id + logger.warning_once( + f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. " + ) + # Additional attributes without default values if not self._from_model_config: # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 293929ecfc11..554865411272 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -14,8 +14,6 @@ # limitations under the License. """BART model configuration""" -import warnings - from ...configuration_utils import PreTrainedConfig from ...utils import logging @@ -80,9 +78,6 @@ class BartConfig(PreTrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). num_labels (`int`, *optional*, defaults to 3): The number of labels to use in [`BartForSequenceClassification`]. - forced_eos_token_id (`int`, *optional*, defaults to 2): - The id of the token to force as the last generated token when `max_length` is reached. Usually set to - `eos_token_id`. Example: @@ -130,7 +125,6 @@ def __init__( eos_token_id=2, is_encoder_decoder=True, decoder_start_token_id=2, - forced_eos_token_id=2, **kwargs, ): self.vocab_size = vocab_size @@ -161,16 +155,9 @@ def __init__( eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, - forced_eos_token_id=forced_eos_token_id, **kwargs, ) self.tie_encoder_decoder = True - # ensure backward compatibility for BART CNN models - if kwargs.get("force_bos_token_to_be_generated", False): - self.forced_bos_token_id = self.bos_token_id - warnings.warn( - f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. " - ) __all__ = ["BartConfig"] diff --git a/src/transformers/models/mvp/configuration_mvp.py b/src/transformers/models/mvp/configuration_mvp.py index 99cd2560c211..c006e19f42a0 100644 --- a/src/transformers/models/mvp/configuration_mvp.py +++ b/src/transformers/models/mvp/configuration_mvp.py @@ -14,8 +14,6 @@ # limitations under the License. """MVP model configuration""" -import warnings - from ...configuration_utils import PreTrainedConfig from ...utils import logging @@ -78,9 +76,6 @@ class MvpConfig(PreTrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). - forced_eos_token_id (`int`, *optional*, defaults to 2): - The id of the token to force as the last generated token when `max_length` is reached. Usually set to - `eos_token_id`. use_prompt (`bool`, *optional*, defaults to `False`): Whether or not to use prompt. prompt_length (`int`, *optional*, defaults to 100): @@ -132,7 +127,6 @@ def __init__( eos_token_id=2, is_encoder_decoder=True, decoder_start_token_id=2, - forced_eos_token_id=2, use_prompt=False, prompt_length=100, prompt_mid_dim=800, @@ -168,16 +162,8 @@ def __init__( eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, - forced_eos_token_id=forced_eos_token_id, **kwargs, ) - if kwargs.get("force_bos_token_to_be_generated", False): - self.forced_bos_token_id = self.bos_token_id - warnings.warn( - f"Please make sure the generated config includes `forced_bos_token_id={self.bos_token_id}` . " - "The config can simply be saved and uploaded again to be fixed." - ) - __all__ = ["MvpConfig"] diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index eba59535b6b0..3ac771067dd2 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1279,9 +1279,7 @@ def test_contrastive_search_bart(self): @slow def test_decoder_attention_mask(self): - model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to( - torch_device - ) + model = BartForConditionalGeneration.from_pretrained("facebook/bart-large").to(torch_device) tokenizer = self.default_tokenizer sentence = "UN Chief Says There Is No in Syria" input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device) @@ -1302,6 +1300,7 @@ def test_decoder_attention_mask(self): max_new_tokens=20, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + forced_bos_token_id=0, ) generated_sentence = tokenizer.batch_decode(generated_ids)[0] expected_sentence = "UN Chief Says There Is No Plan B for Peace in Syria"