Skip to content

Commit 88956e7

Browse files
authored
multigpu: use unet_manual_cast for SelectModelDevice compute dtype (Comfy-Org#14108)
1 parent da49b7d commit 88956e7

1 file changed

Lines changed: 9 additions & 13 deletions

File tree

comfy_extras/nodes_multigpu.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,14 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
4848
return io.NodeOutput(model)
4949

5050

51-
def _force_fp32_cpu_compute(patcher: ModelPatcher):
52-
"""Force fp32 inference dtype for CPU.
53-
54-
PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
55-
and run ~500-600x slower than fp32, which makes a normal-sized workflow
56-
look frozen for hours. Routing through set_model_compute_dtype leaves the
57-
weights as-is and casts at use, so peak memory does not blow up."""
58-
dtype = patcher.model_dtype()
59-
if dtype in (torch.float16, torch.bfloat16):
60-
logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
61-
patcher.set_model_compute_dtype(torch.float32)
51+
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
52+
"""Cast compute dtype to one the device supports; no-op if already supported."""
53+
weight_dtype = patcher.model_dtype()
54+
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
55+
if cast_dtype is None:
56+
return
57+
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
58+
patcher.set_model_compute_dtype(cast_dtype)
6259

6360

6461
def _remember_base_devices(patcher: ModelPatcher):
@@ -229,8 +226,7 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
229226
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
230227
return io.NodeOutput(model)
231228
if resolved is not None:
232-
if resolved.type == "cpu":
233-
_force_fp32_cpu_compute(model)
229+
_force_supported_compute_dtype(model, resolved)
234230
_prune_multigpu_collision(model, model.load_device)
235231
return io.NodeOutput(model)
236232

0 commit comments

Comments
 (0)