Skip to content

Commit 96e4787

Browse files
committed
Update: recognition_model_type
1 parent 01c1a5e commit 96e4787

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

config.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import secrets
55
import string
66
import sys
7+
from json import loads
78

89
import torch
910
import yaml
1011
from typing import List, Union, Optional, Dict, Type
11-
from pydantic import BaseModel, Field, ValidationError
12+
from pydantic import BaseModel, Field, ValidationError, field_validator
1213

1314
from contants import ModelType
1415

@@ -134,40 +135,35 @@ class ResourcePathsConfig(BaseModel):
134135

135136

136137
class BaseModelConfig(BaseModel):
137-
model_type: str
138+
model_type: Optional[str]
138139

139140
class Config:
140141
protected_namespaces = ()
141142

142143

143144
class VITSModelConfig(BaseModelConfig):
144-
model_type: str = ModelType.VITS
145145
vits_path: str = None
146146
config_path: str = None
147147

148148

149149
class W2V2VITSModelConfig(BaseModelConfig):
150-
model_type: str = ModelType.W2V2_VITS
151150
vits_path: str = None
152151
config_path: str = None
153152

154153

155154
class HuBertVITSModelConfig(BaseModelConfig):
156-
model_type: str = ModelType.HUBERT_VITS
157155
vits_path: str = None
158156
config_path: str = None
159157

160158

161159
class BertVITS2ModelConfig(BaseModelConfig):
162-
model_type: str = ModelType.BERT_VITS2
163160
vits_path: str = None
164161
config_path: str = None
165162

166163

167164
class GPTSoVITSModelConfig(BaseModelConfig):
168-
model_type: str = ModelType.GPT_SOVITS
169-
gpt_path: str = None
170-
sovits_path: str = None
165+
vits_path: str = None
166+
t2s_path: str = None
171167

172168

173169
MODEL_TYPE_MAP: Dict[str, Type[BaseModelConfig]] = {
@@ -190,6 +186,41 @@ class TTSModelConfig(BaseModel):
190186
GPTSoVITSModelConfig,
191187
]] = Field(default_factory=list)
192188

189+
@classmethod
190+
def recognition_model_type_by_config(self, config: dict) -> str:
191+
symbols = config.get("symbols", None)
192+
emotion_embedding = config["data"].get("emotion_embedding", False)
193+
194+
if "use_spk_conditioned_encoder" in config["model"]:
195+
model_type = ModelType.BERT_VITS2
196+
return model_type
197+
198+
if symbols != None:
199+
if not emotion_embedding:
200+
mode_type = ModelType.VITS
201+
else:
202+
mode_type = ModelType.W2V2_VITS
203+
else:
204+
mode_type = ModelType.HUBERT_VITS
205+
206+
return mode_type
207+
208+
@field_validator('tts_models', mode="before")
209+
def infer_model_type(cls, v):
210+
result = []
211+
for model in v:
212+
if 'model_type' not in model:
213+
if 'vits_path' in model and 'config_path' in model:
214+
with open(model["config_path"], 'r', encoding='utf-8') as f:
215+
data = loads(f.read())
216+
model['model_type'] = cls.recognition_model_type_by_config(data)
217+
elif 'vits_path' in model and 't2s_path' in model:
218+
model['model_type'] = ModelType.GPT_SOVITS
219+
220+
model_class = MODEL_TYPE_MAP[model['model_type']]
221+
result.append(model_class(**model))
222+
return result
223+
193224
def add_model(self, model_config: BaseModelConfig):
194225
if not isinstance(model_config, BaseModelConfig):
195226
raise TypeError("model_config must be an instance of BaseModelConfig")
@@ -205,8 +236,16 @@ def update_tts_models(self, tts_models: list):
205236
for item in tts_models:
206237
tts_model = item["tts_model"]
207238
model_type = tts_model.get("model_type")
239+
208240
if model_type:
209241
model_type = model_type.upper().replace("_", "-")
242+
else:
243+
if tts_model.get("t2s_path"):
244+
model_type = ModelType.GPT_SOVITS
245+
else:
246+
with open(tts_model["config_path"], 'r', encoding='utf-8') as f:
247+
data = f.read()
248+
model_type = self.recognition_model_type(loads(data))
210249
model_class = MODEL_TYPE_MAP.get(ModelType(model_type))
211250
if model_class is not None:
212251
try:

0 commit comments

Comments
 (0)