2727 mean_squared_error ,
2828 roc_curve ,
2929 auc ,
30- average_precision_score
30+ average_precision_score ,
3131)
3232from tensorboardX import SummaryWriter
3333from torch .utils .data import DataLoader , RandomSampler , SequentialSampler , TensorDataset
4646from transformers import (
4747 AlbertConfig ,
4848 AlbertTokenizer ,
49+ AutoConfig ,
50+ AutoModelForSequenceClassification ,
51+ AutoTokenizer ,
4952 BertConfig ,
5053 BertTokenizer ,
5154 BertweetTokenizer ,
5255 CamembertConfig ,
5356 CamembertTokenizer ,
57+ DebertaConfig ,
58+ DebertaForSequenceClassification ,
59+ DebertaTokenizer ,
5460 DistilBertConfig ,
5561 DistilBertTokenizer ,
5662 ElectraConfig ,
6066 LayoutLMConfig ,
6167 LayoutLMTokenizer ,
6268 LongformerConfig ,
63- LongformerForSequenceClassification ,
6469 LongformerTokenizer ,
70+ MPNetConfig ,
71+ MPNetForSequenceClassification ,
72+ MPNetTokenizer ,
6573 MobileBertConfig ,
66- MobileBertForSequenceClassification ,
6774 MobileBertTokenizer ,
68- ReformerConfig ,
69- ReformerTokenizer ,
7075 RobertaConfig ,
7176 RobertaTokenizer ,
77+ SqueezeBertConfig ,
78+ SqueezeBertForSequenceClassification ,
79+ SqueezeBertTokenizer ,
7280 WEIGHTS_NAME ,
7381 XLMConfig ,
7482 XLMRobertaConfig ,
9199from simpletransformers .classification .transformer_models .distilbert_model import DistilBertForSequenceClassification
92100from simpletransformers .classification .transformer_models .flaubert_model import FlaubertForSequenceClassification
93101from simpletransformers .classification .transformer_models .layoutlm_model import LayoutLMForSequenceClassification
102+ from simpletransformers .classification .transformer_models .longformer_model import LongformerForSequenceClassification
103+ from simpletransformers .classification .transformer_models .mobilebert_model import MobileBertForSequenceClassification
94104from simpletransformers .classification .transformer_models .roberta_model import RobertaForSequenceClassification
95105from simpletransformers .classification .transformer_models .xlm_model import XLMForSequenceClassification
96106from simpletransformers .classification .transformer_models .xlm_roberta_model import XLMRobertaForSequenceClassification
100110from simpletransformers .config .utils import sweep_config_to_sweep_values
101111from simpletransformers .custom_models .models import ElectraForSequenceClassification
102112
103- from transformers .models .reformer import ReformerForSequenceClassification
104113
105114try :
106115 import wandb
112121logger = logging .getLogger (__name__ )
113122
114123
124+ MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT = ["squeezebert" , "deberta" , "mpnet" ]
125+
126+ MODELS_WITH_EXTRA_SEP_TOKEN = ["roberta" , "camembert" , "xlmroberta" , "longformer" , "mpnet" ]
127+
128+ MODELS_WITH_ADD_PREFIX_SPACE = ["roberta" , "camembert" , "xlmroberta" , "longformer" , "mpnet" ]
129+
130+ MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT = ["squeezebert" ]
131+
132+
115133class ClassificationModel :
116134 def __init__ (
117135 self ,
118136 model_type ,
119137 model_name ,
138+ tokenizer_type = None ,
139+ tokenizer_name = None ,
120140 num_labels = None ,
121141 weight = None ,
122142 args = None ,
@@ -132,6 +152,9 @@ def __init__(
132152 Args:
133153 model_type: The type of model (bert, xlnet, xlm, roberta, distilbert)
134154 model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files.
155+ tokenizer_type: The type of tokenizer (auto, bert, xlnet, xlm, roberta, distilbert, etc.) to use. If a string is passed, Simple Transformers will try to initialize a tokenizer class from the available MODEL_CLASSES.
156+ Alternatively, a Tokenizer class (subclassed from PreTrainedTokenizer) can be passed.
157+ tokenizer_name: The name/path to the tokenizer. If the tokenizer_type is not specified, the model_type will be used to determine the type of the tokenizer.
135158 num_labels (optional): The number of labels or classes in the dataset.
136159 weight (optional): A list of length num_labels containing the weights to assign to each label for loss calculation.
137160 args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
@@ -143,17 +166,20 @@ def __init__(
143166
144167 MODEL_CLASSES = {
145168 "albert" : (AlbertConfig , AlbertForSequenceClassification , AlbertTokenizer ),
169+ "auto" : (AutoConfig , AutoModelForSequenceClassification , AutoTokenizer ),
146170 "bert" : (BertConfig , BertForSequenceClassification , BertTokenizer ),
147171 "bertweet" : (RobertaConfig , RobertaForSequenceClassification , BertweetTokenizer ),
148172 "camembert" : (CamembertConfig , CamembertForSequenceClassification , CamembertTokenizer ),
173+ "deberta" : (DebertaConfig , DebertaForSequenceClassification , DebertaTokenizer ),
149174 "distilbert" : (DistilBertConfig , DistilBertForSequenceClassification , DistilBertTokenizer ),
150175 "electra" : (ElectraConfig , ElectraForSequenceClassification , ElectraTokenizer ),
151176 "flaubert" : (FlaubertConfig , FlaubertForSequenceClassification , FlaubertTokenizer ),
152177 "layoutlm" : (LayoutLMConfig , LayoutLMForSequenceClassification , LayoutLMTokenizer ),
153178 "longformer" : (LongformerConfig , LongformerForSequenceClassification , LongformerTokenizer ),
154179 "mobilebert" : (MobileBertConfig , MobileBertForSequenceClassification , MobileBertTokenizer ),
155- "reformer " : (ReformerConfig , ReformerForSequenceClassification , ReformerTokenizer ),
180+ "mpnet " : (MPNetConfig , MPNetForSequenceClassification , MPNetTokenizer ),
156181 "roberta" : (RobertaConfig , RobertaForSequenceClassification , RobertaTokenizer ),
182+ "squeezebert" : (SqueezeBertConfig , SqueezeBertForSequenceClassification , SqueezeBertTokenizer ),
157183 "xlm" : (XLMConfig , XLMForSequenceClassification , XLMTokenizer ),
158184 "xlmroberta" : (XLMRobertaConfig , XLMRobertaForSequenceClassification , XLMRobertaTokenizer ),
159185 "xlnet" : (XLNetConfig , XLNetForSequenceClassification , XLNetTokenizer ),
@@ -166,6 +192,9 @@ def __init__(
166192 elif isinstance (args , ClassificationArgs ):
167193 self .args = args
168194
195+ if model_type in MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT and self .args .sliding_window :
196+ raise ValueError ("{} does not currently support sliding window" .format (model_type ))
197+
169198 if self .args .thread_count :
170199 torch .set_num_threads (self .args .thread_count )
171200
@@ -200,13 +229,24 @@ def __init__(
200229 self .args .labels_list = [i for i in range (len_labels_list )]
201230
202231 config_class , model_class , tokenizer_class = MODEL_CLASSES [model_type ]
232+
233+ if tokenizer_type is not None :
234+ if isinstance (tokenizer_type , str ):
235+ _ , _ , tokenizer_class = MODEL_CLASSES [tokenizer_type ]
236+ else :
237+ tokenizer_class = tokenizer_type
238+
203239 if num_labels :
204240 self .config = config_class .from_pretrained (model_name , num_labels = num_labels , ** self .args .config )
205241 self .num_labels = num_labels
206242 else :
207243 self .config = config_class .from_pretrained (model_name , ** self .args .config )
208244 self .num_labels = self .config .num_labels
209- self .weight = weight
245+
246+ if model_type in MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT and weight is not None :
247+ raise ValueError ("{} does not currently support class weights" .format (model_type ))
248+ else :
249+ self .weight = weight
210250
211251 if use_cuda :
212252 if torch .cuda .is_available ():
@@ -275,17 +315,20 @@ def __init__(
275315 except AttributeError :
276316 raise AttributeError ("fp16 requires Pytorch >= 1.6. Please update Pytorch or turn off fp16." )
277317
278- if model_name in [
318+ if tokenizer_name is None :
319+ tokenizer_name = model_name
320+
321+ if tokenizer_name in [
279322 "vinai/bertweet-base" ,
280323 "vinai/bertweet-covid19-base-cased" ,
281324 "vinai/bertweet-covid19-base-uncased" ,
282325 ]:
283326 self .tokenizer = tokenizer_class .from_pretrained (
284- model_name , do_lower_case = self .args .do_lower_case , normalization = True , ** kwargs
327+ tokenizer_name , do_lower_case = self .args .do_lower_case , normalization = True , ** kwargs
285328 )
286329 else :
287330 self .tokenizer = tokenizer_class .from_pretrained (
288- model_name , do_lower_case = self .args .do_lower_case , ** kwargs
331+ tokenizer_name , do_lower_case = self .args .do_lower_case , ** kwargs
289332 )
290333
291334 if self .args .special_tokens_list :
@@ -294,6 +337,8 @@ def __init__(
294337
295338 self .args .model_name = model_name
296339 self .args .model_type = model_type
340+ self .args .tokenizer_name = tokenizer_name
341+ self .args .tokenizer_type = tokenizer_type
297342
298343 if model_type in ["camembert" , "xlmroberta" ]:
299344 warnings .warn (
@@ -1184,7 +1229,7 @@ def load_and_cache_examples(
11841229 sep_token = tokenizer .sep_token ,
11851230 # RoBERTa uses an extra separator b/w pairs of sentences,
11861231 # cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
1187- sep_token_extra = bool ( args .model_type in [ "roberta" , "camembert" , "xlmroberta" , "longformer" ]) ,
1232+ sep_token_extra = args .model_type in MODELS_WITH_EXTRA_SEP_TOKEN ,
11881233 # PAD on the left for XLNet
11891234 pad_on_left = bool (args .model_type in ["xlnet" ]),
11901235 pad_token = tokenizer .convert_tokens_to_ids ([tokenizer .pad_token ])[0 ],
@@ -1196,7 +1241,7 @@ def load_and_cache_examples(
11961241 sliding_window = args .sliding_window ,
11971242 flatten = not evaluate ,
11981243 stride = args .stride ,
1199- add_prefix_space = bool ( args .model_type in [ "roberta" , "camembert" , "xlmroberta" , "longformer" ]) ,
1244+ add_prefix_space = args .model_type in MODELS_WITH_ADD_PREFIX_SPACE ,
12001245 # avoid padding in case of single example/online inferencing to decrease execution time
12011246 pad_to_max_length = bool (len (examples ) > 1 ),
12021247 args = args ,
@@ -1236,8 +1281,10 @@ def load_and_cache_examples(
12361281 else :
12371282 return dataset
12381283 else :
1239- train_dataset = ClassificationDataset (examples , self .tokenizer , self .args , mode = mode , multi_label = multi_label , output_mode = output_mode )
1240- return train_dataset
1284+ dataset = ClassificationDataset (
1285+ examples , self .tokenizer , self .args , mode = mode , multi_label = multi_label , output_mode = output_mode
1286+ )
1287+ return dataset
12411288
12421289 def compute_metrics (self , preds , model_outputs , labels , eval_examples = None , multi_label = False , ** kwargs ):
12431290 """
@@ -1302,7 +1349,10 @@ def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, mult
13021349 auroc = auc (fpr , tpr )
13031350 auprc = average_precision_score (labels , scores )
13041351 return (
1305- {** {"mcc" : mcc , "tp" : tp , "tn" : tn , "fp" : fp , "fn" : fn , "auroc" : auroc , "auprc" : auprc }, ** extra_metrics },
1352+ {
1353+ ** {"mcc" : mcc , "tp" : tp , "tn" : tn , "fp" : fp , "fn" : fn , "auroc" : auroc , "auprc" : auprc },
1354+ ** extra_metrics ,
1355+ },
13061356 wrong ,
13071357 )
13081358 else :
@@ -1575,7 +1625,7 @@ def _move_model_to_device(self):
15751625
15761626 def _get_inputs_dict (self , batch ):
15771627 if isinstance (batch [0 ], dict ):
1578- inputs = {key : value .squeeze ().to (self .device ) for key , value in batch [0 ].items ()}
1628+ inputs = {key : value .squeeze (1 ).to (self .device ) for key , value in batch [0 ].items ()}
15791629 inputs ["labels" ] = batch [1 ].to (self .device )
15801630 else :
15811631 batch = tuple (t .to (self .device ) for t in batch )
0 commit comments