@@ -56,14 +56,8 @@ def __init__(
5656
5757 if logger == "wandb" and not wandb .api .api_key :
5858 logger = None
59- print (f"Using logger: { logger } " )
6059 self .log_samples = log_samples
6160
62- if grad_accumulation_steps > 1 and self .is_main :
63- print (
64- "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
65- )
66-
6761 self .accelerator = Accelerator (
6862 log_with = logger if logger == "wandb" else None ,
6963 kwargs_handlers = [ddp_kwargs ],
@@ -106,6 +100,12 @@ def __init__(
106100 self .ema_model = EMA (model , include_online_model = False , ** ema_kwargs )
107101 self .ema_model .to (self .accelerator .device )
108102
103+ print (f"Using logger: { logger } " )
104+ if grad_accumulation_steps > 1 :
105+ print (
106+ "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
107+ )
108+
109109 self .epochs = epochs
110110 self .num_warmup_updates = num_warmup_updates
111111 self .save_per_updates = save_per_updates
@@ -357,7 +357,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
357357 self .writer .add_scalar ("loss" , loss .item (), global_update )
358358 self .writer .add_scalar ("lr" , self .scheduler .get_last_lr ()[0 ], global_update )
359359
360- if global_update % self .save_per_updates == 0 :
360+ if global_update % self .save_per_updates == 0 and self . accelerator . sync_gradients :
361361 self .save_checkpoint (global_update )
362362
363363 if self .log_samples and self .accelerator .is_local_main_process :
@@ -391,7 +391,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
391391 f"{ log_samples_path } /update_{ global_update } _ref.wav" , ref_audio , target_sample_rate
392392 )
393393
394- if global_update % self .last_per_updates == 0 :
394+ if global_update % self .last_per_updates == 0 and self . accelerator . sync_gradients :
395395 self .save_checkpoint (global_update , last = True )
396396
397397 self .save_checkpoint (global_update , last = True )
0 commit comments