-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathprogram.py
More file actions
376 lines (323 loc) · 12.4 KB
/
program.py
File metadata and controls
376 lines (323 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from collections import OrderedDict
import paddle
from paddle import to_tensor
import paddle.nn as nn
import paddle.nn.functional as F
from ppcls.optimizer import LearningRateBuilder
from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
from ppcls.modeling.loss import JSDivLoss
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
def create_model(architecture, classes_num):
"""
Create a model
Args:
architecture(dict): architecture information,
name(such as ResNet50) is needed
image(variable): model input variable
classes_num(int): num of classes
Returns:
out(variable): model output variable
"""
name = architecture["name"]
params = architecture.get("params", {})
return architectures.__dict__[name](class_dim=classes_num, **params)
def create_loss(feeds,
out,
architecture,
classes_num=1000,
epsilon=None,
use_mix=False,
use_distillation=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
2. CrossEnotry loss with label smoothing
3. CrossEnotry loss with mix(mixup, cutmix, fmix)
4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
5. GoogLeNet loss
Args:
out(variable): model output variable
feeds(dict): dict of model input variables
architecture(dict): architecture information,
name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
loss(variable): loss variable
"""
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[0], out[1], out[2], feeds["label"])
if use_distillation:
assert len(out) == 2, ("distillation output length must be 2, "
"but got {}".format(len(out)))
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0])
if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
feed_y_a = feeds['y_a']
feed_y_b = feeds['y_b']
feed_lam = feeds['lam']
return loss(out, feed_y_a, feed_y_b, feed_lam)
else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, feeds["label"])
def create_metric(out,
label,
architecture,
topk=5,
classes_num=1000,
use_distillation=False,
mode="train"):
"""
Create measures of model accuracy, such as top1 and top5
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
use_distillation(bool): whether to use distillation training
mode(str): mode, train/valid
Returns:
fetchs(dict): dict of measures
"""
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
out = out[0]
else:
# just need student label to get metrics
if use_distillation:
out = out[1]
softmax_out = F.softmax(out)
fetchs = OrderedDict()
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
topk_name = 'top{}'.format(k)
fetchs[topk_name] = topk
return fetchs
def create_fetchs(feeds, net, config, mode="train"):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
architecture = config.ARCHITECTURE
topk = config.topk
classes_num = config.classes_num
epsilon = config.get('ls_epsilon')
use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation')
out = net(feeds["image"])
fetchs = OrderedDict()
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation)
if not use_mix:
metric = create_metric(
out,
feeds["label"],
architecture,
topk,
classes_num,
use_distillation,
mode=mode)
fetchs.update(metric)
return fetchs
def create_optimizer(config, parameter_list=None):
"""
Create an optimizer using config, usually including
learning rate and regularization.
Args:
config(dict): such as
{
'LEARNING_RATE':
{'function': 'Cosine',
'params': {'lr': 0.1}
},
'OPTIMIZER':
{'function': 'Momentum',
'params':{'momentum': 0.9},
'regularizer':
{'function': 'L2', 'factor': 0.0001}
}
}
Returns:
an optimizer instance
"""
# create learning_rate instance
lr_config = config['LEARNING_RATE']
lr_config['params'].update({
'epochs': config['epochs'],
'step_each_epoch':
config['total_images'] // config['TRAIN']['batch_size'],
})
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
opt_config = config['OPTIMIZER']
opt = OptimizerBuilder(**opt_config)
return opt(lr, parameter_list), lr
def create_feeds(batch, use_mix):
image = batch[0]
if use_mix:
y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
y_b = to_tensor(batch[2].numpy().astype("int64").reshape(-1, 1))
lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
else:
label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1))
feeds = {"image": image, "label": label}
return feeds
def run(dataloader,
config,
net,
optimizer=None,
lr_scheduler=None,
epoch=0,
mode='train'):
"""
Feed data to the model and fetch the measures and loss
Args:
dataloader(paddle dataloader):
exe():
program():
fetchs(dict): dict of measures and the loss
epoch(int): epoch of training or validation
model(str): log only
Returns:
"""
print_interval = config.get("print_interval", 10)
use_mix = config.get("use_mix", False) and mode == "train"
metric_list = [
("loss", AverageMeter(
'loss', '7.5f', postfix=",")),
("lr", AverageMeter(
'lr', 'f', postfix=",", need_avg=False)),
("batch_time", AverageMeter(
'batch_cost', '.5f', postfix=" s,")),
("reader_time", AverageMeter(
'reader_cost', '.5f', postfix=" s,")),
]
if not use_mix:
topk_name = 'top{}'.format(config.topk)
metric_list.insert(
0, (topk_name, AverageMeter(
topk_name, '.5f', postfix=",")))
metric_list.insert(
0, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list)
tic = time.time()
for idx, batch in enumerate(dataloader()):
# avoid statistics from warmup time
if idx == 10:
metric_list["batch_time"].reset()
metric_list["reader_time"].reset()
metric_list['reader_time'].update(time.time() - tic)
batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix)
fetchs = create_fetchs(feeds, net, config, mode)
if mode == 'train':
avg_loss = fetchs['loss']
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
metric_list['lr'].update(
optimizer._global_learning_rate().numpy()[0], batch_size)
if lr_scheduler is not None:
if lr_scheduler.update_specified:
curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
update = max(
0, curr_global_counter - lr_scheduler.update_start_step
) % lr_scheduler.update_step_interval == 0
if update:
lr_scheduler.step()
else:
lr_scheduler.step()
for name, fetch in fetchs.items():
metric_list[name].update(fetch.numpy()[0], batch_size)
metric_list["batch_time"].update(time.time() - tic)
tic = time.time()
fetchs_str = ' '.join([
str(metric_list[key].mean)
if "time" in key else str(metric_list[key].value)
for key in metric_list
])
if idx % print_interval == 0:
ips_info = "ips: {:.5f} images/sec.".format(
batch_size / metric_list["batch_time"].avg)
if mode == 'eval':
logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
mode, idx, fetchs_str, ips_info))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s}, {:s}, {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN'),
logger.coloring(ips_info, 'OKGREEN')))
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
[metric_list['batch_time'].total])
ips_info = "ips: {:.5f} images/sec.".format(
batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
if mode == 'eval':
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(
logger.coloring(end_epoch_str, "RED"),
logger.coloring(mode, "PURPLE"),
logger.coloring(end_str, "OKGREEN"),
logger.coloring(ips_info, "OKGREEN"), ))
# return top1_acc in order to save the best model
if mode == 'valid':
return metric_list['top1'].avg