Skip to content

Commit 0a2dd86

Browse files
authored
MultiGPU Work Units For Accelerated Sampling (CORE-184) (Comfy-Org#7063)
1 parent 04879a8 commit 0a2dd86

16 files changed

Lines changed: 1683 additions & 248 deletions

comfy/cli_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __call__(self, parser, namespace, values, option_string=None):
4949
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
5050
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
5151
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52-
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
52+
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
5353
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
5454
cm_group = parser.add_mutually_exclusive_group()
5555
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

comfy/controlnet.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
You should have received a copy of the GNU General Public License
1616
along with this program. If not, see <https://www.gnu.org/licenses/>.
1717
"""
18-
18+
from __future__ import annotations
1919

2020
import torch
2121
from enum import Enum
2222
import math
2323
import os
2424
import logging
25+
import copy
2526
import comfy.utils
2627
import comfy.model_management
2728
import comfy.model_detection
@@ -38,7 +39,7 @@
3839
import comfy.ldm.flux.controlnet
3940
import comfy.ldm.qwen_image.controlnet
4041
import comfy.cldm.dit_embedder
41-
from typing import TYPE_CHECKING
42+
from typing import TYPE_CHECKING, Union
4243
if TYPE_CHECKING:
4344
from comfy.hooks import HookGroup
4445

@@ -64,6 +65,18 @@ class StrengthType(Enum):
6465
CONSTANT = 1
6566
LINEAR_UP = 2
6667

68+
class ControlIsolation:
69+
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
70+
def __init__(self, control: ControlBase):
71+
self.control = control
72+
self.orig_previous_controlnet = control.previous_controlnet
73+
74+
def __enter__(self):
75+
self.control.previous_controlnet = None
76+
77+
def __exit__(self, *args):
78+
self.control.previous_controlnet = self.orig_previous_controlnet
79+
6780
class ControlBase:
6881
def __init__(self):
6982
self.cond_hint_original = None
@@ -77,14 +90,15 @@ def __init__(self):
7790
self.compression_ratio = 8
7891
self.upscale_algorithm = 'nearest-exact'
7992
self.extra_args = {}
80-
self.previous_controlnet = None
93+
self.previous_controlnet: Union[ControlBase, None] = None
8194
self.extra_conds = []
8295
self.strength_type = StrengthType.CONSTANT
8396
self.concat_mask = False
8497
self.extra_concat_orig = []
8598
self.extra_concat = None
8699
self.extra_hooks: HookGroup = None
87100
self.preprocess_image = lambda a: a
101+
self.multigpu_clones: dict[torch.device, ControlBase] = {}
88102

89103
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
90104
self.cond_hint_original = cond_hint
@@ -111,17 +125,38 @@ def set_previous_controlnet(self, controlnet):
111125
def cleanup(self):
112126
if self.previous_controlnet is not None:
113127
self.previous_controlnet.cleanup()
114-
128+
for device_cnet in self.multigpu_clones.values():
129+
with ControlIsolation(device_cnet):
130+
device_cnet.cleanup()
115131
self.cond_hint = None
116132
self.extra_concat = None
117133
self.timestep_range = None
118134

119135
def get_models(self):
120136
out = []
137+
for device_cnet in self.multigpu_clones.values():
138+
out += device_cnet.get_models_only_self()
121139
if self.previous_controlnet is not None:
122140
out += self.previous_controlnet.get_models()
123141
return out
124142

143+
def get_models_only_self(self):
144+
'Calls get_models, but temporarily sets previous_controlnet to None.'
145+
with ControlIsolation(self):
146+
return self.get_models()
147+
148+
def get_instance_for_device(self, device):
149+
'Returns instance of this Control object intended for selected device.'
150+
return self.multigpu_clones.get(device, self)
151+
152+
def deepclone_multigpu(self, load_device, autoregister=False):
153+
'''
154+
Create deep clone of Control object where model(s) is set to other devices.
155+
156+
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
157+
'''
158+
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
159+
125160
def get_extra_hooks(self):
126161
out = []
127162
if self.extra_hooks is not None:
@@ -130,7 +165,7 @@ def get_extra_hooks(self):
130165
out += self.previous_controlnet.get_extra_hooks()
131166
return out
132167

133-
def copy_to(self, c):
168+
def copy_to(self, c: ControlBase):
134169
c.cond_hint_original = self.cond_hint_original
135170
c.strength = self.strength
136171
c.timestep_percent_range = self.timestep_percent_range
@@ -284,6 +319,14 @@ def copy(self):
284319
self.copy_to(c)
285320
return c
286321

322+
def deepclone_multigpu(self, load_device, autoregister=False):
323+
c = self.copy()
324+
c.control_model = copy.deepcopy(c.control_model)
325+
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
326+
if autoregister:
327+
self.multigpu_clones[load_device] = c
328+
return c
329+
287330
def get_models(self):
288331
out = super().get_models()
289332
out.append(self.control_model_wrapped)
@@ -314,6 +357,10 @@ def pre_run(self, model, percent_to_timestep_function):
314357
super().pre_run(model, percent_to_timestep_function)
315358
self.set_extra_arg("base_model", model.diffusion_model)
316359

360+
def cleanup(self):
361+
self.extra_args.pop("base_model", None)
362+
super().cleanup()
363+
317364
def copy(self):
318365
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
319366
c.control_model = self.control_model
@@ -906,6 +953,14 @@ def copy(self):
906953
self.copy_to(c)
907954
return c
908955

956+
def deepclone_multigpu(self, load_device, autoregister=False):
957+
c = self.copy()
958+
c.t2i_model = copy.deepcopy(c.t2i_model)
959+
c.device = load_device
960+
if autoregister:
961+
self.multigpu_clones[load_device] = c
962+
return c
963+
909964
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
910965
compression_ratio = 8
911966
upscale_algorithm = 'nearest-exact'

comfy/ldm/hunyuan3dv2_1/hunyuandit.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,13 @@ def __init__(
607607
def forward(self, x, t, context, transformer_options = {}, **kwargs):
608608

609609
x = x.movedim(-1, -2)
610-
if context.shape[0] >= 2:
611-
uncond_emb, cond_emb = context.chunk(2, dim = 0)
612-
context = torch.cat([cond_emb, uncond_emb], dim = 0)
610+
611+
swap_cfg_halves = context.shape[0] >= 2
612+
613+
if swap_cfg_halves:
614+
first_half, second_half = context.chunk(2, dim = 0)
615+
context = torch.cat([second_half, first_half], dim = 0)
616+
613617
main_condition = context
614618

615619
t = 1.0 - t
@@ -657,8 +661,8 @@ def block_wrap(args):
657661
output = self.final_layer(combined)
658662
output = output.movedim(-2, -1) * (-1.0)
659663

660-
if output.shape[0] >= 2:
661-
cond_emb, uncond_emb = output.chunk(2, dim = 0)
662-
return torch.cat([uncond_emb, cond_emb])
663-
else:
664-
return output
664+
if swap_cfg_halves:
665+
first_half, second_half = output.chunk(2, dim = 0)
666+
output = torch.cat([second_half, first_half], dim = 0)
667+
668+
return output

comfy/memory_management.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import math
22
import ctypes
3-
import threading
43
import dataclasses
54
import torch
65
from typing import NamedTuple
@@ -10,7 +9,7 @@
109

1110
class TensorFileSlice(NamedTuple):
1211
file_ref: object
13-
thread_id: int
12+
lock: object
1413
offset: int
1514
size: int
1615

@@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
4342
file_obj = info.file_ref
4443
if (destination.device.type != "cpu"
4544
or file_obj is None
46-
or threading.get_ident() != info.thread_id
4745
or destination.numel() * destination.element_size() < info.size
4846
or tensor.numel() * tensor.element_size() != info.size
4947
or tensor.storage_offset() != 0
@@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
5755
if hostbuf is not None:
5856
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
5957
device_ptr = destination2.data_ptr() if destination2 is not None else 0
60-
hostbuf.read_file_slice(file_obj, info.offset, info.size,
61-
offset=destination.data_ptr() - hostbuf.get_raw_address(),
62-
stream=stream_ptr,
63-
device_ptr=device_ptr,
64-
device=None if destination2 is None else destination2.device.index)
58+
with info.lock:
59+
hostbuf.read_file_slice(file_obj, info.offset, info.size,
60+
offset=destination.data_ptr() - hostbuf.get_raw_address(),
61+
stream=stream_ptr,
62+
device_ptr=device_ptr,
63+
device=None if destination2 is None else destination2.device.index)
6564
return True
6665

6766
buf_type = ctypes.c_ubyte * info.size
6867
view = memoryview(buf_type.from_address(destination.data_ptr()))
6968

7069
try:
71-
file_obj.seek(info.offset)
72-
done = 0
73-
while done < info.size:
74-
try:
75-
n = file_obj.readinto(view[done:])
76-
except OSError:
77-
return False
78-
if n <= 0:
79-
return False
80-
done += n
70+
with info.lock:
71+
file_obj.seek(info.offset)
72+
done = 0
73+
while done < info.size:
74+
try:
75+
n = file_obj.readinto(view[done:])
76+
except OSError:
77+
return False
78+
if n <= 0:
79+
return False
80+
done += n
8181
return True
8282
finally:
8383
view.release()

0 commit comments

Comments
 (0)