Skip to content

Resuming training #7

@powerspowers

Description

@powerspowers

The code base does not support resuming training … and it doesn't save the model state in such a way that one could resume. The code saves the state_dict data for the generator and discriminator at tar files for some reason even though python just pickles the output (not tar).

Second, for a checkpoint to be useful for resuming training more data has to be stored - these include the epoch, model state_dict and optimizer state_dict

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)


model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions