Skip to content

Commit 2c0b8d4

Browse files
committed
fixed bug where distributed training would stall
1 parent 283155a commit 2c0b8d4

File tree

4 files changed

+55
-33
lines changed

4 files changed

+55
-33
lines changed

src/experiment/experiment/logger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,12 @@ def _copy_results_over():
286286
util_file.create_directory(self._results_dir)
287287

288288
# setup logging (with convenience function)
289-
self._stdout_logger = pp_logging.setup_stdout(self._results_dir)
289+
stdout_file = os.path.join(self._results_dir, "experiment.log")
290+
self._stdout_logger = pp_logging.setup_stdout(stdout_file)
290291

291292
# setup train logger
292293
self._train_logger = pp_logging.TrainLogger(
293-
self._log_dir, self.global_tag, self._class_to_names
294+
self._log_dir, stdout_file, self.global_tag, self._class_to_names,
294295
)
295296

296297
# initialize writer

src/provable_pruning/provable_pruning/util/logging/stdout.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""A module with our customization for stdout to include a file log."""
22
import sys
33
import datetime
4-
import os.path
54
import re
65

76

@@ -22,7 +21,7 @@ def __init__(self, file_name):
2221
self._last_msg_len = 0
2322

2423
# this will be the file where we also log
25-
self._stdout_file = open(file_name, "w")
24+
self._stdout_file = file_name
2625

2726
def write(self, msg, name=None):
2827
"""Write to file and console.
@@ -77,7 +76,8 @@ def write(self, msg, name=None):
7776

7877
# also write to log file
7978
time_tag = datetime.datetime.utcnow().strftime("%Y-%m-%d, %H:%M:%S.%f")
80-
print(f"{time_tag}: {msg}", file=self._stdout_file)
79+
with open(self._stdout_file, "a") as logfile:
80+
print(f"{time_tag}: {msg}", file=logfile)
8181

8282
# store last_name
8383
self._last_name = name
@@ -88,14 +88,10 @@ def write(self, msg, name=None):
8888
def flush(self):
8989
"""Flush console and file."""
9090
self._stdout_original.flush()
91-
self._stdout_file.flush()
9291

9392

94-
def setup_stdout(results_dir):
93+
def setup_stdout(log_file):
9594
"""Set up stdout logger with this function."""
96-
# log file name
97-
log_file = os.path.join(results_dir, "experiment.log")
98-
9995
# get an instance of the stdout logger
10096
stdout_logger = _StdoutLogger(log_file)
10197

src/provable_pruning/provable_pruning/util/logging/train.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55

66
from torch.utils import tensorboard as tb
7+
from .stdout import setup_stdout
78
from .tensorboard import log_scalar
89

910

