Skip to content

Commit 9e51878

Browse files
committed
0.4.2 fix trainer with grad_accum
1 parent 12d6970 commit 9e51878

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "f5-tts"
7-
version = "0.4.1"
7+
version = "0.4.2"
88
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
99
readme = "README.md"
1010
license = {text = "MIT License"}

src/f5_tts/model/trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,8 @@ def __init__(
5656

5757
if logger == "wandb" and not wandb.api.api_key:
5858
logger = None
59-
print(f"Using logger: {logger}")
6059
self.log_samples = log_samples
6160

62-
if grad_accumulation_steps > 1 and self.is_main:
63-
print(
64-
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
65-
)
66-
6761
self.accelerator = Accelerator(
6862
log_with=logger if logger == "wandb" else None,
6963
kwargs_handlers=[ddp_kwargs],
@@ -106,6 +100,12 @@ def __init__(
106100
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
107101
self.ema_model.to(self.accelerator.device)
108102

103+
print(f"Using logger: {logger}")
104+
if grad_accumulation_steps > 1:
105+
print(
106+
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
107+
)
108+
109109
self.epochs = epochs
110110
self.num_warmup_updates = num_warmup_updates
111111
self.save_per_updates = save_per_updates
@@ -357,7 +357,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
357357
self.writer.add_scalar("loss", loss.item(), global_update)
358358
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
359359

360-
if global_update % self.save_per_updates == 0:
360+
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
361361
self.save_checkpoint(global_update)
362362

363363
if self.log_samples and self.accelerator.is_local_main_process:
@@ -391,7 +391,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
391391
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
392392
)
393393

394-
if global_update % self.last_per_updates == 0:
394+
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
395395
self.save_checkpoint(global_update, last=True)
396396

397397
self.save_checkpoint(global_update, last=True)

0 commit comments

Comments
 (0)