Skip to content
This repository was archived by the owner on Jul 19, 2021. It is now read-only.

Commit e385810

Browse files
author
Mathis Chenuet
committed
end wip ?
good training + add cli args + scorers + gpu acceleration + eval on dev data
1 parent a9492ec commit e385810

File tree

2 files changed

+128
-56
lines changed

2 files changed

+128
-56
lines changed

main.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#/usr/bin/env python3
1+
#!/usr/bin/env python3
22

33
if __name__ == '__main__':
44
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
@@ -7,12 +7,49 @@
77
arg = parser.add_argument
88

99
arg('--data_dir', default='data/sentihood/', help="Path to the dataset directory")
10-
arg('-b', '--batch_size', default=4, help="Batch size for training and evaluation")
11-
arg('-lr', '--learning_rate', default=1e-4, help="Optimizer learning rate")
12-
arg('-wd', '--weight_decay', default=0.01, help="Optimizer weight decay (L2)")
10+
arg('-e', '--epochs', default=1, type=int, help="Training epochs")
11+
arg(
12+
'-b',
13+
'--batch_size',
14+
default=4,
15+
type=int,
16+
help="Batch size for training and evaluation",
17+
)
18+
arg(
19+
'-bs',
20+
'--balanced_sampler',
21+
default=True,
22+
type=bool,
23+
help="Pick examples uniformly *between classes*",
24+
)
25+
arg(
26+
'-lr',
27+
'--learning_rate',
28+
default=1e-4,
29+
type=float,
30+
help="Optimizer learning rate",
31+
)
32+
arg(
33+
'-wd',
34+
'--weight_decay',
35+
default=0.01,
36+
type=float,
37+
help="Optimizer weight decay (L2)",
38+
)
39+
arg('--cpu', action='store_true', help="Force CPU even if CUDA is available")
40+
arg(
41+
'--debug',
42+
action='store_true',
43+
help="Debug options (truncate the datasets for faster debugging)",
44+
)
1345

1446
args = parser.parse_args()
1547

48+
import torch
1649
from sentihood import main
50+
51+
args.device = torch.device(
52+
'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'
53+
)
1754
# main(**vars(args))
1855
main(args)

sentihood.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from collections import Counter
3+
from functools import partial
34
from pathlib import Path
45
from typing import List, TypeVar
56

@@ -20,6 +21,7 @@
2021

2122
T = TypeVar('T')
2223

24+
ARGS = None
2325
MODEL = 'bert-base-uncased'
2426

2527
tokenizer = BertTokenizer.from_pretrained(
@@ -36,8 +38,12 @@
3638
SEP = tokenizer.vocab['[SEP]']
3739

3840
writer = SummaryWriter()
39-
_scorers = ['accuracy', 'f1_macro', 'precision_macro', 'recall_macro']
40-
scorers = {name: metrics.get_scorer(name) for name in _scorers}
41+
scorers = {
42+
'accuracy': metrics.accuracy_score,
43+
'f1_micro': partial(metrics.f1_score, average='micro'),
44+
'f1_macro': partial(metrics.f1_score, average='macro'),
45+
'f1_weighted': partial(metrics.f1_score, average='weighted'),
46+
}
4147

4248

4349
@attr.s(auto_attribs=True, slots=True)
@@ -134,14 +140,14 @@ def create_batch(examples: List[Example]):
134140
for ex in examples
135141
]
136142
return (
137-
torch.tensor(tokens),
138-
torch.tensor(segments),
139-
torch.tensor(mask),
140-
torch.tensor(labels),
143+
torch.tensor(tokens, device=ARGS.device),
144+
torch.tensor(segments, device=ARGS.device),
145+
torch.tensor(mask, device=ARGS.device),
146+
torch.tensor(labels, device=ARGS.device),
141147
)
142148

143149

144-
def train(model, data_loader, epochs):
150+
def train(model, train_data, eval_data, epochs):
145151
param_optimizer = list(model.named_parameters())
146152
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
147153
optimizer_grouped_parameters = [
@@ -160,24 +166,27 @@ def train(model, data_loader, epochs):
160166
]
161167

162168
optimizer = BertAdam(
163-
optimizer_grouped_parameters, lr=ARGS.learning_rate, warmup=0.1, t_total=len(data_loader)
169+
optimizer_grouped_parameters,
170+
lr=ARGS.learning_rate,
171+
warmup=0.1,
172+
t_total=len(train_data),
164173
)
165174

166175
model.train()
167176

168177
for epoch in trange(epochs, desc="Train epoch"):
169-
for step, batch in enumerate(tqdm(data_loader, desc="Iteration")):
178+
for step, batch in enumerate(tqdm(train_data, desc="Iteration")):
170179
loss = model(*batch)
171-
print("loss", loss.item())
180+
tqdm.write(f"loss={loss.item()}")
172181
loss.backward()
173182
optimizer.step()
174183
optimizer.zero_grad()
175184

176-
writer.add_scalar('train loss', loss.item(), step)
185+
writer.add_scalar('train/loss', loss.item(), step)
177186

178187
writer.add_graph('bert', model, batch[-1])
179188

180-
eval(model, data_loader)
189+
eval(model, eval_data)
181190

182191

183192
def eval(model, data_loader):
@@ -187,22 +196,39 @@ def eval(model, data_loader):
187196
all_preds = []
188197
# all_probs = []
189198

