1
1
import json
2
2
from collections import Counter
3
+ from functools import partial
3
4
from pathlib import Path
4
5
from typing import List , TypeVar
5
6
20
21
21
22
T = TypeVar ('T' )
22
23
24
+ ARGS = None
23
25
MODEL = 'bert-base-uncased'
24
26
25
27
tokenizer = BertTokenizer .from_pretrained (
36
38
SEP = tokenizer .vocab ['[SEP]' ]
37
39
38
40
writer = SummaryWriter ()
39
- _scorers = ['accuracy' , 'f1_macro' , 'precision_macro' , 'recall_macro' ]
40
- scorers = {name : metrics .get_scorer (name ) for name in _scorers }
41
+ scorers = {
42
+ 'accuracy' : metrics .accuracy_score ,
43
+ 'f1_micro' : partial (metrics .f1_score , average = 'micro' ),
44
+ 'f1_macro' : partial (metrics .f1_score , average = 'macro' ),
45
+ 'f1_weighted' : partial (metrics .f1_score , average = 'weighted' ),
46
+ }
41
47
42
48
43
49
@attr .s (auto_attribs = True , slots = True )
@@ -134,14 +140,14 @@ def create_batch(examples: List[Example]):
134
140
for ex in examples
135
141
]
136
142
return (
137
- torch .tensor (tokens ),
138
- torch .tensor (segments ),
139
- torch .tensor (mask ),
140
- torch .tensor (labels ),
143
+ torch .tensor (tokens , device = ARGS . device ),
144
+ torch .tensor (segments , device = ARGS . device ),
145
+ torch .tensor (mask , device = ARGS . device ),
146
+ torch .tensor (labels , device = ARGS . device ),
141
147
)
142
148
143
149
144
- def train (model , data_loader , epochs ):
150
+ def train (model , train_data , eval_data , epochs ):
145
151
param_optimizer = list (model .named_parameters ())
146
152
no_decay = ['bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ]
147
153
optimizer_grouped_parameters = [
@@ -160,24 +166,27 @@ def train(model, data_loader, epochs):
160
166
]
161
167
162
168
optimizer = BertAdam (
163
- optimizer_grouped_parameters , lr = ARGS .learning_rate , warmup = 0.1 , t_total = len (data_loader )
169
+ optimizer_grouped_parameters ,
170
+ lr = ARGS .learning_rate ,
171
+ warmup = 0.1 ,
172
+ t_total = len (train_data ),
164
173
)
165
174
166
175
model .train ()
167
176
168
177
for epoch in trange (epochs , desc = "Train epoch" ):
169
- for step , batch in enumerate (tqdm (data_loader , desc = "Iteration" )):
178
+ for step , batch in enumerate (tqdm (train_data , desc = "Iteration" )):
170
179
loss = model (* batch )
171
- print ( "loss" , loss .item ())
180
+ tqdm . write ( f "loss= { loss .item ()} " )
172
181
loss .backward ()
173
182
optimizer .step ()
174
183
optimizer .zero_grad ()
175
184
176
- writer .add_scalar ('train loss' , loss .item (), step )
185
+ writer .add_scalar ('train/ loss' , loss .item (), step )
177
186
178
187
writer .add_graph ('bert' , model , batch [- 1 ])
179
188
180
- eval (model , data_loader )
189
+ eval (model , eval_data )
181
190
182
191
183
192
def eval (model , data_loader ):
@@ -187,22 +196,39 @@ def eval(model, data_loader):
187
196
all_preds = []
188
197
# all_probs = []
189
198
190
- for step , batch in enumerate (tqdm (data_loader , desc = "Eval" )):
191
- print ('train batch shape' , batch .shape )
192
- with torch .no_grad ():
199
+ with torch .no_grad ():
200
+ for step , batch in enumerate (tqdm (data_loader , desc = "Eval" )):
193
201
assert len (batch ) == 4 , "We should have labels here"
202
+ labels = batch [3 ]
203
+ targets = (
204
+ labels != - 100
205
+ ) # the ignored index in the loss (= we ignore the tokens that not the target)
206
+
194
207
logits = model (* batch [:3 ])
195
208
# probs = F.softmax(logits, dim=1)[:,1] # TODO check this
196
- predictions = logits .argmax (dim = 1 ) .tolist ()
209
+ predictions = logits .argmax (dim = - 1 )[ targets ] .tolist ()
197
210
# all_probs += probs.tolist()
198
- labels = batch [ 3 ].tolist ()
211
+ labels = labels [ targets ].tolist ()
199
212
200
- all_preds += predictions
201
- all_labels += labels
213
+ all_preds += predictions
214
+ all_labels += labels
202
215
203
216
# writer.add_pr_curve('eval', labels=all_labels, predictions=all_probs)
204
- writer .add_scalar ('eval/acc' , metrics .accuracy_score (all_labels , all_preds ))
205
- writer .add_scalar ('eval/f1' , metrics .f1_score (all_labels , all_preds ))
217
+ tqdm .write (f"labels={ ' ' .join (map (str , all_labels ))} " )
218
+ tqdm .write (f"preds ={ ' ' .join (map (str , all_preds ))} " )
219
+ # writer.add_scalar('eval/acc', metrics.accuracy_score(all_labels, all_preds))
220
+ # writer.add_scalar('eval/f1 micro', metrics.f1_score(all_labels, all_preds, ))
221
+ for name , scorer in scorers .items ():
222
+ writer .add_scalar (f'eval/{ name } ' , scorer (all_labels , all_preds ))
223
+ writer .add_text (
224
+ 'eval/classification_report' ,
225
+ metrics .classification_report (
226
+ all_labels ,
227
+ all_preds ,
228
+ labels = [0 , 1 , 2 ],
229
+ target_names = 'None Positive Negative' .split (),
230
+ ),
231
+ )
206
232
207
233
208
234
def main (args ):
@@ -238,7 +264,7 @@ def flatten_aspects(ex):
238
264
# text = [ ('[MASK]' if tok in LOCATIONS else tok) for tok in ex['text'] ]
239
265
ids = tokenizer .convert_tokens_to_ids (ex ['text' ])
240
266
targets = [loc for loc in LOCATIONS if loc in ex ['text' ]]
241
- for target in targets :
267
+ for i , target in enumerate ( targets ) :
242
268
target_idx = ex ['text' ].index (target )
243
269
for aspect in aspects :
244
270
sentiment_or_none = next (
@@ -261,40 +287,49 @@ def flatten_aspects(ex):
261
287
)
262
288
263
289
processed = ds .map_many (flatten_aspects )
290
+ if ARGS .debug :
291
+ processed .train = processed .train [: 2 * ARGS .batch_size ]
292
+ processed .dev = processed .dev [: 2 * ARGS .batch_size ]
293
+ processed .test = processed .test [: 2 * ARGS .batch_size ]
264
294
265
295
processed .print_head ()
266
296
267
- writer .add_text ('params/bert_model' , MODEL )
268
- writer .add_text ('params/batch_size' , ARGS .batch_size )
269
- writer .add_text ('params/learning_rate' , ARGS .learning_rate )
270
- writer .add_text ('params/weight_decay' , ARGS .weight_decay )
297
+ writer .add_text ('params' , f"model={ MODEL } params={ str (ARGS )} " )
271
298
272
- model = BertForTokenClassification .from_pretrained (MODEL , num_labels = 3 )
273
299
# 3 labels for None/neutral, Positive, Negative
274
-
275
- # lm = BertForMaskedLM.from_pretrained(MODEL)
276
- # lm.eval()
277
- # for ex in processed.train:
278
- # tokens_tensor = torch.tensor([ex.token_ids])
279
- # segments_tensor = torch.tensor([segment_ids_from_token_ids(ex.token_ids)])
280
- # print(tokens_tensor)
281
- # print(segments_tensor)
282
-
283
- # predictions = lm(tokens_tensor)
284
- # print(predictions)
285
- # predicted_index = torch.argmax(predictions[0, ex.target_idx]).item()
286
- # predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
287
-
288
- # print(predicted_index)
289
- # print(predicted_token)
290
- # print(tokenizer.convert_ids_to_tokens(ex.token_ids), ex.text)
291
-
292
- # print(model(*create_batch([ex])))
293
-
294
- # break
295
-
296
- loader = DataLoader (
297
- processed .train , batch_size = batch_size , shuffle = True , collate_fn = create_batch
300
+ model = BertForTokenClassification .from_pretrained (MODEL , num_labels = 3 )
301
+ model .to (ARGS .device )
302
+
303
+ if ARGS .balanced_sampler :
304
+ class_counts = Counter (ex .sentiment for ex in processed .train )
305
+ class_min = class_counts .most_common ()[- 1 ][1 ]
306
+ writer .add_text (
307
+ 'info/balanced_sampler_weights' ,
308
+ str (
309
+ {
310
+ sentiment : class_min / count
311
+ for sentiment , count in class_counts .items ()
312
+ }
313
+ ),
314
+ )
315
+ weights = [
316
+ len (processed .train ) / class_counts [ex .sentiment ] for ex in processed .train
317
+ ]
318
+ sampler = torch .utils .data .WeightedRandomSampler (
319
+ weights = weights , num_samples = len (processed .train )
320
+ )
321
+ else :
322
+ sampler = None
323
+
324
+ train_loader = DataLoader (
325
+ processed .train ,
326
+ batch_size = ARGS .batch_size ,
327
+ shuffle = not ARGS .balanced_sampler ,
328
+ sampler = sampler ,
329
+ collate_fn = create_batch ,
330
+ )
331
+ eval_loader = DataLoader (
332
+ processed .dev , batch_size = ARGS .batch_size , collate_fn = create_batch
298
333
)
299
334
300
- train (model , loader , epochs = 1 )
335
+ train (model , train_loader , eval_loader , epochs = ARGS . epochs )
0 commit comments