Skip to content

Commit f56302d

Browse files
author
Thilina Rajapakse
committed
Updated tokenization logic in classification models
1 parent 5840749 commit f56302d

File tree

4 files changed

+178
-88
lines changed

4 files changed

+178
-88
lines changed

simpletransformers/classification/classification_model.py

Lines changed: 92 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
MobileBertConfig,
6666
MobileBertForSequenceClassification,
6767
MobileBertTokenizer,
68+
ReformerConfig,
69+
ReformerTokenizer,
6870
RobertaConfig,
6971
RobertaTokenizer,
7072
WEIGHTS_NAME,
@@ -80,6 +82,7 @@
8082
from simpletransformers.classification.classification_utils import (
8183
InputExample,
8284
LazyClassificationDataset,
85+
ClassificationDataset,
8386
convert_examples_to_features,
8487
)
8588
from simpletransformers.classification.transformer_models.albert_model import AlbertForSequenceClassification
@@ -97,6 +100,8 @@
97100
from simpletransformers.config.utils import sweep_config_to_sweep_values
98101
from simpletransformers.custom_models.models import ElectraForSequenceClassification
99102

103+
from transformers.models.reformer import ReformerForSequenceClassification
104+
100105
try:
101106
import wandb
102107

@@ -147,6 +152,7 @@ def __init__(
147152
"layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer),
148153
"longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer),
149154
"mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer),
155+
"reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer),
150156
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
151157
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
152158
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
@@ -1137,97 +1143,101 @@ def load_and_cache_examples(
11371143
os.makedirs(self.args.cache_dir, exist_ok=True)
11381144

11391145
mode = "dev" if evaluate else "train"
1140-
cached_features_file = os.path.join(
1141-
args.cache_dir,
1142-
"cached_{}_{}_{}_{}_{}".format(
1143-
mode, args.model_type, args.max_seq_length, self.num_labels, len(examples),
1144-
),
1145-
)
1146-
1147-
if os.path.exists(cached_features_file) and (
1148-
(not args.reprocess_input_data and not no_cache)
1149-
or (mode == "dev" and args.use_cached_eval_features and not no_cache)
1150-
):
1151-
features = torch.load(cached_features_file)
1152-
if verbose:
1153-
logger.info(f" Features loaded from cache at {cached_features_file}")
1154-
else:
1155-
if verbose:
1156-
logger.info(" Converting to features started. Cache is not used.")
1157-
if args.sliding_window:
1158-
logger.info(" Sliding window enabled")
1159-
1160-
# If labels_map is defined, then labels need to be replaced with ints
1161-
if self.args.labels_map and not self.args.regression:
1162-
for example in examples:
1163-
if multi_label:
1164-
example.label = [self.args.labels_map[label] for label in example.label]
1165-
else:
1166-
example.label = self.args.labels_map[example.label]
1167-
1168-
features = convert_examples_to_features(
1169-
examples,
1170-
args.max_seq_length,
1171-
tokenizer,
1172-
output_mode,
1173-
# XLNet has a CLS token at the end
1174-
cls_token_at_end=bool(args.model_type in ["xlnet"]),
1175-
cls_token=tokenizer.cls_token,
1176-
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
1177-
sep_token=tokenizer.sep_token,
1178-
# RoBERTa uses an extra separator b/w pairs of sentences,
1179-
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
1180-
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1181-
# PAD on the left for XLNet
1182-
pad_on_left=bool(args.model_type in ["xlnet"]),
1183-
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
1184-
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
1185-
process_count=process_count,
1186-
multi_label=multi_label,
1187-
silent=args.silent or silent,
1188-
use_multiprocessing=args.use_multiprocessing,
1189-
sliding_window=args.sliding_window,
1190-
flatten=not evaluate,
1191-
stride=args.stride,
1192-
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1193-
# avoid padding in case of single example/online inferencing to decrease execution time
1194-
pad_to_max_length=bool(len(examples) > 1),
1195-
args=args,
1146+
if args.sliding_window or self.args.model_type == "layoutlm":
1147+
cached_features_file = os.path.join(
1148+
args.cache_dir,
1149+
"cached_{}_{}_{}_{}_{}".format(
1150+
mode, args.model_type, args.max_seq_length, self.num_labels, len(examples),
1151+
),
11961152
)
1197-
if verbose and args.sliding_window:
1198-
logger.info(f" {len(features)} features created from {len(examples)} samples.")
11991153

1200-
if not no_cache:
1201-
torch.save(features, cached_features_file)
1154+
if os.path.exists(cached_features_file) and (
1155+
(not args.reprocess_input_data and not no_cache)
1156+
or (mode == "dev" and args.use_cached_eval_features and not no_cache)
1157+
):
1158+
features = torch.load(cached_features_file)
1159+
if verbose:
1160+
logger.info(f" Features loaded from cache at {cached_features_file}")
1161+
else:
1162+
if verbose:
1163+
logger.info(" Converting to features started. Cache is not used.")
1164+
if args.sliding_window:
1165+
logger.info(" Sliding window enabled")
1166+
1167+
# If labels_map is defined, then labels need to be replaced with ints
1168+
if self.args.labels_map and not self.args.regression:
1169+
for example in examples:
1170+
if multi_label:
1171+
example.label = [self.args.labels_map[label] for label in example.label]
1172+
else:
1173+
example.label = self.args.labels_map[example.label]
1174+
1175+
features = convert_examples_to_features(
1176+
examples,
1177+
args.max_seq_length,
1178+
tokenizer,
1179+
output_mode,
1180+
# XLNet has a CLS token at the end
1181+
cls_token_at_end=bool(args.model_type in ["xlnet"]),
1182+
cls_token=tokenizer.cls_token,
1183+
cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
1184+
sep_token=tokenizer.sep_token,
1185+
# RoBERTa uses an extra separator b/w pairs of sentences,
1186+
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
1187+
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1188+
# PAD on the left for XLNet
1189+
pad_on_left=bool(args.model_type in ["xlnet"]),
1190+
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
1191+
pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
1192+
process_count=process_count,
1193+
multi_label=multi_label,
1194+
silent=args.silent or silent,
1195+
use_multiprocessing=args.use_multiprocessing,
1196+
sliding_window=args.sliding_window,
1197+
flatten=not evaluate,
1198+
stride=args.stride,
1199+
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
1200+
# avoid padding in case of single example/online inferencing to decrease execution time
1201+
pad_to_max_length=bool(len(examples) > 1),
1202+
args=args,
1203+
)
1204+
if verbose and args.sliding_window:
1205+
logger.info(f" {len(features)} features created from {len(examples)} samples.")
1206+
1207+
if not no_cache:
1208+
torch.save(features, cached_features_file)
12021209

1203-
if args.sliding_window and evaluate:
1204-
features = [
1205-
[feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features
1206-
]
1207-
window_counts = [len(sample) for sample in features]
1208-
features = [feature for feature_set in features for feature in feature_set]
1210+
if args.sliding_window and evaluate:
1211+
features = [
1212+
[feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features
1213+
]
1214+
window_counts = [len(sample) for sample in features]
1215+
features = [feature for feature_set in features for feature in feature_set]
12091216

1210-
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
1211-
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
1212-
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
1217+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
1218+
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
1219+
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
12131220

1214-
if self.args.model_type == "layoutlm":
1215-
all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long)
1221+
if self.args.model_type == "layoutlm":
1222+
all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long)
12161223

1217-
if output_mode == "classification":
1218-
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
1219-
elif output_mode == "regression":
1220-
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
1224+
if output_mode == "classification":
1225+
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
1226+
elif output_mode == "regression":
1227+
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
12211228

1222-
if self.args.model_type == "layoutlm":
1223-
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes)
1224-
else:
1225-
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
1229+
if self.args.model_type == "layoutlm":
1230+
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes)
1231+
else:
1232+
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
12261233

1227-
if args.sliding_window and evaluate:
1228-
return dataset, window_counts
1234+
if args.sliding_window and evaluate:
1235+
return dataset, window_counts
1236+
else:
1237+
return dataset
12291238
else:
1230-
return dataset
1239+
train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode)
1240+
return train_dataset
12311241

