Skip to content

Commit 4f60189

Browse files
Include workflow_id in all execution WebSocket messages (CORE-198) (Comfy-Org#13684)
1 parent 7a063e8 commit 4f60189

7 files changed

Lines changed: 398 additions & 21 deletions

File tree

comfy_execution/jobs.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,35 @@ def _create_text_preview(value: str) -> dict:
9393
}
9494

9595

96+
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
97+
"""Extract the workflow id from a prompt's ``extra_data``.
98+
99+
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
100+
when a prompt is queued. Any value that is not a non-empty string is treated as
101+
missing so callers can rely on the return being either ``None`` or a string.
102+
"""
103+
if not isinstance(extra_data, dict):
104+
return None
105+
extra_pnginfo = extra_data.get('extra_pnginfo')
106+
if not isinstance(extra_pnginfo, dict):
107+
return None
108+
workflow = extra_pnginfo.get('workflow')
109+
if not isinstance(workflow, dict):
110+
return None
111+
workflow_id = workflow.get('id')
112+
if isinstance(workflow_id, str) and workflow_id:
113+
return workflow_id
114+
return None
115+
116+
96117
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
97118
"""Extract create_time and workflow_id from extra_data.
98119
99120
Returns:
100121
tuple: (create_time, workflow_id)
101122
"""
102123
create_time = extra_data.get('create_time')
103-
extra_pnginfo = extra_data.get('extra_pnginfo', {})
104-
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
124+
workflow_id = extract_workflow_id(extra_data)
105125
return create_time, workflow_id
106126

107127

comfy_execution/progress.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressStat
164164
if self.server_instance is None:
165165
return
166166

