Skip to content

Commit 62315fb

Browse files
authored
Dynamic VRAM fixes - Ace 1.5 performance + a VRAM leak (Comfy-Org#12368)
* revert threaded model loader change This change was only needed to get around the pytorch 2.7 mempool bugs, and should have been reverted along with Comfy-Org#12260. This fixes a different memory leak where pytorch gets confused about cache emptying. * load non comfy weights * MPDynamic: Pre-generate the tensors for vbars Apparently this is an expensive operation that slows down things. * bump to aimdo 1.8 New features: watermark limit feature logging enhancements -O2 build on linux
1 parent a0302cc commit 62315fb

5 files changed

Lines changed: 20 additions & 35 deletions

File tree

comfy/model_management.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import psutil
2020
import logging
2121
from enum import Enum
22-
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
22+
from comfy.cli_args import args, PerformanceFeature
2323
import threading
2424
import torch
2525
import sys
@@ -651,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
651651
soft_empty_cache()
652652
return unloaded_models
653653

654-
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
654+
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
655655
cleanup_models_gc()
656656
global vram_state
657657

@@ -747,26 +747,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
747747
current_loaded_models.insert(0, loaded_model)
748748
return
749749

750-
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
751-
with torch.inference_mode():
752-
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
753-
soft_empty_cache()
754-
755-
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
756-
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
757-
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
758-
#thread local. So exploit that to escape context
759-
if enables_dynamic_vram():
760-
t = threading.Thread(
761-
target=load_models_gpu_thread,
762-
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
763-
)
764-
t.start()
765-
t.join()
766-
else:
767-
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
768-
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
769-
770750
def load_model_gpu(model):
771751
return load_models_gpu([model])
772752

@@ -1226,21 +1206,16 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
12261206
if dtype is None:
12271207
dtype = weight._model_dtype
12281208

1229-
r = torch.empty_like(weight, dtype=dtype, device=device)
1230-
12311209
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
12321210
if signature is not None:
1233-
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
1234-
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
1211+
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
12351212
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
12361213
weight._v_signature = signature
12371214
#Send it over
12381215
v_tensor.copy_(weight, non_blocking=non_blocking)
1239-
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
1240-
#a non comfy weight
1241-
r.copy_(v_tensor)
1242-
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
1243-
return r
1216+
return v_tensor.to(dtype=dtype)
1217+
1218+
r = torch.empty_like(weight, dtype=dtype, device=device)
12441219

12451220
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
12461221
#Offloaded casting could skip this, however it would make the quantizations

comfy/model_patcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
14921492
if vbar is not None:
14931493
vbar.prioritize()
14941494

1495-
#We have way more tools for acceleration on comfy weight offloading, so always
1495+
#We force reserve VRAM for the non comfy-weight so we dont have to deal
1496+
#with pin and unpin syncrhonization which can be expensive for small weights
1497+
#with a high layer rate (e.g. autoregressive LLMs).
14961498
#prioritize the non-comfy weights (note the order reverse).
14971499
loading = self._load_list(prio_comfy_cast_weights=True)
14981500
loading.sort(reverse=True)
@@ -1541,6 +1543,7 @@ def setup_param(self, m, n, param_key):
15411543

15421544
if vbar is not None and not hasattr(m, "_v"):
15431545
m._v = vbar.alloc(v_weight_size)
1546+
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
15441547
allocated_size += v_weight_size
15451548

15461549
else:
@@ -1555,8 +1558,10 @@ def setup_param(self, m, n, param_key):
15551558
weight_size = geometry.numel() * geometry.element_size()
15561559
if vbar is not None and not hasattr(weight, "_v"):
15571560
weight._v = vbar.alloc(weight_size)
1561+
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
15581562
weight._model_dtype = model_dtype
15591563
allocated_size += weight_size
1564+
vbar.set_watermark_limit(allocated_size)
15601565

15611566
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
15621567

comfy/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
8787

8888
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
8989
if signature is not None:
90-
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
90+
xfer_dest = s._v_tensor
9191
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
9292

9393
if not resident:

execution.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313

1414
import torch
1515

16+
from comfy.cli_args import args
1617
import comfy.memory_management
1718
import comfy.model_management
19+
import comfy_aimdo.model_vbar
20+
1821
from latent_preview import set_preview_method
1922
import nodes
2023
from comfy_execution.caching import (
@@ -527,8 +530,10 @@ def pre_execute_cb(call_index):
527530
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
528531
finally:
529532
if allocator is not None:
533+
if args.verbose == "DEBUG":
534+
comfy_aimdo.model_vbar.vbars_analyze()
530535
comfy.model_management.reset_cast_buffers()
531-
torch.cuda.synchronize()
536+
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
532537

533538
if has_pending_tasks:
534539
pending_async_nodes[unique_id] = output_data

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ alembic
2222
SQLAlchemy
2323
av>=14.2.0
2424
comfy-kitchen>=0.2.7
25-
comfy-aimdo>=0.1.7
25+
comfy-aimdo>=0.1.8
2626
requests
2727

2828
#non essential dependencies:

0 commit comments

Comments
 (0)