12321242
def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs):
12331243
"""

simpletransformers/classification/classification_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import csv
2020
import json
21+
import logging
2122
import linecache
2223
import os
2324
import sys
@@ -44,6 +45,8 @@
4445

4546
csv.field_size_limit(2147483647)
4647

48+
logger = logging.getLogger(__name__)
49+
4750

4851
class InputExample(object):
4952
"""A single training/test example for simple sequence classification."""
@@ -84,6 +87,88 @@ def __init__(self, input_ids, input_mask, segment_ids, label_id, bboxes=None):
8487
self.bboxes = bboxes
8588

8689

90+
def preprocess_data(data):
91+
example, tokenizer, args = data
92+
93+
if example.text_b:
94+
tokenized_example = tokenizer.encode_plus(
95+
text_pair=[example.text_a, example.text_b],
96+
max_length=args.max_seq_length,
97+
truncation=True,
98+
padding="max_length",
99+
return_tensors="pt"
100+
)
101+
else:
102+
tokenized_example = tokenizer.encode_plus(
103+
text=example.text_a,
104+
max_length=args.max_seq_length,
105+
truncation=True,
106+
padding="max_length",
107+
return_tensors="pt"
108+
)
109+
110+
return {**tokenized_example, "label": example.label}
111+
112+
113+
class ClassificationDataset(Dataset):
114+
def __init__(self, data, tokenizer, args, mode, multi_label, output_mode):
115+
self.tokenizer = tokenizer
116+
self.output_mode = output_mode
117+
118+
cached_features_file = os.path.join(
119+
args.cache_dir,
120+
"cached_{}_{}_{}_{}_{}".format(
121+
mode, args.model_type, args.max_seq_length, len(args.labels_list), len(data),
122+
),
123+
)
124+
125+
if os.path.exists(cached_features_file) and (
126+
(not args.reprocess_input_data and not args.no_cache)
127+
or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
128+
):
129+
self.examples = torch.load(cached_features_file)
130+
logger.info(f" Features loaded from cache at {cached_features_file}")
131+
else:
132+
logger.info(" Converting to features started. Cache is not used.")
133+
134+
# If labels_map is defined, then labels need to be replaced with ints
135+
if args.labels_map and not args.regression:
136+
for example in data:
137+
if multi_label:
138+
example.label = [args.labels_map[label] for label in example.label]
139+
else:
140+
example.label = args.labels_map[example.label]
141+
data = [(example, tokenizer, args) for example in data]
142+
143+
if args.use_multiprocessing:
144+
with Pool(args.process_count) as p:
145+
self.examples = list(
146+
tqdm(
147+
p.imap(preprocess_data, data, chunksize=args.multiprocessing_chunksize),
148+
total=len(data),
149+
disable=args.silent,
150+
)
151+
)
152+
else:
153+
self.examples = [preprocess_data(d) for d in tqdm(data, disable=args.silent)]
154+
155+
if not args.no_cache:
156+
logger.info(" Saving features into cached file %s", cached_features_file)
157+
torch.save(self.examples, cached_features_file)
158+
159+
def __len__(self):
160+
return len(self.examples)
161+
162+
def __getitem__(self, index):
163+
features = self.examples[index]
164+
label = features.pop("label")
165+
if self.output_mode == "classification":
166+
label = torch.tensor(label, dtype=torch.long)
167+
elif self.output_mode == "regression":
168+
label = torch.tensor(label, dtype=torch.float)
169+
return features, label
170+
171+
87172
def convert_example_to_feature(
88173
example_row,
89174
pad_token=0,

simpletransformers/t5/run_simple_transformers_streamlit_app.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

tests/test_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_multiclass_classification(model_type, model_name):
113113
@pytest.mark.parametrize(
114114
"model_type, model_name",
115115
[
116-
# ("bert", "bert-base-uncased"),
116+
("bert", "bert-base-uncased"),
117117
("xlnet", "xlnet-base-cased"),
118118
# ("xlm", "xlm-mlm-17-1280"),
119119
# ("roberta", "roberta-base"),

0 commit comments

Comments
 (0)