44import secrets
55import string
66import sys
7+ from json import loads
78
89import torch
910import yaml
1011from typing import List , Union , Optional , Dict , Type
11- from pydantic import BaseModel , Field , ValidationError
12+ from pydantic import BaseModel , Field , ValidationError , field_validator
1213
1314from contants import ModelType
1415
@@ -134,40 +135,35 @@ class ResourcePathsConfig(BaseModel):
134135
135136
136137class BaseModelConfig (BaseModel ):
137- model_type : str
138+ model_type : Optional [ str ]
138139
139140 class Config :
140141 protected_namespaces = ()
141142
142143
143144class VITSModelConfig (BaseModelConfig ):
144- model_type : str = ModelType .VITS
145145 vits_path : str = None
146146 config_path : str = None
147147
148148
149149class W2V2VITSModelConfig (BaseModelConfig ):
150- model_type : str = ModelType .W2V2_VITS
151150 vits_path : str = None
152151 config_path : str = None
153152
154153
155154class HuBertVITSModelConfig (BaseModelConfig ):
156- model_type : str = ModelType .HUBERT_VITS
157155 vits_path : str = None
158156 config_path : str = None
159157
160158
161159class BertVITS2ModelConfig (BaseModelConfig ):
162- model_type : str = ModelType .BERT_VITS2
163160 vits_path : str = None
164161 config_path : str = None
165162
166163
167164class GPTSoVITSModelConfig (BaseModelConfig ):
168- model_type : str = ModelType .GPT_SOVITS
169- gpt_path : str = None
170- sovits_path : str = None
165+ vits_path : str = None
166+ t2s_path : str = None
171167
172168
173169MODEL_TYPE_MAP : Dict [str , Type [BaseModelConfig ]] = {
@@ -190,6 +186,41 @@ class TTSModelConfig(BaseModel):
190186 GPTSoVITSModelConfig ,
191187 ]] = Field (default_factory = list )
192188
189+ @classmethod
190+ def recognition_model_type_by_config (self , config : dict ) -> str :
191+ symbols = config .get ("symbols" , None )
192+ emotion_embedding = config ["data" ].get ("emotion_embedding" , False )
193+
194+ if "use_spk_conditioned_encoder" in config ["model" ]:
195+ model_type = ModelType .BERT_VITS2
196+ return model_type
197+
198+ if symbols != None :
199+ if not emotion_embedding :
200+ mode_type = ModelType .VITS
201+ else :
202+ mode_type = ModelType .W2V2_VITS
203+ else :
204+ mode_type = ModelType .HUBERT_VITS
205+
206+ return mode_type
207+
208+ @field_validator ('tts_models' , mode = "before" )
209+ def infer_model_type (cls , v ):
210+ result = []
211+ for model in v :
212+ if 'model_type' not in model :
213+ if 'vits_path' in model and 'config_path' in model :
214+ with open (model ["config_path" ], 'r' , encoding = 'utf-8' ) as f :
215+ data = loads (f .read ())
216+ model ['model_type' ] = cls .recognition_model_type_by_config (data )
217+ elif 'vits_path' in model and 't2s_path' in model :
218+ model ['model_type' ] = ModelType .GPT_SOVITS
219+
220+ model_class = MODEL_TYPE_MAP [model ['model_type' ]]
221+ result .append (model_class (** model ))
222+ return result
223+
193224 def add_model (self , model_config : BaseModelConfig ):
194225 if not isinstance (model_config , BaseModelConfig ):
195226 raise TypeError ("model_config must be an instance of BaseModelConfig" )
@@ -205,8 +236,16 @@ def update_tts_models(self, tts_models: list):
205236 for item in tts_models :
206237 tts_model = item ["tts_model" ]
207238 model_type = tts_model .get ("model_type" )
239+
208240 if model_type :
209241 model_type = model_type .upper ().replace ("_" , "-" )
242+ else :
243+ if tts_model .get ("t2s_path" ):
244+ model_type = ModelType .GPT_SOVITS
245+ else :
246+ with open (tts_model ["config_path" ], 'r' , encoding = 'utf-8' ) as f :
247+ data = f .read ()
248+ model_type = self .recognition_model_type (loads (data ))
210249 model_class = MODEL_TYPE_MAP .get (ModelType (model_type ))
211250 if model_class is not None :
212251 try :
0 commit comments