|
65 | 65 | MobileBertConfig, |
66 | 66 | MobileBertForSequenceClassification, |
67 | 67 | MobileBertTokenizer, |
| 68 | + ReformerConfig, |
| 69 | + ReformerTokenizer, |
68 | 70 | RobertaConfig, |
69 | 71 | RobertaTokenizer, |
70 | 72 | WEIGHTS_NAME, |
|
80 | 82 | from simpletransformers.classification.classification_utils import ( |
81 | 83 | InputExample, |
82 | 84 | LazyClassificationDataset, |
| 85 | + ClassificationDataset, |
83 | 86 | convert_examples_to_features, |
84 | 87 | ) |
85 | 88 | from simpletransformers.classification.transformer_models.albert_model import AlbertForSequenceClassification |
|
97 | 100 | from simpletransformers.config.utils import sweep_config_to_sweep_values |
98 | 101 | from simpletransformers.custom_models.models import ElectraForSequenceClassification |
99 | 102 |
|
| 103 | +from transformers.models.reformer import ReformerForSequenceClassification |
| 104 | + |
100 | 105 | try: |
101 | 106 | import wandb |
102 | 107 |
|
@@ -147,6 +152,7 @@ def __init__( |
147 | 152 | "layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer), |
148 | 153 | "longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer), |
149 | 154 | "mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer), |
| 155 | + "reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer), |
150 | 156 | "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), |
151 | 157 | "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer), |
152 | 158 | "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer), |
@@ -1137,97 +1143,101 @@ def load_and_cache_examples( |
1137 | 1143 | os.makedirs(self.args.cache_dir, exist_ok=True) |
1138 | 1144 |
|
1139 | 1145 | mode = "dev" if evaluate else "train" |
1140 | | - cached_features_file = os.path.join( |
1141 | | - args.cache_dir, |
1142 | | - "cached_{}_{}_{}_{}_{}".format( |
1143 | | - mode, args.model_type, args.max_seq_length, self.num_labels, len(examples), |
1144 | | - ), |
1145 | | - ) |
1146 | | - |
1147 | | - if os.path.exists(cached_features_file) and ( |
1148 | | - (not args.reprocess_input_data and not no_cache) |
1149 | | - or (mode == "dev" and args.use_cached_eval_features and not no_cache) |
1150 | | - ): |
1151 | | - features = torch.load(cached_features_file) |
1152 | | - if verbose: |
1153 | | - logger.info(f" Features loaded from cache at {cached_features_file}") |
1154 | | - else: |
1155 | | - if verbose: |
1156 | | - logger.info(" Converting to features started. Cache is not used.") |
1157 | | - if args.sliding_window: |
1158 | | - logger.info(" Sliding window enabled") |
1159 | | - |
1160 | | - # If labels_map is defined, then labels need to be replaced with ints |
1161 | | - if self.args.labels_map and not self.args.regression: |
1162 | | - for example in examples: |
1163 | | - if multi_label: |
1164 | | - example.label = [self.args.labels_map[label] for label in example.label] |
1165 | | - else: |
1166 | | - example.label = self.args.labels_map[example.label] |
1167 | | - |
1168 | | - features = convert_examples_to_features( |
1169 | | - examples, |
1170 | | - args.max_seq_length, |
1171 | | - tokenizer, |
1172 | | - output_mode, |
1173 | | - # XLNet has a CLS token at the end |
1174 | | - cls_token_at_end=bool(args.model_type in ["xlnet"]), |
1175 | | - cls_token=tokenizer.cls_token, |
1176 | | - cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0, |
1177 | | - sep_token=tokenizer.sep_token, |
1178 | | - # RoBERTa uses an extra separator b/w pairs of sentences, |
1179 | | - # cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 |
1180 | | - sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), |
1181 | | - # PAD on the left for XLNet |
1182 | | - pad_on_left=bool(args.model_type in ["xlnet"]), |
1183 | | - pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], |
1184 | | - pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, |
1185 | | - process_count=process_count, |
1186 | | - multi_label=multi_label, |
1187 | | - silent=args.silent or silent, |
1188 | | - use_multiprocessing=args.use_multiprocessing, |
1189 | | - sliding_window=args.sliding_window, |
1190 | | - flatten=not evaluate, |
1191 | | - stride=args.stride, |
1192 | | - add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), |
1193 | | - # avoid padding in case of single example/online inferencing to decrease execution time |
1194 | | - pad_to_max_length=bool(len(examples) > 1), |
1195 | | - args=args, |
| 1146 | + if args.sliding_window or self.args.model_type == "layoutlm": |
| 1147 | + cached_features_file = os.path.join( |
| 1148 | + args.cache_dir, |
| 1149 | + "cached_{}_{}_{}_{}_{}".format( |
| 1150 | + mode, args.model_type, args.max_seq_length, self.num_labels, len(examples), |
| 1151 | + ), |
1196 | 1152 | ) |
1197 | | - if verbose and args.sliding_window: |
1198 | | - logger.info(f" {len(features)} features created from {len(examples)} samples.") |
1199 | 1153 |
|
1200 | | - if not no_cache: |
1201 | | - torch.save(features, cached_features_file) |
| 1154 | + if os.path.exists(cached_features_file) and ( |
| 1155 | + (not args.reprocess_input_data and not no_cache) |
| 1156 | + or (mode == "dev" and args.use_cached_eval_features and not no_cache) |
| 1157 | + ): |
| 1158 | + features = torch.load(cached_features_file) |
| 1159 | + if verbose: |
| 1160 | + logger.info(f" Features loaded from cache at {cached_features_file}") |
| 1161 | + else: |
| 1162 | + if verbose: |
| 1163 | + logger.info(" Converting to features started. Cache is not used.") |
| 1164 | + if args.sliding_window: |
| 1165 | + logger.info(" Sliding window enabled") |
| 1166 | + |
| 1167 | + # If labels_map is defined, then labels need to be replaced with ints |
| 1168 | + if self.args.labels_map and not self.args.regression: |
| 1169 | + for example in examples: |
| 1170 | + if multi_label: |
| 1171 | + example.label = [self.args.labels_map[label] for label in example.label] |
| 1172 | + else: |
| 1173 | + example.label = self.args.labels_map[example.label] |
| 1174 | + |
| 1175 | + features = convert_examples_to_features( |
| 1176 | + examples, |
| 1177 | + args.max_seq_length, |
| 1178 | + tokenizer, |
| 1179 | + output_mode, |
| 1180 | + # XLNet has a CLS token at the end |
| 1181 | + cls_token_at_end=bool(args.model_type in ["xlnet"]), |
| 1182 | + cls_token=tokenizer.cls_token, |
| 1183 | + cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0, |
| 1184 | + sep_token=tokenizer.sep_token, |
| 1185 | + # RoBERTa uses an extra separator b/w pairs of sentences, |
| 1186 | + # cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 |
| 1187 | + sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), |
| 1188 | + # PAD on the left for XLNet |
| 1189 | + pad_on_left=bool(args.model_type in ["xlnet"]), |
| 1190 | + pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], |
| 1191 | + pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, |
| 1192 | + process_count=process_count, |
| 1193 | + multi_label=multi_label, |
| 1194 | + silent=args.silent or silent, |
| 1195 | + use_multiprocessing=args.use_multiprocessing, |
| 1196 | + sliding_window=args.sliding_window, |
| 1197 | + flatten=not evaluate, |
| 1198 | + stride=args.stride, |
| 1199 | + add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]), |
| 1200 | + # avoid padding in case of single example/online inferencing to decrease execution time |
| 1201 | + pad_to_max_length=bool(len(examples) > 1), |
| 1202 | + args=args, |
| 1203 | + ) |
| 1204 | + if verbose and args.sliding_window: |
| 1205 | + logger.info(f" {len(features)} features created from {len(examples)} samples.") |
| 1206 | + |
| 1207 | + if not no_cache: |
| 1208 | + torch.save(features, cached_features_file) |
1202 | 1209 |
|
1203 | | - if args.sliding_window and evaluate: |
1204 | | - features = [ |
1205 | | - [feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features |
1206 | | - ] |
1207 | | - window_counts = [len(sample) for sample in features] |
1208 | | - features = [feature for feature_set in features for feature in feature_set] |
| 1210 | + if args.sliding_window and evaluate: |
| 1211 | + features = [ |
| 1212 | + [feature_set] if not isinstance(feature_set, list) else feature_set for feature_set in features |
| 1213 | + ] |
| 1214 | + window_counts = [len(sample) for sample in features] |
| 1215 | + features = [feature for feature_set in features for feature in feature_set] |
1209 | 1216 |
|
1210 | | - all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
1211 | | - all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) |
1212 | | - all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) |
| 1217 | + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
| 1218 | + all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) |
| 1219 | + all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) |
1213 | 1220 |
|
1214 | | - if self.args.model_type == "layoutlm": |
1215 | | - all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long) |
| 1221 | + if self.args.model_type == "layoutlm": |
| 1222 | + all_bboxes = torch.tensor([f.bboxes for f in features], dtype=torch.long) |
1216 | 1223 |
|
1217 | | - if output_mode == "classification": |
1218 | | - all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) |
1219 | | - elif output_mode == "regression": |
1220 | | - all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) |
| 1224 | + if output_mode == "classification": |
| 1225 | + all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) |
| 1226 | + elif output_mode == "regression": |
| 1227 | + all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) |
1221 | 1228 |
|
1222 | | - if self.args.model_type == "layoutlm": |
1223 | | - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes) |
1224 | | - else: |
1225 | | - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) |
| 1229 | + if self.args.model_type == "layoutlm": |
| 1230 | + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_bboxes) |
| 1231 | + else: |
| 1232 | + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) |
1226 | 1233 |
|
1227 | | - if args.sliding_window and evaluate: |
1228 | | - return dataset, window_counts |
| 1234 | + if args.sliding_window and evaluate: |
| 1235 | + return dataset, window_counts |
| 1236 | + else: |
| 1237 | + return dataset |
1229 | 1238 | else: |
1230 | | - return dataset |
| 1239 | + train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode) |
| 1240 | + return train_dataset |
1231 | 1241 |
|
1232 | 1242 | def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs): |
1233 | 1243 | """ |
|
0 commit comments