Skip to content

Commit f2f8f44

Browse files
authored
Cherrypick PR#2878 into r1.8 (#2880)
1 parent 66e6cd1 commit f2f8f44

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

test/test_train_mp_imagenet.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
'--lr_scheduler_divisor': {
2525
'type': int,
2626
},
27+
'--test_only_at_end': {
28+
'action': 'store_true',
29+
},
2730
}
2831

2932
FLAGS = args_parse.parse_common_options(
@@ -236,15 +239,16 @@ def test_loop_fn(loader, epoch):
236239
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
237240
train_loop_fn(train_device_loader, epoch)
238241
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
239-
accuracy = test_loop_fn(test_device_loader, epoch)
240-
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
241-
epoch, test_utils.now(), accuracy))
242-
max_accuracy = max(accuracy, max_accuracy)
243-
test_utils.write_to_summary(
244-
writer,
245-
epoch,
246-
dict_to_write={'Accuracy/test': accuracy},
247-
write_xla_metrics=True)
242+
if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs:
243+
accuracy = test_loop_fn(test_device_loader, epoch)
244+
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
245+
epoch, test_utils.now(), accuracy))
246+
max_accuracy = max(accuracy, max_accuracy)
247+
test_utils.write_to_summary(
248+
writer,
249+
epoch,
250+
dict_to_write={'Accuracy/test': accuracy},
251+
write_xla_metrics=True)
248252
if FLAGS.metrics_debug:
249253
xm.master_print(met.metrics_report())
250254

0 commit comments

Comments
 (0)