diff --git a/core/__init__.py b/core/__init__.py index 03af283..efaf22d 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -112,7 +112,7 @@ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimize # perform the training here @abstractmethod - def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers, single_gpu: bool=False): raise NotImplementedError("This method needs to be overriden") # ------------ @@ -357,7 +357,7 @@ def __call__(self, single_gpu=False): # TRAIN if self.is_main_node: print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) + self.train(data, extras, models, optimizers, schedulers, single_gpu) if single_gpu is False: barrier() diff --git a/train/base.py b/train/base.py index 4e8a6ef..e2c84bb 100755 --- a/train/base.py +++ b/train/base.py @@ -239,7 +239,7 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext raise NotImplementedError("This method needs to be overriden") def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, - schedulers: WarpCore.Schedulers): + schedulers: WarpCore.Schedulers, single_gpu: bool=False): start_iter = self.info.iter + 1 max_iters = self.config.updates * self.config.grad_accum_steps if self.is_main_node: @@ -304,13 +304,14 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), } - self.save_checkpoints(models, optimizers) + self.save_checkpoints(models, optimizers, single_gpu=single_gpu) if self.is_main_node: create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') self.sample(models, data, extras) - def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None, single_gpu=False): + if not single_gpu: + barrier() suffix = '' if suffix is None else suffix self.save_info(self.info, suffix=suffix) models_dict = models.to_dict() @@ -325,7 +326,7 @@ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models_dict[key] if self.config.use_fsdp else None) if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: - self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k", single_gpu=single_gpu) torch.cuda.empty_cache() def sample(self, models: Models, data: WarpCore.Data, extras: Extras): diff --git a/train/train_c_lora.py b/train/train_c_lora.py index 8b83eee..c3f7c49 100755 --- a/train/train_c_lora.py +++ b/train/train_c_lora.py @@ -320,11 +320,11 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") + single_gpu = bool(sys.argv[2]) if len(sys.argv) > 2 else False warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=torch.device(int(os.environ.get("SLURM_LOCALID")) if not single_gpu else 0) ) warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - # RUN TRAINING - warpcore() + warpcore(single_gpu=single_gpu)