Add class to manage running a GPSR training and loading the results#43
Add class to manage running a GPSR training and loading the results#43lboult wants to merge 8 commits intoroussel-ryan:mainfrom
Conversation
roussel-ryan
left a comment
There was a problem hiding this comment.
Good start, we want to minimize the number of times the user has to make a subclass (ie. its not required to use this class for a single case). I also think it would be useful to include convenience methods that plot the loss curves or get the most recent version / plot the most recent loss curve
| This includes preparing datasets, models, logging, checkpointing, and trainer setup. | ||
| """ | ||
|
|
||
| def __init__(self, hparams, log_name="scans"): |
There was a problem hiding this comment.
Instead of having users subclass this object, I think it makes more sense for them to create the dataset/GPSRLattice objects externally and then pass them as required arguments to initialize GPSRRun. This would give a reasonable level of flexibility without having to overwrite the class
There was a problem hiding this comment.
So I changed it to do it this way in e9738da
The reason for it being this way intially is that when loading from a checkpoint, one has to provide the gpsr_lattice object again, which I found kinda clumsy... but maybe this is better for simplicity
| Initialize the GPSRRun with hyperparameters and logging setup. | ||
|
|
||
| Args: | ||
| hparams (dict): Hyperparameters for the model and training. |
There was a problem hiding this comment.
List out the keys of hparams here OR have them be keyword arguments to the init method with docstrings and assign them as attributes to the class (I would prefer the latter)
There was a problem hiding this comment.
Agreed it's better for them to be actual arguments: I tried this out in e0dc0ea... Note that I didn't explicitly add them as attributes of the object though, instead they are just bundled into a hyperparameter dictionary like before (they have to do this at some point for the logging anyway). I could maybe have them as attributes if needed
| ) | ||
| return checkpoint_callback | ||
|
|
||
| def setup_trainer(self): |
There was a problem hiding this comment.
might want to add **kwargs as an argument to this function that would allow additional options to be passed to the trainer without having to overwrite this method
| print("Hyperparameters:") | ||
| pprint(self.hparams) | ||
|
|
||
| def setup_training(self, train_dataset): |
There was a problem hiding this comment.
the train_dataset argument should be removed here in favor of putting it as a required argument in __init__
There was a problem hiding this comment.
The reason I had this here is that one also uses the class for loading in already trained models... It could perhaps still be an argument in the initialisation but by default be None in that case
There was a problem hiding this comment.
I did this in e0dc0ea: the user can still leave the train_dataset as None on init though, in the case when they are just loading in a checkpoint rather than training
|
Note that to accept the PR, you will need to update the examples as well |
Hey,
Here's a class that tries to roll together a lot of the workflows for setting up the GPSR training and also allowing models to be loaded in from checkpoints.
I separated it into a lot of functions such that users can override specific bits of it when they need to change something slightly...
Let me know what you think :)