diff --git a/.gitignore b/.gitignore index 628f31f..9b40272 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ results/ ckp/ checkpoints/ *.swp +wandb/ Dockerfile build_dgx.sh diff --git a/micromind/core.py b/micromind/core.py index 2f887be..9e94aeb 100644 --- a/micromind/core.py +++ b/micromind/core.py @@ -25,11 +25,14 @@ # This is used ONLY if you are not using argparse to get the hparams default_cfg = { + "project_name": "micromind", "output_folder": "results", "experiment_name": "micromind_exp", "opt": "adam", # this is ignored if you are overriding the configure_optimizers "lr": 0.001, # this is ignored if you are overriding the configure_optimizers "debug": False, + "log_wandb": False, + "wandb_resume": "auto", # ["allow", "must", "never", "auto" or None] } @@ -381,7 +384,8 @@ def compute_macs(self, input_shape: Union[List, Tuple]): def on_train_start(self): """Initializes the optimizer, modules and puts the networks on the right - devices. Optionally loads checkpoint if already present. + devices. Optionally loads checkpoint if already present. It also start wandb + logger if selected. This function gets executed at the beginning of every training. """ @@ -389,6 +393,17 @@ def on_train_start(self): # pass debug status to checkpointer self.checkpointer.debug = self.hparams.debug + if self.hparams.log_wandb: + import wandb + + self.wlog = wandb.init( + project=self.hparams.project_name, + name=self.hparams.experiment_name, + resume=self.hparams.wandb_resume, + id=self.hparams.experiment_name, + config=self.hparams, + ) + init_opt = self.configure_optimizers() if isinstance(init_opt, list) or isinstance(init_opt, tuple): self.opt, self.lr_sched = init_opt @@ -449,6 +464,8 @@ def init_devices(self): def on_train_end(self): """Runs at the end of each training. Cleans up before exiting.""" + if self.hparams.log_wandb: + self.wlog.finish() pass def eval(self): @@ -531,6 +548,9 @@ def train( # ok for cos_lr self.lr_sched.step() + if self.hparams.log_wandb: + self.wlog.log({"lr": self.lr_sched.get_last_lr()}) + for m in self.metrics: if ( self.current_epoch + 1 @@ -574,6 +594,10 @@ def train( else: val_metrics = train_metrics.update({"val_loss": loss_epoch / (idx + 1)}) + if self.hparams.log_wandb: # wandb log + self.wlog.log(train_metrics) + self.wlog.log(val_metrics) + if e >= 1 and self.debug: break