Skip to content

Commit 46266f1

Browse files
authored
Merge pull request #741 from Chiyan200/main
Fix Settings Loader Issues: Resolve KeyErrors, Path Handling, and Component Assignment (#731)
2 parents 129014c + 24fe39d commit 46266f1

File tree

1 file changed

+91
-54
lines changed

1 file changed

+91
-54
lines changed

src/f5_tts/train/finetune_gradio.py

Lines changed: 91 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -116,41 +116,57 @@ def load_settings(project_name):
116116
path_project = os.path.join(path_project_ckpts, project_name)
117117
file_setting = os.path.join(path_project, "setting.json")
118118

119-
if not os.path.isfile(file_setting):
120-
settings = {
121-
"exp_name": "F5TTS_Base",
122-
"learning_rate": 1e-05,
123-
"batch_size_per_gpu": 1000,
124-
"batch_size_type": "frame",
125-
"max_samples": 64,
126-
"grad_accumulation_steps": 1,
127-
"max_grad_norm": 1,
128-
"epochs": 100,
129-
"num_warmup_updates": 2,
130-
"save_per_updates": 300,
131-
"keep_last_n_checkpoints": -1,
132-
"last_per_updates": 100,
133-
"finetune": True,
134-
"file_checkpoint_train": "",
135-
"tokenizer_type": "pinyin",
136-
"tokenizer_file": "",
137-
"mixed_precision": "none",
138-
"logger": "wandb",
139-
"bnb_optimizer": False,
140-
}
141-
else:
119+
# Default settings
120+
default_settings = {
121+
"exp_name": "F5TTS_Base",
122+
"learning_rate": 1e-05,
123+
"batch_size_per_gpu": 1000,
124+
"batch_size_type": "frame",
125+
"max_samples": 64,
126+
"grad_accumulation_steps": 1,
127+
"max_grad_norm": 1,
128+
"epochs": 100,
129+
"num_warmup_updates": 2,
130+
"save_per_updates": 300,
131+
"keep_last_n_checkpoints": -1,
132+
"last_per_updates": 100,
133+
"finetune": True,
134+
"file_checkpoint_train": "",
135+
"tokenizer_type": "pinyin",
136+
"tokenizer_file": "",
137+
"mixed_precision": "none",
138+
"logger": "wandb",
139+
"bnb_optimizer": False,
140+
}
141+
142+
# Load settings from file if it exists
143+
if os.path.isfile(file_setting):
142144
with open(file_setting, "r") as f:
143-
settings = json.load(f)
144-
if "logger" not in settings:
145-
settings["logger"] = "wandb"
146-
if "bnb_optimizer" not in settings:
147-
settings["bnb_optimizer"] = False
148-
if "keep_last_n_checkpoints" not in settings:
149-
settings["keep_last_n_checkpoints"] = -1 # default to keep all checkpoints
150-
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
151-
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
145+
file_settings = json.load(f)
146+
default_settings.update(file_settings)
152147

153-
return settings
148+
# Return as a tuple in the correct order
149+
return (
150+
default_settings["exp_name"],
151+
default_settings["learning_rate"],
152+
default_settings["batch_size_per_gpu"],
153+
default_settings["batch_size_type"],
154+
default_settings["max_samples"],
155+
default_settings["grad_accumulation_steps"],
156+
default_settings["max_grad_norm"],
157+
default_settings["epochs"],
158+
default_settings["num_warmup_updates"],
159+
default_settings["save_per_updates"],
160+
default_settings["keep_last_n_checkpoints"],
161+
default_settings["last_per_updates"],
162+
default_settings["finetune"],
163+
default_settings["file_checkpoint_train"],
164+
default_settings["tokenizer_type"],
165+
default_settings["tokenizer_file"],
166+
default_settings["mixed_precision"],
167+
default_settings["logger"],
168+
default_settings["bnb_optimizer"],
169+
)
154170

155171

156172
# Load metadata
@@ -1579,27 +1595,48 @@ def get_audio_select(file_sample):
15791595
stop_button = gr.Button("Stop Training", interactive=False)
15801596

15811597
if projects_selelect is not None:
1582-
settings = load_settings(projects_selelect)
1583-
1584-
exp_name.value = settings["exp_name"]
1585-
learning_rate.value = settings["learning_rate"]
1586-
batch_size_per_gpu.value = settings["batch_size_per_gpu"]
1587-
batch_size_type.value = settings["batch_size_type"]
1588-
max_samples.value = settings["max_samples"]
1589-
grad_accumulation_steps.value = settings["grad_accumulation_steps"]
1590-
max_grad_norm.value = settings["max_grad_norm"]
1591-
epochs.value = settings["epochs"]
1592-
num_warmup_updates.value = settings["num_warmup_updates"]
1593-
save_per_updates.value = settings["save_per_updates"]
1594-
keep_last_n_checkpoints.value = settings["keep_last_n_checkpoints"]
1595-
last_per_updates.value = settings["last_per_updates"]
1596-
ch_finetune.value = settings["finetune"]
1597-
file_checkpoint_train.value = settings["file_checkpoint_train"]
1598-
tokenizer_type.value = settings["tokenizer_type"]
1599-
tokenizer_file.value = settings["tokenizer_file"]
1600-
mixed_precision.value = settings["mixed_precision"]
1601-
cd_logger.value = settings["logger"]
1602-
ch_8bit_adam.value = settings["bnb_optimizer"]
1598+
(
1599+
exp_name_value,
1600+
learning_rate_value,
1601+
batch_size_per_gpu_value,
1602+
batch_size_type_value,
1603+
max_samples_value,
1604+
grad_accumulation_steps_value,
1605+
max_grad_norm_value,
1606+
epochs_value,
1607+
num_warmup_updates_value,
1608+
save_per_updates_value,
1609+
keep_last_n_checkpoints_value,
1610+
last_per_updates_value,
1611+
finetune_value,
1612+
file_checkpoint_train_value,
1613+
tokenizer_type_value,
1614+
tokenizer_file_value,
1615+
mixed_precision_value,
1616+
logger_value,
1617+
bnb_optimizer_value,
1618+
) = load_settings(projects_selelect)
1619+
1620+
# Assigning values to the respective components
1621+
exp_name.value = exp_name_value
1622+
learning_rate.value = learning_rate_value
1623+
batch_size_per_gpu.value = batch_size_per_gpu_value
1624+
batch_size_type.value = batch_size_type_value
1625+
max_samples.value = max_samples_value
1626+
grad_accumulation_steps.value = grad_accumulation_steps_value
1627+
max_grad_norm.value = max_grad_norm_value
1628+
epochs.value = epochs_value
1629+
num_warmup_updates.value = num_warmup_updates_value
1630+
save_per_updates.value = save_per_updates_value
1631+
keep_last_n_checkpoints.value = keep_last_n_checkpoints_value
1632+
last_per_updates.value = last_per_updates_value
1633+
ch_finetune.value = finetune_value
1634+
file_checkpoint_train.value = file_checkpoint_train_value
1635+
tokenizer_type.value = tokenizer_type_value
1636+
tokenizer_file.value = tokenizer_file_value
1637+
mixed_precision.value = mixed_precision_value
1638+
cd_logger.value = logger_value
1639+
ch_8bit_adam.value = bnb_optimizer_value
16031640

16041641
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
16051642
txt_info_train = gr.Text(label="Info", value="")

0 commit comments

Comments
 (0)