-
Notifications
You must be signed in to change notification settings - Fork 6
Description
The code base does not support resuming training … and it doesn't save the model state in such a way that one could resume. The code saves the state_dict data for the generator and discriminator at tar files for some reason even though python just pickles the output (not tar).
Second, for a checkpoint to be useful for resuming training more data has to be stored - these include the epoch, model state_dict and optimizer state_dict
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()