@@ -48,17 +48,14 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
4848 return io .NodeOutput (model )
4949
5050
51- def _force_fp32_cpu_compute (patcher : ModelPatcher ):
52- """Force fp32 inference dtype for CPU.
53-
54- PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
55- and run ~500-600x slower than fp32, which makes a normal-sized workflow
56- look frozen for hours. Routing through set_model_compute_dtype leaves the
57- weights as-is and casts at use, so peak memory does not blow up."""
58- dtype = patcher .model_dtype ()
59- if dtype in (torch .float16 , torch .bfloat16 ):
60- logging .info (f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was { dtype } )." )
61- patcher .set_model_compute_dtype (torch .float32 )
51+ def _force_supported_compute_dtype (patcher : ModelPatcher , device : torch .device ):
52+ """Cast compute dtype to one the device supports; no-op if already supported."""
53+ weight_dtype = patcher .model_dtype ()
54+ cast_dtype = comfy .model_management .unet_manual_cast (weight_dtype , device )
55+ if cast_dtype is None :
56+ return
57+ logging .info (f"Select Model Device: using { cast_dtype } compute dtype on { device } (model weight dtype was { weight_dtype } )." )
58+ patcher .set_model_compute_dtype (cast_dtype )
6259
6360
6461def _remember_base_devices (patcher : ModelPatcher ):
@@ -229,8 +226,7 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
229226 logging .warning (f"Select Model Device: cannot retarget model, passing through unchanged. ({ e } )" )
230227 return io .NodeOutput (model )
231228 if resolved is not None :
232- if resolved .type == "cpu" :
233- _force_fp32_cpu_compute (model )
229+ _force_supported_compute_dtype (model , resolved )
234230 _prune_multigpu_collision (model , model .load_device )
235231 return io .NodeOutput (model )
236232
0 commit comments