@@ -116,41 +116,57 @@ def load_settings(project_name):
116116 path_project = os .path .join (path_project_ckpts , project_name )
117117 file_setting = os .path .join (path_project , "setting.json" )
118118
119- if not os .path .isfile (file_setting ):
120- settings = {
121- "exp_name" : "F5TTS_Base" ,
122- "learning_rate" : 1e-05 ,
123- "batch_size_per_gpu" : 1000 ,
124- "batch_size_type" : "frame" ,
125- "max_samples" : 64 ,
126- "grad_accumulation_steps" : 1 ,
127- "max_grad_norm" : 1 ,
128- "epochs" : 100 ,
129- "num_warmup_updates" : 2 ,
130- "save_per_updates" : 300 ,
131- "keep_last_n_checkpoints" : - 1 ,
132- "last_per_updates" : 100 ,
133- "finetune" : True ,
134- "file_checkpoint_train" : "" ,
135- "tokenizer_type" : "pinyin" ,
136- "tokenizer_file" : "" ,
137- "mixed_precision" : "none" ,
138- "logger" : "wandb" ,
139- "bnb_optimizer" : False ,
140- }
141- else :
119+ # Default settings
120+ default_settings = {
121+ "exp_name" : "F5TTS_Base" ,
122+ "learning_rate" : 1e-05 ,
123+ "batch_size_per_gpu" : 1000 ,
124+ "batch_size_type" : "frame" ,
125+ "max_samples" : 64 ,
126+ "grad_accumulation_steps" : 1 ,
127+ "max_grad_norm" : 1 ,
128+ "epochs" : 100 ,
129+ "num_warmup_updates" : 2 ,
130+ "save_per_updates" : 300 ,
131+ "keep_last_n_checkpoints" : - 1 ,
132+ "last_per_updates" : 100 ,
133+ "finetune" : True ,
134+ "file_checkpoint_train" : "" ,
135+ "tokenizer_type" : "pinyin" ,
136+ "tokenizer_file" : "" ,
137+ "mixed_precision" : "none" ,
138+ "logger" : "wandb" ,
139+ "bnb_optimizer" : False ,
140+ }
141+
142+ # Load settings from file if it exists
143+ if os .path .isfile (file_setting ):
142144 with open (file_setting , "r" ) as f :
143- settings = json .load (f )
144- if "logger" not in settings :
145- settings ["logger" ] = "wandb"
146- if "bnb_optimizer" not in settings :
147- settings ["bnb_optimizer" ] = False
148- if "keep_last_n_checkpoints" not in settings :
149- settings ["keep_last_n_checkpoints" ] = - 1 # default to keep all checkpoints
150- if "last_per_updates" not in settings : # patch for backward compatibility, with before f992c4e
151- settings ["last_per_updates" ] = settings ["last_per_steps" ] // settings ["grad_accumulation_steps" ]
145+ file_settings = json .load (f )
146+ default_settings .update (file_settings )
152147
153- return settings
148+ # Return as a tuple in the correct order
149+ return (
150+ default_settings ["exp_name" ],
151+ default_settings ["learning_rate" ],
152+ default_settings ["batch_size_per_gpu" ],
153+ default_settings ["batch_size_type" ],
154+ default_settings ["max_samples" ],
155+ default_settings ["grad_accumulation_steps" ],
156+ default_settings ["max_grad_norm" ],
157+ default_settings ["epochs" ],
158+ default_settings ["num_warmup_updates" ],
159+ default_settings ["save_per_updates" ],
160+ default_settings ["keep_last_n_checkpoints" ],
161+ default_settings ["last_per_updates" ],
162+ default_settings ["finetune" ],
163+ default_settings ["file_checkpoint_train" ],
164+ default_settings ["tokenizer_type" ],
165+ default_settings ["tokenizer_file" ],
166+ default_settings ["mixed_precision" ],
167+ default_settings ["logger" ],
168+ default_settings ["bnb_optimizer" ],
169+ )
154170
155171
156172# Load metadata
@@ -1579,27 +1595,48 @@ def get_audio_select(file_sample):
15791595 stop_button = gr .Button ("Stop Training" , interactive = False )
15801596
15811597 if projects_selelect is not None :
1582- settings = load_settings (projects_selelect )
1583-
1584- exp_name .value = settings ["exp_name" ]
1585- learning_rate .value = settings ["learning_rate" ]
1586- batch_size_per_gpu .value = settings ["batch_size_per_gpu" ]
1587- batch_size_type .value = settings ["batch_size_type" ]
1588- max_samples .value = settings ["max_samples" ]
1589- grad_accumulation_steps .value = settings ["grad_accumulation_steps" ]
1590- max_grad_norm .value = settings ["max_grad_norm" ]
1591- epochs .value = settings ["epochs" ]
1592- num_warmup_updates .value = settings ["num_warmup_updates" ]
1593- save_per_updates .value = settings ["save_per_updates" ]
1594- keep_last_n_checkpoints .value = settings ["keep_last_n_checkpoints" ]
1595- last_per_updates .value = settings ["last_per_updates" ]
1596- ch_finetune .value = settings ["finetune" ]
1597- file_checkpoint_train .value = settings ["file_checkpoint_train" ]
1598- tokenizer_type .value = settings ["tokenizer_type" ]
1599- tokenizer_file .value = settings ["tokenizer_file" ]
1600- mixed_precision .value = settings ["mixed_precision" ]
1601- cd_logger .value = settings ["logger" ]
1602- ch_8bit_adam .value = settings ["bnb_optimizer" ]
1598+ (
1599+ exp_name_value ,
1600+ learning_rate_value ,
1601+ batch_size_per_gpu_value ,
1602+ batch_size_type_value ,
1603+ max_samples_value ,
1604+ grad_accumulation_steps_value ,
1605+ max_grad_norm_value ,
1606+ epochs_value ,
1607+ num_warmup_updates_value ,
1608+ save_per_updates_value ,
1609+ keep_last_n_checkpoints_value ,
1610+ last_per_updates_value ,
1611+ finetune_value ,
1612+ file_checkpoint_train_value ,
1613+ tokenizer_type_value ,
1614+ tokenizer_file_value ,
1615+ mixed_precision_value ,
1616+ logger_value ,
1617+ bnb_optimizer_value ,
1618+ ) = load_settings (projects_selelect )
1619+
1620+ # Assigning values to the respective components
1621+ exp_name .value = exp_name_value
1622+ learning_rate .value = learning_rate_value
1623+ batch_size_per_gpu .value = batch_size_per_gpu_value
1624+ batch_size_type .value = batch_size_type_value
1625+ max_samples .value = max_samples_value
1626+ grad_accumulation_steps .value = grad_accumulation_steps_value
1627+ max_grad_norm .value = max_grad_norm_value
1628+ epochs .value = epochs_value
1629+ num_warmup_updates .value = num_warmup_updates_value
1630+ save_per_updates .value = save_per_updates_value
1631+ keep_last_n_checkpoints .value = keep_last_n_checkpoints_value
1632+ last_per_updates .value = last_per_updates_value
1633+ ch_finetune .value = finetune_value
1634+ file_checkpoint_train .value = file_checkpoint_train_value
1635+ tokenizer_type .value = tokenizer_type_value
1636+ tokenizer_file .value = tokenizer_file_value
1637+ mixed_precision .value = mixed_precision_value
1638+ cd_logger .value = logger_value
1639+ ch_8bit_adam .value = bnb_optimizer_value
16031640
16041641 ch_stream = gr .Checkbox (label = "Stream Output Experiment" , value = True )
16051642 txt_info_train = gr .Text (label = "Info" , value = "" )
0 commit comments