Skip to content

Commit 0b3c88a

Browse files
committed
Fix #48
1 parent 23e3c77 commit 0b3c88a

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

trainer/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,12 @@ def test_run(self) -> None:
14171417
else:
14181418
test_outputs = self.model.test(self.training_assets, self.test_loader, None)
14191419
if hasattr(self.model, "test_log") or (self.num_gpus > 1 and hasattr(self.model.module, "test_log")):
1420-
self.model.test_log(test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
1420+
if self.num_gpus > 1:
1421+
self.model.module.test_log(
1422+
test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
1423+
)
1424+
else:
1425+
self.model.test_log(test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
14211426

14221427
def _restore_best_loss(self):
14231428
"""Restore the best loss from the args.best_path if provided else

0 commit comments

Comments
 (0)