-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtrain.py
More file actions
148 lines (125 loc) · 6.58 KB
/
train.py
File metadata and controls
148 lines (125 loc) · 6.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import datetime
import glob
import os
import sys
import torch
import torch.nn.functional as F
from eval import eval
from process_data import to_var, make_vector, gen_cmntrank_batches, gen_editanch_batches
def rank_loss(score_pos, score_neg):
# normalize context attention
# ctx_att_norm = F.normalize(ctx_att, p=2, dim=1)
batch_size = len(score_pos)
# y = Variable(torch.FloatTensor([1] * batch_size))
margin = to_var(torch.FloatTensor(score_pos.size()).fill_(1))
# loss = F.margin_ranking_loss(score_pos, -score_neg, y, margin=10.0)
loss_list = margin - score_pos + score_neg
loss_list = loss_list.clamp(min=0)
loss = loss_list.sum()
return loss, loss_list
def cal_batch_loss(loss_list, batch_size, index_list):
loss_list = loss_list.data.squeeze(1).numpy().tolist()
loss_step = int(len(loss_list) / batch_size)
# return [sum(loss_list[i * loss_step: (i + 1) * loss_step]) / loss_step for i in range(batch_size)]
start = 0
batch_loss_list = []
for i in range(batch_size):
cur_bs = index_list[i]
end = start + cur_bs
if cur_bs == 0:
batch_loss_list.append(0)
else:
batch_loss_list.append(sum(loss_list[start:end]) / cur_bs)
start = end
return batch_loss_list
def train(args, model, dataset, train_df, val_df, optimizer, w2i, n_epoch, start_epoch, batch_size):
print('----Train---')
label = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
model.train()
cmnt_sent_len = args.max_cmnt_length
diff_sent_len = args.max_diff_length
ctx_sent_len = args.max_ctx_length
train_size = len(train_df)
# if args.use_cl:
# sample_size = int(train_size * 0.5)
# else:
# sample_size = -1
sample_size = int(train_size)
cr_best_acc, ea_best_acc = 0, 0
for epoch in range(start_epoch, n_epoch + 1):
print('============================== Epoch ', epoch, ' ==============================')
for step, batch in enumerate(dataset.iterate_minibatches(train_df, batch_size, epoch, n_epoch), start=1):
batch_sample_weights = None
total_loss = 0
if args.cr_train:
cr_pos_batch, cr_neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len,
ctx_sent_len, args.rank_num)
if len(cr_pos_batch[0]) > 0:
# TODO: tuning the code for more effective way: combine the positive sample and negative samples
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
make_vector(cr_pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
make_vector(cr_neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action,
cr_mode=True)
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action,
cr_mode=True)
loss, _ = rank_loss(score_pos, score_neg)
# batch_sample_weights = cal_batch_loss(loss_list, batch_size, index_list)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# train revision anchoring
if args.ea_train:
ea_batch, ea_truth = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.anchor_num)
if len(ea_batch[0]) > 0:
cmnt, src_token, src_action, tgt_token, tgt_action = \
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
target = to_var(torch.tensor(ea_truth))
loss = F.cross_entropy(logit, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the loss for each data sample
# new_sample_weights += batch_sample_weights
if step % args.log_interval == 0:
# corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
# p1_corr, p3_corr, p5_corr, mrr, ndcg = eval_rank(score_pos, score_neg, args.batch_size)
# p1_acc = p1_corr / args.batch_size * 100
# p3_acc = p3_corr / args.batch_size * 100
# p5_acc = p5_corr / args.batch_size * 100
try:
sys.stdout.write(
'\rEpoch[{}] Batch[{}] - loss: {:.6f}\n'.format(epoch, step, loss.data.item()))
sys.stdout.flush()
except:
print("Unexpected error:", sys.exc_info()[0])
if step % args.test_interval == 0:
if args.val_size > 0:
val_df = val_df[:args.val_size]
cr_acc, ea_acc = eval(dataset, val_df, w2i, model, args)
model.train() # change model back to training mode
if args.cr_train and cr_acc > cr_best_acc:
cr_best_acc = cr_acc
if args.save_best:
save(model, args.save_dir, 'best_cr', epoch, step, cr_best_acc, args.no_action)
if args.ea_train and ea_acc > ea_best_acc:
ea_best_acc = ea_acc
if args.save_best:
save(model, args.save_dir, 'best_ea', epoch, step, ea_best_acc, args.no_action)
# sample_weights = new_sample_weights
def save(model, save_dir, save_prefix, epoch, steps, best_result, no_action=False):
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, save_prefix)
# delete previously saved checkpoints
exist_files = sorted(glob.glob(save_prefix + '*'))
for file_name in exist_files:
if os.path.exists(file_name):
os.remove(file_name)
result_str = '%.3f' % best_result
save_path = '{}_steps_{:02}_{:06}_{}.pt'.format(save_prefix, epoch, steps, result_str)
print("Save best model", save_path)
torch.save(model.state_dict(), save_path)