190-
for step, batch in enumerate(tqdm(data_loader, desc="Eval")):
191-
print('train batch shape', batch.shape)
192-
with torch.no_grad():
199+
with torch.no_grad():
200+
for step, batch in enumerate(tqdm(data_loader, desc="Eval")):
193201
assert len(batch) == 4, "We should have labels here"
202+
labels = batch[3]
203+
targets = (
204+
labels != -100
205+
) # the ignored index in the loss (= we ignore the tokens that not the target)
206+
194207
logits = model(*batch[:3])
195208
# probs = F.softmax(logits, dim=1)[:,1] # TODO check this
196-
predictions = logits.argmax(dim=1).tolist()
209+
predictions = logits.argmax(dim=-1)[targets].tolist()
197210
# all_probs += probs.tolist()
198-
labels = batch[3].tolist()
211+
labels = labels[targets].tolist()
199212

200-
all_preds += predictions
201-
all_labels += labels
213+
all_preds += predictions
214+
all_labels += labels
202215

203216
# writer.add_pr_curve('eval', labels=all_labels, predictions=all_probs)
204-
writer.add_scalar('eval/acc', metrics.accuracy_score(all_labels, all_preds))
205-
writer.add_scalar('eval/f1', metrics.f1_score(all_labels, all_preds))
217+
tqdm.write(f"labels={' '.join(map(str, all_labels))}")
218+
tqdm.write(f"preds ={' '.join(map(str, all_preds))}")
219+
# writer.add_scalar('eval/acc', metrics.accuracy_score(all_labels, all_preds))
220+
# writer.add_scalar('eval/f1 micro', metrics.f1_score(all_labels, all_preds, ))
221+
for name, scorer in scorers.items():
222+
writer.add_scalar(f'eval/{name}', scorer(all_labels, all_preds))
223+
writer.add_text(
224+
'eval/classification_report',
225+
metrics.classification_report(
226+
all_labels,
227+
all_preds,
228+
labels=[0, 1, 2],
229+
target_names='None Positive Negative'.split(),
230+
),
231+
)
206232

207233

208234
def main(args):
@@ -238,7 +264,7 @@ def flatten_aspects(ex):
238264
# text = [ ('[MASK]' if tok in LOCATIONS else tok) for tok in ex['text'] ]
239265
ids = tokenizer.convert_tokens_to_ids(ex['text'])
240266
targets = [loc for loc in LOCATIONS if loc in ex['text']]
241-
for target in targets:
267+
for i, target in enumerate(targets):
242268
target_idx = ex['text'].index(target)
243269
for aspect in aspects:
244270
sentiment_or_none = next(
@@ -261,40 +287,49 @@ def flatten_aspects(ex):
261287
)
262288

263289
processed = ds.map_many(flatten_aspects)
290+
if ARGS.debug:
291+
processed.train = processed.train[: 2 * ARGS.batch_size]
292+
processed.dev = processed.dev[: 2 * ARGS.batch_size]
293+
processed.test = processed.test[: 2 * ARGS.batch_size]
264294

265295
processed.print_head()
266296

267-
writer.add_text('params/bert_model', MODEL)
268-
writer.add_text('params/batch_size', ARGS.batch_size)
269-
writer.add_text('params/learning_rate', ARGS.learning_rate)
270-
writer.add_text('params/weight_decay', ARGS.weight_decay)
297+
writer.add_text('params', f"model={MODEL} params={str(ARGS)}")
271298

272-
model = BertForTokenClassification.from_pretrained(MODEL, num_labels=3)
273299
# 3 labels for None/neutral, Positive, Negative
274-
275-
# lm = BertForMaskedLM.from_pretrained(MODEL)
276-
# lm.eval()
277-
# for ex in processed.train:
278-
# tokens_tensor = torch.tensor([ex.token_ids])
279-
# segments_tensor = torch.tensor([segment_ids_from_token_ids(ex.token_ids)])
280-
# print(tokens_tensor)
281-
# print(segments_tensor)
282-
283-
# predictions = lm(tokens_tensor)
284-
# print(predictions)
285-
# predicted_index = torch.argmax(predictions[0, ex.target_idx]).item()
286-
# predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
287-
288-
# print(predicted_index)
289-
# print(predicted_token)
290-
# print(tokenizer.convert_ids_to_tokens(ex.token_ids), ex.text)
291-
292-
# print(model(*create_batch([ex])))
293-
294-
# break
295-
296-
loader = DataLoader(
297-
processed.train, batch_size=batch_size, shuffle=True, collate_fn=create_batch
300+
model = BertForTokenClassification.from_pretrained(MODEL, num_labels=3)
301+
model.to(ARGS.device)
302+
303+
if ARGS.balanced_sampler:
304+
class_counts = Counter(ex.sentiment for ex in processed.train)
305+
class_min = class_counts.most_common()[-1][1]
306+
writer.add_text(
307+
'info/balanced_sampler_weights',
308+
str(
309+
{
310+
sentiment: class_min / count
311+
for sentiment, count in class_counts.items()
312+
}
313+
),
314+
)
315+
weights = [
316+
len(processed.train) / class_counts[ex.sentiment] for ex in processed.train
317+
]
318+
sampler = torch.utils.data.WeightedRandomSampler(
319+
weights=weights, num_samples=len(processed.train)
320+
)
321+
else:
322+
sampler = None
323+
324+
train_loader = DataLoader(
325+
processed.train,
326+
batch_size=ARGS.batch_size,
327+
shuffle=not ARGS.balanced_sampler,
328+
sampler=sampler,
329+
collate_fn=create_batch,
330+
)
331+
eval_loader = DataLoader(
332+
processed.dev, batch_size=ARGS.batch_size, collate_fn=create_batch
298333
)
299334

300-
train(model, loader, epochs=1)
335+
train(model, train_loader, eval_loader, epochs=ARGS.epochs)

0 commit comments

Comments
 (0)