Skip to content

Commit 4b6d9a5

Browse files
committed
Support Iluvatar CoreX
1 parent 8f05fb4 commit 4b6d9a5

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
293293
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
294294
3. Launch ComfyUI by running `python main.py`
295295

296+
#### Iluvatar Corex
297+
298+
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
299+
300+
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
301+
2. Because of not support cudaMallocAsync, please launch ComfyUI by running `python main.py --disable-cuda-malloc`
302+
296303
# Running
297304

298305
```python main.py```

comfy/model_management.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def get_supported_float8_types():
128128
except:
129129
mlu_available = False
130130

131+
try:
132+
ixuca_available = hasattr(torch, "corex")
133+
except:
134+
ixuca_available = False
135+
131136
if args.cpu:
132137
cpu_state = CPUState.CPU
133138

@@ -151,6 +156,12 @@ def is_mlu():
151156
return True
152157
return False
153158

159+
def is_ixuca():
160+
global ixuca_available
161+
if ixuca_available:
162+
return True
163+
return False
164+
154165
def get_torch_device():
155166
global directml_enabled
156167
global cpu_state
@@ -288,7 +299,7 @@ def is_amd():
288299
if torch_version_numeric[0] >= 2:
289300
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
290301
ENABLE_PYTORCH_ATTENTION = True
291-
if is_intel_xpu() or is_ascend_npu() or is_mlu():
302+
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
292303
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
293304
ENABLE_PYTORCH_ATTENTION = True
294305
except:
@@ -1027,6 +1038,8 @@ def xformers_enabled():
10271038
return False
10281039
if is_mlu():
10291040
return False
1041+
if is_ixuca():
1042+
return False
10301043
if directml_enabled:
10311044
return False
10321045
return XFORMERS_IS_AVAILABLE
@@ -1062,6 +1075,8 @@ def pytorch_attention_flash_attention():
10621075
return True
10631076
if is_amd():
10641077
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
1078+
if is_ixuca():
1079+
return True
10651080
return False
10661081

10671082
def force_upcast_attention_dtype():
@@ -1181,6 +1196,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
11811196
if is_mlu():
11821197
return True
11831198

1199+
if is_ixuca():
1200+
return True
1201+
11841202
if torch.version.hip:
11851203
return True
11861204

@@ -1241,6 +1259,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
12411259
if is_ascend_npu():
12421260
return True
12431261

1262+
if is_ixuca():
1263+
return True
1264+
12441265
if is_amd():
12451266
arch = torch.cuda.get_device_properties(device).gcnArchName
12461267
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16

0 commit comments

Comments
 (0)