Skip to content

Commit 129014c

Browse files
authored
Merge pull request #735 from DDXDB/main
Added intel XPU support
2 parents 9e51878 + 6fb913f commit 129014c

File tree

7 files changed

+69
-6
lines changed

7 files changed

+69
-6
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https:/
3232

3333
# AMD GPU: install pytorch with your ROCm version, e.g.
3434
pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
35+
36+
# intel GPU: install pytorch with your XPU version, e.g.
37+
# Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
38+
pip install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
3539
```
3640

3741
Then you can choose from a few options below:

src/f5_tts/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,15 @@ def __init__(
4747
else:
4848
import torch
4949

50-
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
50+
self.device = (
51+
"cuda"
52+
if torch.cuda.is_available()
53+
else "xpu"
54+
if torch.xpu.is_available()
55+
else "mps"
56+
if torch.backends.mps.is_available()
57+
else "cpu"
58+
)
5159

5260
# Load models
5361
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)

src/f5_tts/eval/eval_utmos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def main():
1313
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
1414
args = parser.parse_args()
1515

16-
device = "cuda" if torch.cuda.is_available() else "cpu"
16+
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
1717

1818
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
1919
predictor = predictor.to(device)

src/f5_tts/infer/speech_edit.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from f5_tts.model import CFM, DiT, UNetT
1111
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
1212

13-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
13+
device = (
14+
"cuda"
15+
if torch.cuda.is_available()
16+
else "xpu"
17+
if torch.xpu.is_available()
18+
else "mps"
19+
if torch.backends.mps.is_available()
20+
else "cpu"
21+
)
1422

1523

1624
# --------------------- Dataset Settings -------------------- #

src/f5_tts/infer/utils_infer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,15 @@
3333

3434
_ref_audio_cache = {}
3535

36-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
36+
device = (
37+
"cuda"
38+
if torch.cuda.is_available()
39+
else "xpu"
40+
if torch.xpu.is_available()
41+
else "mps"
42+
if torch.backends.mps.is_available()
43+
else "cpu"
44+
)
3745

3846
# -----------------------------------------
3947

src/f5_tts/socket_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
class TTSStreamingProcessor:
1818
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
1919
self.device = device or (
20-
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
20+
"cuda"
21+
if torch.cuda.is_available()
22+
else "xpu"
23+
if torch.xpu.is_available()
24+
else "mps"
25+
if torch.backends.mps.is_available()
26+
else "cpu"
2127
)
2228

2329
# Load the model using the provided checkpoint and vocab files

src/f5_tts/train/finetune_gradio.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@
4646
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
4747
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
4848

49-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
49+
device = (
50+
"cuda"
51+
if torch.cuda.is_available()
52+
else "xpu"
53+
if torch.xpu.is_available()
54+
else "mps"
55+
if torch.backends.mps.is_available()
56+
else "cpu"
57+
)
5058

5159

5260
# Save settings from a JSON file
@@ -889,6 +897,13 @@ def calculate_train(
889897
gpu_properties = torch.cuda.get_device_properties(i)
890898
total_memory += gpu_properties.total_memory / (1024**3) # in GB
891899

900+
elif torch.xpu.is_available():
901+
gpu_count = torch.xpu.device_count()
902+
total_memory = 0
903+
for i in range(gpu_count):
904+
gpu_properties = torch.xpu.get_device_properties(i)
905+
total_memory += gpu_properties.total_memory / (1024**3)
906+
892907
elif torch.backends.mps.is_available():
893908
gpu_count = 1
894909
total_memory = psutil.virtual_memory().available / (1024**3)
@@ -1284,7 +1299,21 @@ def get_gpu_stats():
12841299
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
12851300
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
12861301
)
1302+
elif torch.xpu.is_available():
1303+
gpu_count = torch.xpu.device_count()
1304+
for i in range(gpu_count):
1305+
gpu_name = torch.xpu.get_device_name(i)
1306+
gpu_properties = torch.xpu.get_device_properties(i)
1307+
total_memory = gpu_properties.total_memory / (1024**3) # in GB
1308+
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
1309+
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
12871310

1311+
gpu_stats += (
1312+
f"GPU {i} Name: {gpu_name}\n"
1313+
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
1314+
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1315+
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1316+
)
12881317
elif torch.backends.mps.is_available():
12891318
gpu_count = 1
12901319
gpu_stats += "MPS GPU\n"

0 commit comments

Comments
 (0)