Skip to content

Commit 9eda706

Browse files
committed
V3: 7 more nodes
1 parent bc6b011 commit 9eda706

File tree

9 files changed

+861
-1
lines changed

9 files changed

+861
-1
lines changed

comfy_api/v3/io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,12 @@ def as_dict(self):
439439
class Image(ComfyTypeIO):
440440
Type = torch.Tensor
441441

442+
443+
@comfytype(io_type="WAN_CAMERA_EMBEDDING")
444+
class WanCameraEmbedding(ComfyTypeIO):
445+
Type = torch.Tensor
446+
447+
442448
@comfytype(io_type="WEBCAM")
443449
class Webcam(ComfyTypeIO):
444450
Type = str
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import torch
5+
from einops import rearrange
6+
7+
import comfy.model_management
8+
import nodes
9+
from comfy_api.v3 import io
10+
11+
CAMERA_DICT = {
12+
"base_T_norm": 1.5,
13+
"base_angle": np.pi / 3,
14+
"Static": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 0.0]},
15+
"Pan Up": {"angle": [0.0, 0.0, 0.0], "T": [0.0, -1.0, 0.0]},
16+
"Pan Down": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 1.0, 0.0]},
17+
"Pan Left": {"angle": [0.0, 0.0, 0.0], "T": [-1.0, 0.0, 0.0]},
18+
"Pan Right": {"angle": [0.0, 0.0, 0.0], "T": [1.0, 0.0, 0.0]},
19+
"Zoom In": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 2.0]},
20+
"Zoom Out": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, -2.0]},
21+
"Anti Clockwise (ACW)": {"angle": [0.0, 0.0, -1.0], "T": [0.0, 0.0, 0.0]},
22+
"ClockWise (CW)": {"angle": [0.0, 0.0, 1.0], "T": [0.0, 0.0, 0.0]},
23+
}
24+
25+
26+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device="cpu"):
27+
def get_relative_pose(cam_params):
28+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
29+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
30+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
31+
cam_to_origin = 0
32+
target_cam_c2w = np.array([[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]])
33+
abs2rel = target_cam_c2w @ abs_w2cs[0]
34+
ret_poses = [target_cam_c2w] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
35+
return np.array(ret_poses, dtype=np.float32)
36+
37+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
38+
cam_params = [Camera(cam_param) for cam_param in cam_params]
39+
40+
sample_wh_ratio = width / height
41+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
42+
43+
if pose_wh_ratio > sample_wh_ratio:
44+
resized_ori_w = height * pose_wh_ratio
45+
for cam_param in cam_params:
46+
cam_param.fx = resized_ori_w * cam_param.fx / width
47+
else:
48+
resized_ori_h = width / pose_wh_ratio
49+
for cam_param in cam_params:
50+
cam_param.fy = resized_ori_h * cam_param.fy / height
51+
52+
intrinsic = np.asarray(
53+
[[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params],
54+
dtype=np.float32,
55+
)
56+
57+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
58+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
59+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
60+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
61+
plucker_embedding = plucker_embedding[None]
62+
return rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
63+
64+
65+
class Camera:
66+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
67+
68+
def __init__(self, entry):
69+
fx, fy, cx, cy = entry[1:5]
70+
self.fx = fx
71+
self.fy = fy
72+
self.cx = cx
73+
self.cy = cy
74+
c2w_mat = np.array(entry[7:]).reshape(4, 4)
75+
self.c2w_mat = c2w_mat
76+
self.w2c_mat = np.linalg.inv(c2w_mat)
77+
78+
79+
def ray_condition(K, c2w, H, W, device):
80+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py"""
81+
# c2w: B, V, 4, 4
82+
# K: B, V, 4
83+
84+
B = K.shape[0]
85+
86+
j, i = torch.meshgrid(
87+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
88+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
89+
indexing="ij",
90+
)
91+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
92+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
93+
94+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
95+
96+
zs = torch.ones_like(i) # [B, HxW]
97+
xs = (i - cx) / fx * zs
98+
ys = (j - cy) / fy * zs
99+
zs = zs.expand_as(ys)
100+
101+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
102+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
103+
104+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
105+
rays_o = c2w[..., :3, 3] # B, V, 3
106+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
107+
# c2w @ dirctions
108+
rays_dxo = torch.cross(rays_o, rays_d)
109+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
110+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
111+
# plucker = plucker.permute(0, 1, 4, 2, 3)
112+
return plucker
113+
114+
115+
def get_camera_motion(angle, T, speed, n=81):
116+
def compute_R_form_rad_angle(angles):
117+
theta_x, theta_y, theta_z = angles
118+
Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]])
119+
120+
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)]])
121+
122+
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]])
123+
124+
R = np.dot(Rz, np.dot(Ry, Rx))
125+
return R
126+
127+
RT = []
128+
for i in range(n):
129+
_angle = (i / n) * speed * (CAMERA_DICT["base_angle"]) * angle
130+
R = compute_R_form_rad_angle(_angle)
131+
_T = (i / n) * speed * (CAMERA_DICT["base_T_norm"]) * (T.reshape(3, 1))
132+
_RT = np.concatenate([R, _T], axis=1)
133+
RT.append(_RT)
134+
RT = np.stack(RT)
135+
return RT
136+
137+
138+
class WanCameraEmbedding(io.ComfyNodeV3):
139+
@classmethod
140+
def define_schema(cls):
141+
return io.SchemaV3(
142+
node_id="WanCameraEmbedding_V3",
143+
category="camera",
144+
inputs=[
145+
io.Combo.Input(
146+
"camera_pose",
147+
options=[
148+
"Static",
149+
"Pan Up",
150+
"Pan Down",
151+
"Pan Left",
152+
"Pan Right",
153+
"Zoom In",
154+
"Zoom Out",
155+
"Anti Clockwise (ACW)",
156+
"ClockWise (CW)",
157+
],
158+
default="Static",
159+
),
160+
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
161+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
162+
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
163+
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
164+
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
165+
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
166+
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
167+
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
168+
],
169+
outputs=[
170+
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
171+
io.Int.Output(display_name="width"),
172+
io.Int.Output(display_name="height"),
173+
io.Int.Output(display_name="length"),
174+
],
175+
)
176+
177+
@classmethod
178+
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
179+
"""
180+
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
181+
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
182+
"""
183+
motion_list = [camera_pose]
184+
speed = speed
185+
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
186+
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
187+
RT = get_camera_motion(angle, T, speed, length)
188+
189+
trajs = []
190+
for cp in RT.tolist():
191+
traj = [fx, fy, cx, cy, 0, 0]
192+
traj.extend(cp[0])
193+
traj.extend(cp[1])
194+
traj.extend(cp[2])
195+
traj.extend([0, 0, 0, 1])
196+
trajs.append(traj)
197+
198+
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
199+
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
200+
control_camera_video = process_pose_params(cam_params, width=width, height=height)
201+
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
202+
203+
control_camera_video = torch.concat(
204+
[torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), control_camera_video[:, :, 1:]], dim=2
205+
).transpose(1, 2)
206+
207+
# Reshape, transpose, and view into desired shape
208+
b, f, c, h, w = control_camera_video.shape
209+
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
210+
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
211+
212+
return io.NodeOutput(control_camera_video, width, height, length)
213+
214+
215+
NODES_LIST = [
216+
WanCameraEmbedding,
217+
]

comfy_extras/v3/nodes_canny.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
from kornia.filters import canny
4+
5+
import comfy.model_management
6+
from comfy_api.v3 import io
7+
8+
9+
class Canny(io.ComfyNodeV3):
10+
@classmethod
11+
def define_schema(cls):
12+
return io.SchemaV3(
13+
node_id="Canny_V3",
14+
category="image/preprocessors",
15+
inputs=[
16+
io.Image.Input("image"),
17+
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
18+
io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
19+
],
20+
outputs=[io.Image.Output()],
21+
)
22+
23+
@classmethod
24+
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
25+
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
26+
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
27+
return io.NodeOutput(img_out)
28+
29+
30+
NODES_LIST = [
31+
Canny,
32+
]

comfy_extras/v3/nodes_cfg.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
from comfy_api.v3 import io
6+
7+
8+
def optimized_scale(positive, negative):
9+
positive_flat = positive.reshape(positive.shape[0], -1)
10+
negative_flat = negative.reshape(negative.shape[0], -1)
11+
12+
# Calculate dot production
13+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
14+
15+
# Squared norm of uncondition
16+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
17+
18+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
19+
st_star = dot_product / squared_norm
20+
21+
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
22+
23+
24+
class CFGNorm(io.ComfyNodeV3):
25+
@classmethod
26+
def define_schema(cls) -> io.SchemaV3:
27+
return io.SchemaV3(
28+
node_id="CFGNorm_V3",
29+
category="advanced/guidance",
30+
inputs=[
31+
io.Model.Input("model"),
32+
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
33+
],
34+
outputs=[io.Model.Output("patched_model", display_name="patched_model")],
35+
)
36+
37+
@classmethod
38+
def execute(cls, model, strength) -> io.NodeOutput:
39+
m = model.clone()
40+
41+
def cfg_norm(args):
42+
cond_p = args['cond_denoised']
43+
pred_text_ = args["denoised"]
44+
45+
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
46+
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
47+
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
48+
return pred_text_ * scale * strength
49+
50+
m.set_model_sampler_post_cfg_function(cfg_norm)
51+
return io.NodeOutput(m)
52+
53+
54+
class CFGZeroStar(io.ComfyNodeV3):
55+
@classmethod
56+
def define_schema(cls) -> io.SchemaV3:
57+
return io.SchemaV3(
58+
node_id="CFGZeroStar_V3",
59+
category="advanced/guidance",
60+
inputs=[
61+
io.Model.Input("model"),
62+
],
63+
outputs=[io.Model.Output("patched_model", display_name="patched_model")],
64+
)
65+
66+
@classmethod
67+
def execute(cls, model) -> io.NodeOutput:
68+
m = model.clone()
69+
70+
def cfg_zero_star(args):
71+
guidance_scale = args['cond_scale']
72+
x = args['input']
73+
cond_p = args['cond_denoised']
74+
uncond_p = args['uncond_denoised']
75+
out = args["denoised"]
76+
alpha = optimized_scale(x - cond_p, x - uncond_p)
77+
78+
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
79+
80+
m.set_model_sampler_post_cfg_function(cfg_zero_star)
81+
return io.NodeOutput(m)
82+
83+
84+
NODES_LIST = [
85+
CFGNorm,
86+
CFGZeroStar,
87+
]

0 commit comments

Comments
 (0)