Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def from_string(cls, value: str):
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")

parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")

Expand Down
40 changes: 27 additions & 13 deletions comfy/memory_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from typing import NamedTuple

import comfy_aimdo.host_buffer
from comfy.quant_ops import QuantizedTensor


Expand All @@ -17,32 +18,34 @@ class TensorFileSlice(NamedTuple):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):

if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False

if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
if not read_tensor_file_slice_into(tensor._qdata,
destination._qdata if destination is not None else None, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False

dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination is not None:
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True

info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False

if destination is not None and destination.device.type != "cpu" and destination2 is None:
destination2 = destination
destination = None

file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or destination.numel() * destination.element_size() < info.size
if (file_obj is None
or (destination is None and destination2 is None)
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
Expand All @@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info.size == 0:
return True

if destination is None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
stream_ptr, destination2.data_ptr(),
destination2.device.index,
mark_cold=False)
return True

hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
Expand All @@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
device=None if destination2 is None else destination2.device.index)
return True

if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
return False

buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))

Expand Down
87 changes: 40 additions & 47 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,17 @@ def free_pins(size, evict_active=False):
return freed_total

def ensure_pin_budget(size, evict_active=False):
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if args.fast_disk:
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
else:
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
if shortfall <= 0:
return True

to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall

def ensure_pin_registerable(size, evict_active=False):
def ensure_pin_registerable(size, evict_active=True):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
Expand All @@ -658,10 +661,17 @@ def ensure_pin_registerable(size, evict_active=False):
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
if evict_active:
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS

class LoadedModel:
Expand Down Expand Up @@ -803,9 +813,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
Expand All @@ -817,6 +827,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))

if not for_dynamic and pins_required > 0:
ensure_pin_budget(pins_required)
ensure_pin_registerable(pins_required)

if len(unloaded_model) > 0:
soft_empty_cache()
elif device is not None:
Expand Down Expand Up @@ -879,15 +893,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model_finalizer.detach()

total_memory_required = {}
total_pins_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
if not loaded_model.model.is_dynamic():
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()

for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic)
for_dynamic=free_for_dynamic,
pins_required=total_pins_required.get(device, 0))

for device in total_memory_required:
if device != torch.device("cpu"):
Expand Down Expand Up @@ -1283,7 +1301,6 @@ def current_stream(device):
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}

DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3

Expand Down Expand Up @@ -1326,42 +1343,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer

def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer

def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True

def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT

LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
Expand All @@ -1370,20 +1358,24 @@ def reset_cast_buffers():
mmap_obj.bounce()
DIRTY_MMAPS.clear()

for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)

for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
pin_state = model.model.dynamic_pins[model.load_device]

if pin_state["active"]:
*_, buckets = pin_state["weights"]
for size, bucket in list(buckets.items()):
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
if not bucket:
del buckets[size]

pin_state["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})

STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()

def get_offload_stream(device):
Expand Down Expand Up @@ -1436,7 +1428,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)

dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context:
for tensor in tensors:
Expand All @@ -1448,9 +1440,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
mark_mmap_dirty(storage)
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest_view is not None:
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)


def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
Expand Down
20 changes: 10 additions & 10 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,8 @@ def register_load_device(self, device):
"""
if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"hostbufs_initialized": False,
"failed": False,
"active": False,
Expand Down Expand Up @@ -1799,8 +1799,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
Expand Down Expand Up @@ -1942,18 +1942,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
return freed

def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)

def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])

def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
Expand All @@ -1978,10 +1976,12 @@ def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
module._pin_balancer_entry[-1] = None
del module._pin_balancer_entry
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
Expand Down
11 changes: 11 additions & 0 deletions comfy/model_prefetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import comfy_aimdo.model_vbar
import comfy.memory_management
import comfy.model_management
import comfy.ops

Expand Down Expand Up @@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
if hasattr(s, "_v"):
comfy_modules.append(s)

registerable_size = 0
for s in comfy_modules:
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
registerable_size += lowvram_fn.memory_required()

offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if not comfy.model_management.args.fast_disk:
comfy.model_management.ensure_pin_registerable(registerable_size)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))

Expand Down
Loading
Loading