|
46 | 46 | path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) |
47 | 47 | file_train = str(files("f5_tts").joinpath("train/finetune_cli.py")) |
48 | 48 |
|
49 | | -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| 49 | +device = ( |
| 50 | + "cuda" |
| 51 | + if torch.cuda.is_available() |
| 52 | + else "xpu" |
| 53 | + if torch.xpu.is_available() |
| 54 | + else "mps" |
| 55 | + if torch.backends.mps.is_available() |
| 56 | + else "cpu" |
| 57 | +) |
50 | 58 |
|
51 | 59 |
|
52 | 60 | # Save settings from a JSON file |
@@ -889,6 +897,13 @@ def calculate_train( |
889 | 897 | gpu_properties = torch.cuda.get_device_properties(i) |
890 | 898 | total_memory += gpu_properties.total_memory / (1024**3) # in GB |
891 | 899 |
|
| 900 | + elif torch.xpu.is_available(): |
| 901 | + gpu_count = torch.xpu.device_count() |
| 902 | + total_memory = 0 |
| 903 | + for i in range(gpu_count): |
| 904 | + gpu_properties = torch.xpu.get_device_properties(i) |
| 905 | + total_memory += gpu_properties.total_memory / (1024**3) |
| 906 | + |
892 | 907 | elif torch.backends.mps.is_available(): |
893 | 908 | gpu_count = 1 |
894 | 909 | total_memory = psutil.virtual_memory().available / (1024**3) |
@@ -1284,7 +1299,21 @@ def get_gpu_stats(): |
1284 | 1299 | f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" |
1285 | 1300 | f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" |
1286 | 1301 | ) |
| 1302 | + elif torch.xpu.is_available(): |
| 1303 | + gpu_count = torch.xpu.device_count() |
| 1304 | + for i in range(gpu_count): |
| 1305 | + gpu_name = torch.xpu.get_device_name(i) |
| 1306 | + gpu_properties = torch.xpu.get_device_properties(i) |
| 1307 | + total_memory = gpu_properties.total_memory / (1024**3) # in GB |
| 1308 | + allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB |
| 1309 | + reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB |
1287 | 1310 |
|
| 1311 | + gpu_stats += ( |
| 1312 | + f"GPU {i} Name: {gpu_name}\n" |
| 1313 | + f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n" |
| 1314 | + f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" |
| 1315 | + f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" |
| 1316 | + ) |
1288 | 1317 | elif torch.backends.mps.is_available(): |
1289 | 1318 | gpu_count = 1 |
1290 | 1319 | gpu_stats += "MPS GPU\n" |
|
0 commit comments