From ae9f499b47825a8af21c60fd69c62ce159155389 Mon Sep 17 00:00:00 2001 From: linsalrob Date: Tue, 5 May 2026 17:05:54 +0800 Subject: [PATCH] adding xpu support --- src/phold/features/autotune.py | 10 +++++++++- src/phold/features/predict_3Di.py | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/phold/features/autotune.py b/src/phold/features/autotune.py index 5bf6bd6..f8cfe08 100644 --- a/src/phold/features/autotune.py +++ b/src/phold/features/autotune.py @@ -46,10 +46,14 @@ def sample_probe_sequences(seqs, n=5000, seed=0): def device_synchronize(device: torch.device): + # I think this whole block can be replacesd with + # torch.accelerator.synchronize(device) if device.type == "cuda": torch.cuda.synchronize(device) + elif device.type == 'xpu': + torch.xpu.synchronize(device) elif device.type == "mps": - torch.mps.synchronize() + torch.mps.synchronize(device) # CPU and others: no-op def autotune_batching_real_data( @@ -80,6 +84,10 @@ def autotune_batching_real_data( # check for NVIDIA/cuda if torch.cuda.is_available(): device = torch.device("cuda:0") + # check for intel xpu + elif torch.xpu.is_available(): + device = torch.device("xpu:0") + dev_name = "xpu" # check for apple silicon/metal elif torch.backends.mps.is_available(): device = torch.device("mps") diff --git a/src/phold/features/predict_3Di.py b/src/phold/features/predict_3Di.py index e7fbcc6..fcb5444 100644 --- a/src/phold/features/predict_3Di.py +++ b/src/phold/features/predict_3Di.py @@ -108,6 +108,10 @@ def get_T5_model( if torch.cuda.is_available(): device = torch.device("cuda:0") dev_name = "cuda:0" + # check for intel xpu + elif torch.xpu.is_available(): + device = torch.device("xpu:0") + dev_name = "xpu" # check for apple silicon/metal elif torch.backends.mps.is_available(): device = torch.device("mps")