Skip to content

Commit 00431b2

Browse files
committed
Update: VITS dynamic_loading
1 parent 96e4787 commit 00431b2

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class Config:
144144
class VITSModelConfig(BaseModelConfig):
145145
vits_path: str = None
146146
config_path: str = None
147+
dynamic_loading: Optional[bool] = False
147148

148149

149150
class W2V2VITSModelConfig(BaseModelConfig):

manager/ModelManager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ def _load_model_from_path(self, tts_model):
246246
"config": hps,
247247
"device": self.device
248248
}
249+
if model_type == ModelType.VITS:
250+
model_args["dynamic_loading"] = tts_model["dynamic_loading"]
249251

250252
model_class = self.model_class_map[model_type]
251253
model = model_class(**model_args)
@@ -255,7 +257,7 @@ def _load_model_from_path(self, tts_model):
255257
if bert_embedding and self.tts_front is None:
256258
self.load_VITS_PinYin_model(
257259
os.path.join(BASE_DIR, config.system.data_path, config.resource_paths_config.vits_chinese_bert))
258-
if not config.vits_config.dynamic_loading:
260+
if not model.dynamic_loading:
259261
model.load_model()
260262
self.available_tts_model.add(ModelType.VITS)
261263

vits/vits.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class VITS:
12-
def __init__(self, vits_path, config, device="cpu", **kwargs):
12+
def __init__(self, vits_path, config, device="cpu", dynamic_loading=False, **kwargs):
1313
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
1414
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
1515
self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
@@ -23,6 +23,7 @@ def __init__(self, vits_path, config, device="cpu", **kwargs):
2323
self.sampling_rate = self.hps_ms.data.sampling_rate
2424
self.device = torch.device(device)
2525
self.vits_path = vits_path
26+
self.dynamic_loading = dynamic_loading
2627

2728
# load checkpoint
2829
# self.load_model()
@@ -39,10 +40,9 @@ def load_model(self):
3940
_ = self.net_g_ms.eval()
4041
utils.load_checkpoint(self.vits_path, self.net_g_ms)
4142
self.net_g_ms.to(self.device)
42-
43+
4344
def release_model(self):
4445
del self.net_g_ms
45-
4646

4747
def get_cleaned_text(self, text, hps, cleaned=False):
4848
if cleaned:

0 commit comments

Comments
 (0)