Skip to content

Commit 31ed623

Browse files
committed
feat(gpu): add partial dml support
1 parent 46204ca commit 31ed623

File tree

6 files changed

+30
-17
lines changed

6 files changed

+30
-17
lines changed

ChatTTS/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ def load(
141141
compile: bool = False,
142142
custom_path: Optional[FileLike] = None,
143143
device: Optional[torch.device] = None,
144-
coef: Optional[torch.Tensor] = None,
144+
coef: Optional[str] = None,
145145
use_flash_attn=False,
146146
use_vllm=False,
147147
experimental: bool = False,
148+
enable_cache=False,
148149
) -> bool:
149150
download_path = self.download_models(source, force_redownload, custom_path)
150151
if download_path is None:
@@ -156,6 +157,7 @@ def load(
156157
use_flash_attn=use_flash_attn,
157158
use_vllm=use_vllm,
158159
experimental=experimental,
160+
enable_cache=enable_cache,
159161
**{
160162
k: os.path.join(download_path, v)
161163
for k, v in asdict(self.config.path).items()
@@ -287,6 +289,7 @@ def _load(
287289
use_flash_attn=False,
288290
use_vllm=False,
289291
experimental: bool = False,
292+
enable_cache = False,
290293
):
291294
if device is None:
292295
device = select_device(experimental=experimental)
@@ -351,6 +354,7 @@ def _load(
351354
device=device,
352355
device_gpt=self.device_gpt,
353356
logger=self.logger,
357+
enable_cache=enable_cache,
354358
).eval()
355359
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
356360
gpt.load_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
@@ -425,6 +429,7 @@ def _infer(
425429
text_tokens = refined.ids
426430
text_tokens = [i[i.less(self.tokenizer.break_0_ids)] for i in text_tokens]
427431
text = self.tokenizer.decode(text_tokens)
432+
self.logger.debug("refined texts %s", str(text))
428433
refined.destroy()
429434
if refine_text_only:
430435
if split_text and isinstance(text, list):

ChatTTS/model/embed.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,27 @@ def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Ten
5454
get_emb
5555
"""
5656
device = next(self.parameters()).device
57+
input_ids_dev = input_ids.to(device)
58+
text_mask_dev = text_mask.to(device)
59+
5760
emb_text: torch.Tensor = self.emb_text(
58-
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device)
61+
input_ids_dev[text_mask_dev].narrow(1, 0, 1).squeeze_(1)
5962
)
6063

61-
text_mask_inv = text_mask.logical_not().to(device)
62-
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device)
64+
text_mask_inv = text_mask_dev.logical_not()
65+
masked_input_ids: torch.Tensor = input_ids_dev[text_mask_inv]
6366

6467
emb_code = [
6568
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
6669
]
6770
emb_code = torch.stack(emb_code, 2).sum(2)
6871

6972
emb = torch.zeros(
70-
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
73+
(input_ids_dev.shape[:-1]) + (emb_text.shape[-1],),
7174
device=emb_text.device,
7275
dtype=emb_text.dtype,
7376
)
74-
emb[text_mask] = emb_text
77+
emb[text_mask_dev] = emb_text
7578
emb[text_mask_inv] = emb_code.to(emb.dtype)
7679

7780
del emb_text, emb_code, text_mask_inv

ChatTTS/model/gpt.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
device=torch.device("cpu"),
2929
device_gpt=torch.device("cpu"),
3030
logger=logging.getLogger(__name__),
31+
enable_cache=False,
3132
):
3233
super().__init__()
3334

@@ -36,6 +37,8 @@ def __init__(
3637
self.device = device
3738
self.device_gpt = device_gpt
3839

40+
self.enable_cache = enable_cache
41+
3942
self.generator = torch.Generator(device=device)
4043

4144
self.num_vq = int(gpt_config["num_vq"])
@@ -142,7 +145,6 @@ def prepare(self, compile=False):
142145
class _GenerationInputs:
143146
position_ids: torch.Tensor
144147
cache_position: torch.Tensor
145-
use_cache: bool
146148
input_ids: Optional[torch.Tensor] = None
147149
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
148150
attention_mask: Optional[torch.Tensor] = None
@@ -167,7 +169,6 @@ def _prepare_generation_inputs(
167169
inputs_embeds: Optional[torch.Tensor] = None,
168170
cache_position: Optional[torch.Tensor] = None,
169171
position_ids: Optional[torch.Tensor] = None,
170-
use_cache=True,
171172
) -> _GenerationInputs:
172173
# With static cache, the `past_key_values` is None
173174
# TODO joao: standardize interface for the different Cache classes and remove of this if
@@ -230,8 +231,7 @@ def _prepare_generation_inputs(
230231
and attention_mask is not None
231232
and cache_length + input_ids.shape[1] > max_cache_length
232233
):
233-
start_pos = attention_mask.shape[1] - max_cache_length
234-
attention_mask = attention_mask.narrow(1, start_pos, max_cache_length)
234+
attention_mask = attention_mask.narrow(1, -max_cache_length, max_cache_length)
235235

236236
if attention_mask is not None and position_ids is None:
237237
# create position_ids on the fly for batch generation
@@ -258,7 +258,6 @@ def _prepare_generation_inputs(
258258
model_inputs = self._GenerationInputs(
259259
position_ids=position_ids,
260260
cache_position=cache_position,
261-
use_cache=use_cache,
262261
)
263262

264263
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
@@ -399,7 +398,6 @@ def generate(
399398
inputs_ids,
400399
past_key_values,
401400
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
402-
use_cache=not self.is_te_llama,
403401
)
404402

405403
if i > 0:
@@ -423,7 +421,7 @@ def generate(
423421
position_ids=model_input.position_ids,
424422
past_key_values=model_input.past_key_values,
425423
inputs_embeds=model_input.inputs_embeds,
426-
use_cache=model_input.use_cache,
424+
use_cache=not self.is_te_llama and self.enable_cache,
427425
output_attentions=return_attn,
428426
cache_position=model_input.cache_position,
429427
)

ChatTTS/utils/gpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib.util
2+
13
import torch
24

35
try:
@@ -43,6 +45,10 @@ def select_device(min_memory=2047, experimental=False):
4345
else:
4446
logger.get_logger().info("found Apple GPU, but use CPU.")
4547
device = torch.device("cpu")
48+
elif importlib.util.find_spec("torch_directml") is not None:
49+
import torch_directml
50+
51+
device = torch_directml.device(torch_directml.default_device())
4652
else:
4753
logger.get_logger().warning("no GPU or NPU found, use CPU instead")
4854
device = torch.device("cpu")

examples/web/funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def on_audio_seed_change(audio_seed_input):
6262
return rand_spk
6363

6464

65-
def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
65+
def load_chat(cust_path: Optional[str], coef: Optional[str], enable_cache = False) -> bool:
6666
if cust_path == None:
67-
ret = chat.load(coef=coef)
67+
ret = chat.load(coef=coef, enable_cache=enable_cache)
6868
else:
6969
logger.info("local model path: %s", cust_path)
70-
ret = chat.load("custom", custom_path=cust_path, coef=coef)
70+
ret = chat.load("custom", custom_path=cust_path, coef=coef, enable_cache=enable_cache)
7171
global custom_path
7272
custom_path = cust_path
7373
if ret:

examples/web/webui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,12 @@ def make_audio(autoplay, stream):
261261
parser.add_argument("--root_path", type=str, help="root path")
262262
parser.add_argument("--custom_path", type=str, help="custom model path")
263263
parser.add_argument("--coef", type=str, help="custom dvae coefficient")
264+
parser.add_argument("--enable_cache", action="store_true", help="enable model cache")
264265
args = parser.parse_args()
265266

266267
logger.info("loading ChatTTS model...")
267268

268-
if load_chat(args.custom_path, args.coef):
269+
if load_chat(args.custom_path, args.coef, args.enable_cache):
269270
logger.info("Models loaded successfully.")
270271
else:
271272
logger.error("Models load failed.")

0 commit comments

Comments
 (0)