Skip to content

Commit c2cf31e

Browse files
authored
Merge pull request #729 from hcsolakoglu/fix-ckpt-rotation
Exclude pretrained models from the checkpoint rotation logic
2 parents 46266f1 + 2d27d2c commit c2cf31e

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

src/f5_tts/model/trainer.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,14 @@ def save_checkpoint(self, update, last=False):
160160
return
161161
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
162162
if self.keep_last_n_checkpoints > 0:
163+
# Updated logic to exclude pretrained model from rotation
163164
checkpoints = [
164165
f
165166
for f in os.listdir(self.checkpoint_path)
166-
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
167+
if f.startswith("model_")
168+
and not f.startswith("pretrained_") # Exclude pretrained models
169+
and f.endswith(".pt")
170+
and f != "model_last.pt"
167171
]
168172
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
169173
while len(checkpoints) > self.keep_last_n_checkpoints:
@@ -183,10 +187,24 @@ def load_checkpoint(self):
183187
if "model_last.pt" in os.listdir(self.checkpoint_path):
184188
latest_checkpoint = "model_last.pt"
185189
else:
186-
latest_checkpoint = sorted(
187-
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
188-
key=lambda x: int("".join(filter(str.isdigit, x))),
189-
)[-1]
190+
# Updated to consider pretrained models for loading but prioritize training checkpoints
191+
all_checkpoints = [
192+
f
193+
for f in os.listdir(self.checkpoint_path)
194+
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
195+
]
196+
197+
# First try to find regular training checkpoints
198+
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
199+
if training_checkpoints:
200+
latest_checkpoint = sorted(
201+
training_checkpoints,
202+
key=lambda x: int("".join(filter(str.isdigit, x))),
203+
)[-1]
204+
else:
205+
# If no training checkpoints, use pretrained model
206+
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
207+
190208
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
191209
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
192210

src/f5_tts/train/finetune_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def main():
111111
if not os.path.isdir(checkpoint_path):
112112
os.makedirs(checkpoint_path, exist_ok=True)
113113

114-
file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
114+
# Change: Add 'pretrained_' prefix to copied model
115+
file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path))
115116
if not os.path.isfile(file_checkpoint):
116117
shutil.copy2(ckpt_path, file_checkpoint)
117118
print("copy checkpoint for finetune")

src/f5_tts/train/finetune_gradio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,9 @@ def vocab_extend(project_name, symbols, model_type):
10991099
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
11001100
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
11011101
os.makedirs(new_ckpt_path, exist_ok=True)
1102-
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
1102+
1103+
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1104+
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
11031105

11041106
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
11051107

0 commit comments

Comments
 (0)