|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from einops import rearrange |
| 6 | + |
| 7 | +import comfy.model_management |
| 8 | +import nodes |
| 9 | +from comfy_api.v3 import io |
| 10 | + |
| 11 | +CAMERA_DICT = { |
| 12 | + "base_T_norm": 1.5, |
| 13 | + "base_angle": np.pi / 3, |
| 14 | + "Static": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 0.0]}, |
| 15 | + "Pan Up": {"angle": [0.0, 0.0, 0.0], "T": [0.0, -1.0, 0.0]}, |
| 16 | + "Pan Down": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 1.0, 0.0]}, |
| 17 | + "Pan Left": {"angle": [0.0, 0.0, 0.0], "T": [-1.0, 0.0, 0.0]}, |
| 18 | + "Pan Right": {"angle": [0.0, 0.0, 0.0], "T": [1.0, 0.0, 0.0]}, |
| 19 | + "Zoom In": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 2.0]}, |
| 20 | + "Zoom Out": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, -2.0]}, |
| 21 | + "Anti Clockwise (ACW)": {"angle": [0.0, 0.0, -1.0], "T": [0.0, 0.0, 0.0]}, |
| 22 | + "ClockWise (CW)": {"angle": [0.0, 0.0, 1.0], "T": [0.0, 0.0, 0.0]}, |
| 23 | +} |
| 24 | + |
| 25 | + |
| 26 | +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device="cpu"): |
| 27 | + def get_relative_pose(cam_params): |
| 28 | + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" |
| 29 | + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
| 30 | + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
| 31 | + cam_to_origin = 0 |
| 32 | + target_cam_c2w = np.array([[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]]) |
| 33 | + abs2rel = target_cam_c2w @ abs_w2cs[0] |
| 34 | + ret_poses = [target_cam_c2w] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
| 35 | + return np.array(ret_poses, dtype=np.float32) |
| 36 | + |
| 37 | + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" |
| 38 | + cam_params = [Camera(cam_param) for cam_param in cam_params] |
| 39 | + |
| 40 | + sample_wh_ratio = width / height |
| 41 | + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed |
| 42 | + |
| 43 | + if pose_wh_ratio > sample_wh_ratio: |
| 44 | + resized_ori_w = height * pose_wh_ratio |
| 45 | + for cam_param in cam_params: |
| 46 | + cam_param.fx = resized_ori_w * cam_param.fx / width |
| 47 | + else: |
| 48 | + resized_ori_h = width / pose_wh_ratio |
| 49 | + for cam_param in cam_params: |
| 50 | + cam_param.fy = resized_ori_h * cam_param.fy / height |
| 51 | + |
| 52 | + intrinsic = np.asarray( |
| 53 | + [[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params], |
| 54 | + dtype=np.float32, |
| 55 | + ) |
| 56 | + |
| 57 | + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] |
| 58 | + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere |
| 59 | + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] |
| 60 | + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W |
| 61 | + plucker_embedding = plucker_embedding[None] |
| 62 | + return rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] |
| 63 | + |
| 64 | + |
| 65 | +class Camera: |
| 66 | + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" |
| 67 | + |
| 68 | + def __init__(self, entry): |
| 69 | + fx, fy, cx, cy = entry[1:5] |
| 70 | + self.fx = fx |
| 71 | + self.fy = fy |
| 72 | + self.cx = cx |
| 73 | + self.cy = cy |
| 74 | + c2w_mat = np.array(entry[7:]).reshape(4, 4) |
| 75 | + self.c2w_mat = c2w_mat |
| 76 | + self.w2c_mat = np.linalg.inv(c2w_mat) |
| 77 | + |
| 78 | + |
| 79 | +def ray_condition(K, c2w, H, W, device): |
| 80 | + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" |
| 81 | + # c2w: B, V, 4, 4 |
| 82 | + # K: B, V, 4 |
| 83 | + |
| 84 | + B = K.shape[0] |
| 85 | + |
| 86 | + j, i = torch.meshgrid( |
| 87 | + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
| 88 | + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
| 89 | + indexing="ij", |
| 90 | + ) |
| 91 | + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] |
| 92 | + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] |
| 93 | + |
| 94 | + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 |
| 95 | + |
| 96 | + zs = torch.ones_like(i) # [B, HxW] |
| 97 | + xs = (i - cx) / fx * zs |
| 98 | + ys = (j - cy) / fy * zs |
| 99 | + zs = zs.expand_as(ys) |
| 100 | + |
| 101 | + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 |
| 102 | + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 |
| 103 | + |
| 104 | + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW |
| 105 | + rays_o = c2w[..., :3, 3] # B, V, 3 |
| 106 | + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW |
| 107 | + # c2w @ dirctions |
| 108 | + rays_dxo = torch.cross(rays_o, rays_d) |
| 109 | + plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
| 110 | + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 |
| 111 | + # plucker = plucker.permute(0, 1, 4, 2, 3) |
| 112 | + return plucker |
| 113 | + |
| 114 | + |
| 115 | +def get_camera_motion(angle, T, speed, n=81): |
| 116 | + def compute_R_form_rad_angle(angles): |
| 117 | + theta_x, theta_y, theta_z = angles |
| 118 | + Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]]) |
| 119 | + |
| 120 | + Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)]]) |
| 121 | + |
| 122 | + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]]) |
| 123 | + |
| 124 | + R = np.dot(Rz, np.dot(Ry, Rx)) |
| 125 | + return R |
| 126 | + |
| 127 | + RT = [] |
| 128 | + for i in range(n): |
| 129 | + _angle = (i / n) * speed * (CAMERA_DICT["base_angle"]) * angle |
| 130 | + R = compute_R_form_rad_angle(_angle) |
| 131 | + _T = (i / n) * speed * (CAMERA_DICT["base_T_norm"]) * (T.reshape(3, 1)) |
| 132 | + _RT = np.concatenate([R, _T], axis=1) |
| 133 | + RT.append(_RT) |
| 134 | + RT = np.stack(RT) |
| 135 | + return RT |
| 136 | + |
| 137 | + |
| 138 | +class WanCameraEmbedding(io.ComfyNodeV3): |
| 139 | + @classmethod |
| 140 | + def define_schema(cls): |
| 141 | + return io.SchemaV3( |
| 142 | + node_id="WanCameraEmbedding_V3", |
| 143 | + category="camera", |
| 144 | + inputs=[ |
| 145 | + io.Combo.Input( |
| 146 | + "camera_pose", |
| 147 | + options=[ |
| 148 | + "Static", |
| 149 | + "Pan Up", |
| 150 | + "Pan Down", |
| 151 | + "Pan Left", |
| 152 | + "Pan Right", |
| 153 | + "Zoom In", |
| 154 | + "Zoom Out", |
| 155 | + "Anti Clockwise (ACW)", |
| 156 | + "ClockWise (CW)", |
| 157 | + ], |
| 158 | + default="Static", |
| 159 | + ), |
| 160 | + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), |
| 161 | + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), |
| 162 | + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), |
| 163 | + io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), |
| 164 | + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), |
| 165 | + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), |
| 166 | + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), |
| 167 | + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), |
| 168 | + ], |
| 169 | + outputs=[ |
| 170 | + io.WanCameraEmbedding.Output(display_name="camera_embedding"), |
| 171 | + io.Int.Output(display_name="width"), |
| 172 | + io.Int.Output(display_name="height"), |
| 173 | + io.Int.Output(display_name="length"), |
| 174 | + ], |
| 175 | + ) |
| 176 | + |
| 177 | + @classmethod |
| 178 | + def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: |
| 179 | + """ |
| 180 | + Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) |
| 181 | + Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py |
| 182 | + """ |
| 183 | + motion_list = [camera_pose] |
| 184 | + speed = speed |
| 185 | + angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) |
| 186 | + T = np.array(CAMERA_DICT[motion_list[0]]["T"]) |
| 187 | + RT = get_camera_motion(angle, T, speed, length) |
| 188 | + |
| 189 | + trajs = [] |
| 190 | + for cp in RT.tolist(): |
| 191 | + traj = [fx, fy, cx, cy, 0, 0] |
| 192 | + traj.extend(cp[0]) |
| 193 | + traj.extend(cp[1]) |
| 194 | + traj.extend(cp[2]) |
| 195 | + traj.extend([0, 0, 0, 1]) |
| 196 | + trajs.append(traj) |
| 197 | + |
| 198 | + cam_params = np.array([[float(x) for x in pose] for pose in trajs]) |
| 199 | + cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) |
| 200 | + control_camera_video = process_pose_params(cam_params, width=width, height=height) |
| 201 | + control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) |
| 202 | + |
| 203 | + control_camera_video = torch.concat( |
| 204 | + [torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), control_camera_video[:, :, 1:]], dim=2 |
| 205 | + ).transpose(1, 2) |
| 206 | + |
| 207 | + # Reshape, transpose, and view into desired shape |
| 208 | + b, f, c, h, w = control_camera_video.shape |
| 209 | + control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) |
| 210 | + control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) |
| 211 | + |
| 212 | + return io.NodeOutput(control_camera_video, width, height, length) |
| 213 | + |
| 214 | + |
| 215 | +NODES_LIST = [ |
| 216 | + WanCameraEmbedding, |
| 217 | +] |
0 commit comments