@@ -16,14 +17,23 @@ class TrainLogger(object):
1617
multiprocessing context.
1718
"""
1819

19-
def __init__(self, log_dir=None, global_tag=None, class_to_names=None):
20+
def __init__(
21+
self,
22+
log_dir=None,
23+
stdout_file=None,
24+
global_tag=None,
25+
class_to_names=None,
26+
):
2027
"""Initialize the train logger.
2128
2229
If the optional arguments are not supplied, the logger will print
23-
updates about the training progress but won't log it to tensorboard
30+
updates about the training progress but won't log it to tensorboard or
31+
log it to a file
2432
"""
2533
self._global_tag = global_tag
2634
self._logdir = log_dir
35+
self._stdout_file = stdout_file
36+
self._stdout_init = False
2737
self._diagnostics_step = 20 if "imagenet" in self._logdir else 50
2838
self._class_to_names = class_to_names
2939

@@ -80,8 +90,9 @@ def initialize(
8090
s_idx=None,
8191
):
8292
"""Initialize the logger for the current (re-)training session."""
83-
# reset the writer
93+
# reset the writer and logger
8494
self._writer = None
95+
self._stdout_init = False
8596

8697
# setup parameters
8798
if self._class_to_names is None:
@@ -171,7 +182,7 @@ def train_diagnostics(
171182
self._t_last_print = time.time()
172183

173184
# print progress
174-
print(
185+
self._print(
175186
self._progress_str.format(
176187
epoch + 1, step, loss, acc1 * 100.0, acc5 * 100.0, t_elapsed
177188
)
@@ -225,9 +236,9 @@ def test_diagnostics(self, epoch, loss, acc1, acc5):
225236
"""Finish test statistics computations and store them."""
226237
# store statistics
227238
self.test_epoch.append(epoch)
228-
self.test_acc1.append(loss)
229-
self.test_acc5.append(acc1)
230-
self.test_loss.append(acc5)
239+
self.test_loss.append(float(loss))
240+
self.test_acc1.append(acc1)
241+
self.test_acc5.append(acc5)
231242

232243
# get the writer
233244
writer = self._get_writer()
@@ -268,7 +279,7 @@ def test_diagnostics(self, epoch, loss, acc1, acc5):
268279
)
269280

270281
# print progress
271-
print(
282+
self._print(
272283
self._test_str.format(
273284
self.test_epoch[-1] + 1,
274285
self.test_loss[-1],
@@ -280,8 +291,16 @@ def test_diagnostics(self, epoch, loss, acc1, acc5):
280291
def epoch_diagnostics(self, t_total, t_loading, t_optim, t_enforce, t_log):
281292
"""Print diagnostics around the timing of one epoch."""
282293
t_remaining = t_total - sum([t_loading, t_optim, t_enforce, t_log])
283-
print(
294+
self._print(
284295
self._timing_str.format(
285296
t_total, t_loading, t_optim, t_enforce, t_log, t_remaining
286297
)
287298
)
299+
300+
def _print(self, value):
301+
"""Print and ensure we are also printing to file."""
302+
if not self._stdout_init and self._stdout_file is not None:
303+
stdout = setup_stdout(self._stdout_file)
304+
stdout.write(" " * 200, name=self.name)
305+
self._stdout_init = True
306+
print(value)

src/provable_pruning/provable_pruning/util/train.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _train_procedure(
315315
torch.set_grad_enabled(True)
316316

317317
# setup torch.distributed and spawn processes
318-
num_workers = self.train_loader.num_workers // self.num_gpus
318+
num_workers = self.train_loader.num_workers // max(self.num_gpus, 1)
319319

320320
# empty gpu cache to make sure everything is ready for retraining
321321
torch.cuda.empty_cache()
@@ -456,6 +456,9 @@ def train_with_worker(
456456
file_name_checkpoint, net_handle, optimizer, loc
457457
)
458458

459+
# wait for all processes to load the checkpoint
460+
dist.barrier()
461+
459462
# this may be non-zero in the case of rewinding ...
460463
if not found_checkpoint:
461464
start_epoch = params["startEpoch"]
@@ -476,9 +479,6 @@ def train_with_worker(
476479
if not is_cpu:
477480
cudnn.benchmark = True
478481

479-
# switch to train mode
480-
net_parallel.train()
481-
482482
# convenience function for storing check points
483483
def store_checkpoints(epoch):
484484
# save checkpoint at the end of every epoch with 0 worker
@@ -518,15 +518,14 @@ def store_checkpoints(epoch):
518518
)
519519

520520
# test after one epoch
521-
if gpu_id == 0 and train_logger is not None:
522-
_test_one_epoch(
523-
loader=test_loader,
524-
criterion=criterion,
525-
epoch=epoch,
526-
device=worker_device,
527-
net=net_parallel,
528-
train_logger=train_logger,
529-
)
521+
_test_one_epoch(
522+
loader=test_loader,
523+
criterion=criterion,
524+
epoch=epoch,
525+
device=worker_device,
526+
net=net_parallel,
527+
train_logger=train_logger if gpu_id == 0 else None,
528+
)
530529

531530
# store final checkpoint
532531
store_checkpoints(params["numEpochs"])
@@ -540,6 +539,7 @@ def store_checkpoints(epoch):
540539

541540
# destroy process group at the end
542541
if is_distributed:
542+
dist.barrier()
543543
dist.destroy_process_group()
544544

545545

@@ -561,6 +561,9 @@ def _train_one_epoch(
561561
t_enforce = 0.0
562562
t_log = 0.0
563563

564+
# switch to train mode
565+
net_parallel.train()
566+
564567
# go through one epoch and train
565568
for i, (images, targets) in enumerate(train_loader):
566569

@@ -624,6 +627,9 @@ def _test_one_epoch(loader, criterion, epoch, device, net, train_logger=None):
624627
loss = 0
625628
num_total = 0
626629

630+
# switch to eval mode
631+
net.eval()
632+
627633
with torch.no_grad():
628634
for images, targets in loader:
629635
# move to correct device
@@ -647,7 +653,7 @@ def _test_one_epoch(loader, criterion, epoch, device, net, train_logger=None):
647653
acc5 /= num_total
648654
loss /= num_total
649655

650-
# make sure loss is also a regular float (not torch.Tensor)/s
656+
# make sure loss is also a regular float (not torch.Tensor)
651657
loss = float(loss)
652658

653659
if train_logger is not None:

0 commit comments

Comments
 (0)