Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,13 @@ def __init__(self, **kwargs):
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)

# Ensure backward compatibility for BART CNN models
if self._from_model_config and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
logger.warning_once(
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
)

# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
Expand Down
13 changes: 0 additions & 13 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
"""BART model configuration"""

import warnings

from ...configuration_utils import PreTrainedConfig
from ...utils import logging

Expand Down Expand Up @@ -80,9 +78,6 @@ class BartConfig(PreTrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels (`int`, *optional*, defaults to 3):
The number of labels to use in [`BartForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.

Example:

Expand Down Expand Up @@ -130,7 +125,6 @@ def __init__(
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -161,16 +155,9 @@ def __init__(
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
self.tie_encoder_decoder = True
# ensure backward compatibility for BART CNN models
if kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
)


__all__ = ["BartConfig"]
14 changes: 0 additions & 14 deletions src/transformers/models/mvp/configuration_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
"""MVP model configuration"""

import warnings

from ...configuration_utils import PreTrainedConfig
from ...utils import logging

Expand Down Expand Up @@ -78,9 +76,6 @@ class MvpConfig(PreTrainedConfig):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
use_prompt (`bool`, *optional*, defaults to `False`):
Whether or not to use prompt.
prompt_length (`int`, *optional*, defaults to 100):
Expand Down Expand Up @@ -132,7 +127,6 @@ def __init__(
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
use_prompt=False,
prompt_length=100,
prompt_mid_dim=800,
Expand Down Expand Up @@ -168,16 +162,8 @@ def __init__(
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)

if kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the generated config includes `forced_bos_token_id={self.bos_token_id}` . "
"The config can simply be saved and uploaded again to be fixed."
)


__all__ = ["MvpConfig"]