167+
workflow_id = self.registry.workflow_id if self.registry else None
168+
167169
# Only send info for non-pending nodes
168170
active_nodes = {
169171
node_id: {
@@ -172,6 +174,7 @@ def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressStat
172174
"state": state["state"].value,
173175
"node_id": node_id,
174176
"prompt_id": prompt_id,
177+
"workflow_id": workflow_id,
175178
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
176179
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
177180
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
@@ -183,7 +186,7 @@ def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressStat
183186
# Send a combined progress_state message with all node states
184187
# Include client_id to ensure message is only sent to the initiating client
185188
self.server_instance.send_sync(
186-
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
189+
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
187190
)
188191

189192
@override
@@ -215,6 +218,7 @@ def update_handler(
215218
metadata = {
216219
"node_id": node_id,
217220
"prompt_id": prompt_id,
221+
"workflow_id": self.registry.workflow_id if self.registry else None,
218222
"display_node_id": self.registry.dynprompt.get_display_node_id(
219223
node_id
220224
),
@@ -240,9 +244,10 @@ class ProgressRegistry:
240244
Registry that maintains node progress state and notifies registered handlers.
241245
"""
242246

243-
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
247+
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
244248
self.prompt_id = prompt_id
245249
self.dynprompt = dynprompt
250+
self.workflow_id = workflow_id
246251
self.nodes: Dict[str, NodeProgressState] = {}
247252
self.handlers: Dict[str, ProgressHandler] = {}
248253

@@ -322,15 +327,15 @@ def reset_handlers(self) -> None:
322327
# Global registry instance
323328
global_progress_registry: ProgressRegistry | None = None
324329

325-
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
330+
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
326331
global global_progress_registry
327332

328333
# Reset existing handlers if registry exists
329334
if global_progress_registry is not None:
330335
global_progress_registry.reset_handlers()
331336

332337
# Create new registry
333-
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
338+
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
334339

335340

336341
def add_progress_handler(handler: ProgressHandler) -> None:

execution.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from comfy_execution.graph_utils import GraphBuilder, is_link
3939
from comfy_execution.validation import validate_node_input
4040
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
41+
from comfy_execution.jobs import extract_workflow_id
4142
from comfy_execution.utils import CurrentNodeContext
4243
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
4344
from comfy_api.latest import io, _io
@@ -417,15 +418,15 @@ def _is_intermediate_output(dynprompt, node_id):
417418
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
418419
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
419420

420-
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
421+
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
421422
if server.client_id is None:
422423
return
423424
cached_ui = cached.ui or {}
424-
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
425+
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
425426
if cached.ui is not None:
426427
ui_outputs[node_id] = cached.ui
427428

428-
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
429+
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
429430
unique_id = current_item
430431
real_node_id = dynprompt.get_real_node_id(unique_id)
431432
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -435,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
435436
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
436437
cached = await caches.outputs.get(unique_id)
437438
if cached is not None:
438-
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
439+
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
439440
get_progress_state().finish_progress(unique_id)
440441
execution_list.cache_update(unique_id, cached)
441442
return (ExecutionResult.SUCCESS, None, None)
@@ -483,7 +484,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
483484
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
484485
if server.client_id is not None:
485486
server.last_node_id = display_node_id
486-
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
487+
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
487488

488489
obj = await caches.objects.get(unique_id)
489490
if obj is None:
@@ -513,6 +514,7 @@ def execution_block_cb(block):
513514
if block.message is not None:
514515
mes = {
515516
"prompt_id": prompt_id,
517+
"workflow_id": workflow_id,
516518
"node_id": unique_id,
517519
"node_type": class_type,
518520
"executed": list(executed),
@@ -561,7 +563,7 @@ async def await_completion():
561563
"output": output_ui
562564
}
563565
if server.client_id is not None:
564-
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
566+
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
565567
if has_subgraph:
566568
cached_outputs = []
567569
new_node_ids = []
@@ -658,6 +660,7 @@ def reset(self):
658660
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
659661
self.status_messages = []
660662
self.success = True
663+
self.workflow_id = None
661664

662665
def add_message(self, event, data: dict, broadcast: bool):
663666
data = {
@@ -677,6 +680,7 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e
677680
if isinstance(ex, comfy.model_management.InterruptProcessingException):
678681
mes = {
679682
"prompt_id": prompt_id,
683+
"workflow_id": self.workflow_id,
680684
"node_id": node_id,
681685
"node_type": class_type,
682686
"executed": list(executed),
@@ -685,6 +689,7 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e
685689
else:
686690
mes = {
687691
"prompt_id": prompt_id,
692+
"workflow_id": self.workflow_id,
688693
"node_id": node_id,
689694
"node_type": class_type,
690695
"executed": list(executed),
@@ -723,7 +728,9 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
723728
self.server.client_id = None
724729

725730
self.status_messages = []
726-
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
731+
self.workflow_id = extract_workflow_id(extra_data)
732+
self.server.last_workflow_id = self.workflow_id
733+
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
727734

728735
self._notify_prompt_lifecycle("start", prompt_id)
729736
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
@@ -733,7 +740,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
733740
try:
734741
with torch.inference_mode():
735742
dynamic_prompt = DynamicPrompt(prompt)
736-
reset_progress_state(prompt_id, dynamic_prompt)
743+
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
737744
add_progress_handler(WebUIProgressHandler(self.server))
738745
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
739746
for cache in self.caches.all:
@@ -751,7 +758,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
751758

752759
comfy.model_management.cleanup_models_gc()
753760
self.add_message("execution_cached",
754-
{ "nodes": cached_nodes, "prompt_id": prompt_id},
761+
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
755762
broadcast=False)
756763
pending_subgraph_results = {}
757764
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
@@ -769,7 +776,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
769776
break
770777

771778
assert node_id is not None, "Node ID should not be None at this point"
772-
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
779+
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
773780
self.success = result != ExecutionResult.FAILURE
774781
if result == ExecutionResult.FAILURE:
775782
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -793,8 +800,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
793800
cached = await self.caches.outputs.get(node_id)
794801
if cached is not None:
795802
display_node_id = dynamic_prompt.get_display_node_id(node_id)
796-
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
797-
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
803+
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
804+
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
798805

799806
ui_outputs = {}
800807
meta_outputs = {}
@@ -811,6 +818,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
811818
finally:
812819
comfy.memory_management.set_ram_cache_release_state(None, 0)
813820
self._notify_prompt_lifecycle("end", prompt_id)
821+
self.server.last_workflow_id = None
822+
self.workflow_id = None
814823

815824

816825
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):

main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import sys
3030
from comfy_execution.progress import get_progress_state
3131
from comfy_execution.utils import get_executing_context
32+
from comfy_execution.jobs import extract_workflow_id
3233
from comfy_api import feature_flags
3334
from app.database.db import init_db, dependencies_available
3435

@@ -317,6 +318,12 @@ def prompt_worker(q, server_instance):
317318
for k in sensitive:
318319
extra_data[k] = sensitive[k]
319320

321+
# Capture the workflow id for this prompt before execution: the
322+
# executor clears server.last_workflow_id in its finally block, so
323+
# reading it after e.execute() returns would emit workflow_id=None
324+
# on the terminal "executing" reset below.
325+
workflow_id = extract_workflow_id(extra_data)
326+
320327
asset_seeder.pause()
321328
e.execute(item[2], prompt_id, extra_data, item[4])
322329

@@ -330,7 +337,7 @@ def prompt_worker(q, server_instance):
330337
completed=e.success,
331338
messages=e.status_messages), process_item=remove_sensitive)
332339
if server_instance.client_id is not None:
333-
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
340+
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": workflow_id}, server_instance.client_id)
334341

335342
current_time = time.perf_counter()
336343
execution_time = current_time - execution_start_time
@@ -393,7 +400,7 @@ def hook(value, total, preview_image, prompt_id=None, node_id=None):
393400
prompt_id = server_instance.last_prompt_id
394401
if node_id is None:
395402
node_id = server_instance.last_node_id
396-
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
403+
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
397404
get_progress_state().update_progress(node_id, value, total, preview_image)
398405

399406
server_instance.send_sync("progress", progress, server_instance.client_id)

server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,11 @@ async def websocket_handler(request):
275275
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
276276
# On reconnect if we are the currently executing client send the current node
277277
if self.client_id == sid and self.last_node_id is not None:
278-
await self.send("executing", { "node": self.last_node_id }, sid)
278+
await self.send("executing", {
279+
"node": self.last_node_id,
280+
"prompt_id": getattr(self, "last_prompt_id", None),
281+
"workflow_id": getattr(self, "last_workflow_id", None),
282+
}, sid)
279283

280284
# Flag to track if we've received the first message
281285
first_message = True

0 commit comments

Comments
 (0)