Skip to content

Commit 6f189e0

Browse files
author
Thilina Rajapakse
committed
Updated classification model tokenization logic. Added deberta, mpnet, squeezenet for classification
1 parent f56302d commit 6f189e0

File tree

16 files changed

+329
-59
lines changed

16 files changed

+329
-59
lines changed

CHANGELOG.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.60.0] - 2021-02-02
8+
9+
# Added
10+
11+
- Added class weights support for Longformer classification
12+
- Added new classification models:
13+
- SqueezeBert
14+
- DeBERTa
15+
- MPNet
16+
17+
# Changed
18+
19+
- Updated ClassificationModel logic to make it easier to add new models
20+
721
## [0.51.16] - 2021-01-29
822

923
## Fixed
@@ -1386,7 +1400,11 @@ Model checkpoint is now saved for all epochs again.
13861400

13871401
- This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG.
13881402

1389-
[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...HEAD
1403+
[0.60.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/5840749...HEAD
1404+
1405+
[0.51.16]: https://github.com/ThilinaRajapakse/simpletransformers/compare/b42898e...5840749
1406+
1407+
[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...b42898e
13901408

13911409
[0.51.14]: https://github.com/ThilinaRajapakse/simpletransformers/compare/278fca1...2af55e9
13921410

docs/_docs/04-classification-specifics.md

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: Classification Specifics
33
permalink: /docs/classification-specifics/
44
excerpt: "Specific notes for text classification tasks."
5-
last_modified_at: 2020/12/21 22:13:56
5+
last_modified_at: 2021/02/02 02:03:09
66
toc: true
77
---
88

@@ -32,22 +32,25 @@ The process of performing text classification in Simple Transformers does not de
3232

3333
New model types are regularly added to the library. Text classification tasks currently supports the model types given below.
3434

35-
| Model | Model code for `ClassificationModel` |
36-
| ----------- | ------------------------------------ |
37-
| ALBERT | albert |
38-
| BERT | bert |
39-
| BERTweet | bertweet |
40-
| CamemBERT | camembert |
41-
| RoBERTa | roberta |
42-
| DistilBERT | distilbert |
43-
| ELECTRA | electra |
44-
| FlauBERT | flaubert |
45-
| *LayoutLM | layoutlm |
46-
| Longformer | longformer |
47-
| *MobileBERT | mobilebert |
48-
| XLM | xlm |
49-
| XLM-RoBERTa | xlmroberta |
50-
| XLNet | xlnet |
35+
| Model | Model code for `ClassificationModel` |
36+
| ------------ | ------------------------------------ |
37+
| ALBERT | albert |
38+
| BERT | bert |
39+
| BERTweet | bertweet |
40+
| CamemBERT | camembert |
41+
| *DeBERTa | deberta |
42+
| DistilBERT | distilbert |
43+
| ELECTRA | electra |
44+
| FlauBERT | flaubert |
45+
| LayoutLM | layoutlm |
46+
| *Longformer | longformer |
47+
| *MPNet | mpnet |
48+
| MobileBERT | mobilebert |
49+
| RoBERTa | roberta |
50+
| *SqueezeBert | squeezebert |
51+
| XLM | xlm |
52+
| XLM-RoBERTa | xlmroberta |
53+
| XLNet | xlnet |
5154

5255
\* *Not available with Multi-label classification*
5356

examples/t5/mt5_translation/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
english_preds = model.predict(to_english)
3434

3535
sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth)
36-
print("Sinhalese to English: ", sin_eng_bleu.score)
36+
print("Sinhalese to English: ", sin_eng_bleu.score)

examples/t5/training_on_a_new_task/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
"wandb_project": "Question Generation with T5",
2323
}
2424

25-
model = T5Model("t5","t5-large",args=model_args)
25+
model = T5Model("t5", "t5-large", args=model_args)
2626

2727
model.train_model(train_df, eval_data=eval_df)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="simpletransformers",
8-
version="0.51.16",
8+
version="0.60.0",
99
author="Thilina Rajapakse",
1010
author_email="[email protected]",
1111
description="An easy-to-use wrapper library for the Transformers library.",

simpletransformers/classification/classification_model.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
mean_squared_error,
2828
roc_curve,
2929
auc,
30-
average_precision_score
30+
average_precision_score,
3131
)
3232
from tensorboardX import SummaryWriter
3333
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
@@ -46,11 +46,17 @@
4646
from transformers import (
4747
AlbertConfig,
4848
AlbertTokenizer,
49+
AutoConfig,
50+
AutoModelForSequenceClassification,
51+
AutoTokenizer,
4952
BertConfig,
5053
BertTokenizer,
5154
BertweetTokenizer,
5255
CamembertConfig,
5356
CamembertTokenizer,
57+
DebertaConfig,
58+
DebertaForSequenceClassification,
59+
DebertaTokenizer,
5460
DistilBertConfig,
5561
DistilBertTokenizer,
5662
ElectraConfig,
@@ -60,15 +66,17 @@
6066
LayoutLMConfig,
6167
LayoutLMTokenizer,
6268
LongformerConfig,
63-
LongformerForSequenceClassification,
6469
LongformerTokenizer,
70+
MPNetConfig,
71+
MPNetForSequenceClassification,
72+
MPNetTokenizer,
6573
MobileBertConfig,
66-
MobileBertForSequenceClassification,
6774
MobileBertTokenizer,
68-
ReformerConfig,
69-
ReformerTokenizer,
7075
RobertaConfig,
7176
RobertaTokenizer,
77+
SqueezeBertConfig,
78+
SqueezeBertForSequenceClassification,
79+
SqueezeBertTokenizer,
7280
WEIGHTS_NAME,
7381
XLMConfig,
7482
XLMRobertaConfig,
@@ -91,6 +99,8 @@
9199
from simpletransformers.classification.transformer_models.distilbert_model import DistilBertForSequenceClassification
92100
from simpletransformers.classification.transformer_models.flaubert_model import FlaubertForSequenceClassification
93101
from simpletransformers.classification.transformer_models.layoutlm_model import LayoutLMForSequenceClassification
102+
from simpletransformers.classification.transformer_models.longformer_model import LongformerForSequenceClassification
103+
from simpletransformers.classification.transformer_models.mobilebert_model import MobileBertForSequenceClassification
94104
from simpletransformers.classification.transformer_models.roberta_model import RobertaForSequenceClassification
95105
from simpletransformers.classification.transformer_models.xlm_model import XLMForSequenceClassification
96106
from simpletransformers.classification.transformer_models.xlm_roberta_model import XLMRobertaForSequenceClassification
@@ -100,7 +110,6 @@
100110
from simpletransformers.config.utils import sweep_config_to_sweep_values
101111
from simpletransformers.custom_models.models import ElectraForSequenceClassification
102112

103-
from transformers.models.reformer import ReformerForSequenceClassification
104113

105114
try:
106115
import wandb
@@ -112,11 +121,22 @@
112121
logger = logging.getLogger(__name__)
113122

114123

124+
MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT = ["squeezebert", "deberta", "mpnet"]
125+
126+
MODELS_WITH_EXTRA_SEP_TOKEN = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"]
127+
128+
MODELS_WITH_ADD_PREFIX_SPACE = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"]
129+
130+
MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT = ["squeezebert"]
131+
132+
115133
class ClassificationModel:
116134
def __init__(
117135
self,
118136
model_type,
119137
model_name,
138+
tokenizer_type=None,
139+
tokenizer_name=None,
120140
num_labels=None,
121141
weight=None,
122142
args=None,
@@ -132,6 +152,9 @@ def __init__(
132152
Args:
133153
model_type: The type of model (bert, xlnet, xlm, roberta, distilbert)
134154
model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files.
155+
tokenizer_type: The type of tokenizer (auto, bert, xlnet, xlm, roberta, distilbert, etc.) to use. If a string is passed, Simple Transformers will try to initialize a tokenizer class from the available MODEL_CLASSES.
156+
Alternatively, a Tokenizer class (subclassed from PreTrainedTokenizer) can be passed.
157+
tokenizer_name: The name/path to the tokenizer. If the tokenizer_type is not specified, the model_type will be used to determine the type of the tokenizer.
135158
num_labels (optional): The number of labels or classes in the dataset.
136159
weight (optional): A list of length num_labels containing the weights to assign to each label for loss calculation.
137160
args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
@@ -143,17 +166,20 @@ def __init__(
143166

144167
MODEL_CLASSES = {
145168
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
169+
"auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer),
146170
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
147171
"bertweet": (RobertaConfig, RobertaForSequenceClassification, BertweetTokenizer),
148172
"camembert": (CamembertConfig, CamembertForSequenceClassification, CamembertTokenizer),
173+
"deberta": (DebertaConfig, DebertaForSequenceClassification, DebertaTokenizer),
149174
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
150175
"electra": (ElectraConfig, ElectraForSequenceClassification, ElectraTokenizer),
151176
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
152177
"layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer),
153178
"longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer),
154179
"mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer),
155-
"reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer),
180+
"mpnet": (MPNetConfig, MPNetForSequenceClassification, MPNetTokenizer),
156181
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
182+
"squeezebert": (SqueezeBertConfig, SqueezeBertForSequenceClassification, SqueezeBertTokenizer),
157183
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
158184
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
159185
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
@@ -166,6 +192,9 @@ def __init__(
166192
elif isinstance(args, ClassificationArgs):
167193
self.args = args
168194

195+
if model_type in MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT and self.args.sliding_window:
196+
raise ValueError("{} does not currently support sliding window".format(model_type))
197+
169198
if self.args.thread_count:
170199
torch.set_num_threads(self.args.thread_count)
171200

@@ -200,13 +229,24 @@ def __init__(
200229
self.args.labels_list = [i for i in range(len_labels_list)]
201230

202231
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
232+
233+
if tokenizer_type is not None:
234+
if isinstance(tokenizer_type, str):
235+
_, _, tokenizer_class = MODEL_CLASSES[tokenizer_type]
236+
else:
237+
tokenizer_class = tokenizer_type
238+
203239
if num_labels:
204240
self.config = config_class.from_pretrained(model_name, num_labels=num_labels, **self.args.config)
205241
self.num_labels = num_labels
206242
else:
207243
self.config = config_class.from_pretrained(model_name, **self.args.config)
208244
self.num_labels = self.config.num_labels
209-
self.weight = weight
245+
246+
if model_type in MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT and weight is not None:
247+
raise ValueError("{} does not currently support class weights".format(model_type))
248+
else:
249+
self.weight = weight
210250

211251
if use_cuda:
212252
if torch.cuda.is_available():
@@ -275,17 +315,20 @@ def __init__(
275315
except AttributeError:
276316
raise AttributeError("fp16 requires Pytorch >= 1.6. Please update Pytorch or turn off fp16.")
277317

278-
if model_name in [
318+
if tokenizer_name is None:
319+
tokenizer_name = model_name
320+
321+
if tokenizer_name in [
279322
"vinai/bertweet-base",
280323
"vinai/bertweet-covid19-base-cased",
281324
"vinai/bertweet-covid19-base-uncased",
282325
]:
283326
self.tokenizer = tokenizer_class.from_pretrained(
284-
model_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs
327+
tokenizer_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs
285328
)
286329
else:
287330
self.tokenizer = tokenizer_class.from_pretrained(
288-
model_name, do_lower_case=self.args.do_lower_case, **kwargs
331+
tokenizer_name, do_lower_case=self.args.do_lower_case, **kwargs
289332
)
290333

291334
if self.args.special_tokens_list:
@@ -294,6 +337,8 @@ def __init__(
294337

295338
self.args.model_name = model_name
296339
self.args.model_type = model_type
340+
self.args.tokenizer_name = tokenizer_name
341+
self.args.tokenizer_type = tokenizer_type
297342

298343
if model_type in ["camembert", "xlmroberta"]:
299344
warnings.warn(
@@ -1184,7 +1229,7 @@ def load_and_cache_examples(
11841229
sep_token=tokenizer.sep_token,
11851230
# RoBERTa uses an extra separator b/w pairs of sentences,
11861231
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
1187-
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1232+
sep_token_extra=args.model_type in MODELS_WITH_EXTRA_SEP_TOKEN,
11881233
# PAD on the left for XLNet
11891234
pad_on_left=bool(args.model_type in ["xlnet"]),
11901235
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
@@ -1196,7 +1241,7 @@ def load_and_cache_examples(
11961241
sliding_window=args.sliding_window,
11971242
flatten=not evaluate,
11981243
stride=args.stride,
1199-
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1244+
add_prefix_space=args.model_type in MODELS_WITH_ADD_PREFIX_SPACE,
12001245
# avoid padding in case of single example/online inferencing to decrease execution time
12011246
pad_to_max_length=bool(len(examples) > 1),
12021247
args=args,
@@ -1236,8 +1281,10 @@ def load_and_cache_examples(
12361281
else:
12371282
return dataset
12381283
else:
1239-
train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode)
1240-
return train_dataset
1284+
dataset = ClassificationDataset(
1285+
examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode
1286+
)
1287+
return dataset
12411288

12421289
def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs):
12431290
"""
@@ -1302,7 +1349,10 @@ def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, mult
13021349
auroc = auc(fpr, tpr)
13031350
auprc = average_precision_score(labels, scores)
13041351
return (
1305-
{**{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc}, **extra_metrics},
1352+
{
1353+
**{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc},
1354+
**extra_metrics,
1355+
},
13061356
wrong,
13071357
)
13081358
else:
@@ -1575,7 +1625,7 @@ def _move_model_to_device(self):
15751625

15761626
def _get_inputs_dict(self, batch):
15771627
if isinstance(batch[0], dict):
1578-
inputs = {key: value.squeeze().to(self.device) for key, value in batch[0].items()}
1628+
inputs = {key: value.squeeze(1).to(self.device) for key, value in batch[0].items()}
15791629
inputs["labels"] = batch[1].to(self.device)
15801630
else:
15811631
batch = tuple(t.to(self.device) for t in batch)

simpletransformers/classification/classification_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ def preprocess_data(data):
9696
max_length=args.max_seq_length,
9797
truncation=True,
9898
padding="max_length",
99-
return_tensors="pt"
99+
return_tensors="pt",
100100
)
101101
else:
102102
tokenized_example = tokenizer.encode_plus(
103103
text=example.text_a,
104104
max_length=args.max_seq_length,
105105
truncation=True,
106106
padding="max_length",
107-
return_tensors="pt"
107+
return_tensors="pt",
108108
)
109109

110110
return {**tokenized_example, "label": example.label}
@@ -600,7 +600,7 @@ def __init__(
600600
self.data = [
601601
dict(
602602
json.load(open(os.path.join(data_path, l + self.data_type_extension))),
603-
**{"images": l + image_type_extension}
603+
**{"images": l + image_type_extension},
604604
)
605605
for l in files_list
606606
]

0 commit comments

Comments
 (0)