Skip to content

Commit 96d5b0d

Browse files
committed
fix bart and mvp
1 parent bdee088 commit 96d5b0d

File tree

3 files changed

+7
-27
lines changed

3 files changed

+7
-27
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,13 @@ def __init__(self, **kwargs):
436436
self._commit_hash = kwargs.pop("_commit_hash", None)
437437
self.transformers_version = kwargs.pop("transformers_version", __version__)
438438

439+
# Ensure backward compatibility for BART CNN models
440+
if self._from_model_config and kwargs.get("force_bos_token_to_be_generated", False):
441+
self.forced_bos_token_id = self.bos_token_id
442+
logger.warning_once(
443+
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
444+
)
445+
439446
# Additional attributes without default values
440447
if not self._from_model_config:
441448
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a

src/transformers/models/bart/configuration_bart.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
"""BART model configuration"""
1616

17-
import warnings
18-
1917
from ...configuration_utils import PreTrainedConfig
2018
from ...utils import logging
2119

@@ -80,9 +78,6 @@ class BartConfig(PreTrainedConfig):
8078
Whether or not the model should return the last key/values attentions (not used by all models).
8179
num_labels (`int`, *optional*, defaults to 3):
8280
The number of labels to use in [`BartForSequenceClassification`].
83-
forced_eos_token_id (`int`, *optional*, defaults to 2):
84-
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
85-
`eos_token_id`.
8681
8782
Example:
8883
@@ -130,7 +125,6 @@ def __init__(
130125
eos_token_id=2,
131126
is_encoder_decoder=True,
132127
decoder_start_token_id=2,
133-
forced_eos_token_id=2,
134128
**kwargs,
135129
):
136130
self.vocab_size = vocab_size
@@ -161,16 +155,9 @@ def __init__(
161155
eos_token_id=eos_token_id,
162156
is_encoder_decoder=is_encoder_decoder,
163157
decoder_start_token_id=decoder_start_token_id,
164-
forced_eos_token_id=forced_eos_token_id,
165158
**kwargs,
166159
)
167160
self.tie_encoder_decoder = True
168-
# ensure backward compatibility for BART CNN models
169-
if kwargs.get("force_bos_token_to_be_generated", False):
170-
self.forced_bos_token_id = self.bos_token_id
171-
warnings.warn(
172-
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
173-
)
174161

175162

176163
__all__ = ["BartConfig"]

src/transformers/models/mvp/configuration_mvp.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
"""MVP model configuration"""
1616

17-
import warnings
18-
1917
from ...configuration_utils import PreTrainedConfig
2018
from ...utils import logging
2119

@@ -78,9 +76,6 @@ class MvpConfig(PreTrainedConfig):
7876
Scale embeddings by diving by sqrt(d_model).
7977
use_cache (`bool`, *optional*, defaults to `True`):
8078
Whether or not the model should return the last key/values attentions (not used by all models).
81-
forced_eos_token_id (`int`, *optional*, defaults to 2):
82-
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
83-
`eos_token_id`.
8479
use_prompt (`bool`, *optional*, defaults to `False`):
8580
Whether or not to use prompt.
8681
prompt_length (`int`, *optional*, defaults to 100):
@@ -132,7 +127,6 @@ def __init__(
132127
eos_token_id=2,
133128
is_encoder_decoder=True,
134129
decoder_start_token_id=2,
135-
forced_eos_token_id=2,
136130
use_prompt=False,
137131
prompt_length=100,
138132
prompt_mid_dim=800,
@@ -168,16 +162,8 @@ def __init__(
168162
eos_token_id=eos_token_id,
169163
is_encoder_decoder=is_encoder_decoder,
170164
decoder_start_token_id=decoder_start_token_id,
171-
forced_eos_token_id=forced_eos_token_id,
172165
**kwargs,
173166
)
174167

175-
if kwargs.get("force_bos_token_to_be_generated", False):
176-
self.forced_bos_token_id = self.bos_token_id
177-
warnings.warn(
178-
f"Please make sure the generated config includes `forced_bos_token_id={self.bos_token_id}` . "
179-
"The config can simply be saved and uploaded again to be fixed."
180-
)
181-
182168

183169
__all__ = ["MvpConfig"]

0 commit comments

Comments
 (0)