|
24 | 24 | '--lr_scheduler_divisor': { |
25 | 25 | 'type': int, |
26 | 26 | }, |
| 27 | + '--test_only_at_end': { |
| 28 | + 'action': 'store_true', |
| 29 | + }, |
27 | 30 | } |
28 | 31 |
|
29 | 32 | FLAGS = args_parse.parse_common_options( |
@@ -236,15 +239,16 @@ def test_loop_fn(loader, epoch): |
236 | 239 | xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) |
237 | 240 | train_loop_fn(train_device_loader, epoch) |
238 | 241 | xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) |
239 | | - accuracy = test_loop_fn(test_device_loader, epoch) |
240 | | - xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( |
241 | | - epoch, test_utils.now(), accuracy)) |
242 | | - max_accuracy = max(accuracy, max_accuracy) |
243 | | - test_utils.write_to_summary( |
244 | | - writer, |
245 | | - epoch, |
246 | | - dict_to_write={'Accuracy/test': accuracy}, |
247 | | - write_xla_metrics=True) |
| 242 | + if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs: |
| 243 | + accuracy = test_loop_fn(test_device_loader, epoch) |
| 244 | + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( |
| 245 | + epoch, test_utils.now(), accuracy)) |
| 246 | + max_accuracy = max(accuracy, max_accuracy) |
| 247 | + test_utils.write_to_summary( |
| 248 | + writer, |
| 249 | + epoch, |
| 250 | + dict_to_write={'Accuracy/test': accuracy}, |
| 251 | + write_xla_metrics=True) |
248 | 252 | if FLAGS.metrics_debug: |
249 | 253 | xm.master_print(met.metrics_report()) |
250 | 254 |
|
|
0 commit comments