From 3a59a6e28b2026e55a229b179f0e876cfc2d32c7 Mon Sep 17 00:00:00 2001 From: Qirui Sun Date: Thu, 8 May 2025 18:28:54 +0000 Subject: [PATCH 1/5] support wan camera models --- comfy/ldm/wan/model.py | 143 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 10 +++ comfy/supported_models.py | 12 +++- comfy_extras/nodes_wan.py | 43 ++++++++++++ 4 files changed, 207 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index fc5ff40c5c9..623ae6efb8b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -247,6 +247,60 @@ def forward(self, c, x, **kwargs): return c_skip, c +class WanCamAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}): + super(WanCamAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + +class WanCamResidualBlock(nn.Module): + def __init__(self, dim, operation_settings={}): + super(WanCamResidualBlock, self).__init__() + self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.relu = nn.ReLU(inplace=True) + self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -637,3 +691,92 @@ def block_wrap(args): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class CameraWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='camera', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + in_dim_control_adapter=24, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) + + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + freqs=None, + camera_conditions = None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + if self.control_adapter is not None and camera_conditions is not None: + x_camera = self.control_adapter(camera_conditions).to(x.dtype) + x = x + x_camera + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 6d27930dc13..ee2970ff7fd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1075,6 +1075,16 @@ def extra_conds(self, **kwargs): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out +class WAN21_Camera(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + camera_conditions = kwargs.get("camera_conditions", None) + out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fef25eb24f6..667393ac056 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -992,6 +992,16 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1129,6 +1139,6 @@ def get_model(self, state_dict, prefix="", device=None): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9dda645971a..d5f3b13d348 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -297,6 +297,48 @@ def op(self, samples, trim_amount): samples_out["samples"] = s1[:, :, trim_amount:] return (samples_out,) +class WanCameraImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "camera_conditions": ("LATENT", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, camera_conditions, width, height, length, batch_size, start_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, 'camera_conditions': camera_conditions}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, 'camera_conditions': camera_conditions}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, @@ -305,4 +347,5 @@ def op(self, samples, trim_amount): "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, "TrimVideoLatent": TrimVideoLatent, + "WanCameraImageToVideo": WanCameraImageToVideo, } From 7a7a5ea3231d895ba6c38633ba025ae0be9fd277 Mon Sep 17 00:00:00 2001 From: Qirui Sun Date: Sun, 11 May 2025 01:50:55 +0000 Subject: [PATCH 2/5] fix by ruff check --- comfy/ldm/wan/model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 623ae6efb8b..a996dedf416 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -250,14 +250,14 @@ def forward(self, c, x, **kwargs): class WanCamAdapter(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}): super(WanCamAdapter, self).__init__() - + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) - + # Convolution: reduce spatial dimensions by a factor # of 2 (without overlap) self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - + # Residual blocks for feature extraction self.residual_blocks = nn.Sequential( *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)] @@ -267,19 +267,19 @@ def forward(self, x): # Reshape to merge the frame dimension into batch bs, c, f, h, w = x.size() x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) - + # Pixel Unshuffle operation x_unshuffled = self.pixel_unshuffle(x) - + # Convolution operation x_conv = self.conv(x_unshuffled) - + # Feature extraction with residual blocks out = self.residual_blocks(x_conv) - + # Reshape to restore original bf dimension out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) - + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames out = out.permute(0, 2, 1, 3, 4) @@ -779,4 +779,4 @@ def block_wrap(args): # unpatchify x = self.unpatchify(x, grid_sizes) - return x \ No newline at end of file + return x From 4c7206027f403e7ff069ad92db75c4148908e09f Mon Sep 17 00:00:00 2001 From: Qirui Sun Date: Tue, 13 May 2025 18:15:32 +0000 Subject: [PATCH 3/5] change camera_condition type; make camera_condition optional --- comfy/model_base.py | 3 ++- comfy_extras/nodes_wan.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ee2970ff7fd..b0065fcecde 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1083,7 +1083,8 @@ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) camera_conditions = kwargs.get("camera_conditions", None) - out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) + if camera_conditions is not None: + out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out class Hunyuan3Dv2(BaseModel): diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d5f3b13d348..a91b4aba948 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -303,7 +303,6 @@ def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "vae": ("VAE", ), - "camera_conditions": ("LATENT", ), "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), @@ -311,6 +310,7 @@ def INPUT_TYPES(s): }, "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "start_image": ("IMAGE", ), + "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @@ -319,7 +319,7 @@ def INPUT_TYPES(s): CATEGORY = "conditioning/video_models" - def encode(self, positive, negative, vae, camera_conditions, width, height, length, batch_size, start_image=None, clip_vision_output=None): + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -329,8 +329,12 @@ def encode(self, positive, negative, vae, camera_conditions, width, height, leng concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, 'camera_conditions': camera_conditions}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, 'camera_conditions': camera_conditions}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if camera_conditions is not None: + positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) + negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions}) if clip_vision_output is not None: positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) From 99d2d90ecfb8b152197b70fb6cc80896b0b8ff57 Mon Sep 17 00:00:00 2001 From: Qirui Sun Date: Tue, 13 May 2025 20:01:13 +0000 Subject: [PATCH 4/5] support camera trajectory nodes --- comfy_extras/nodes_camera_trajectory.py | 218 ++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 219 insertions(+) create mode 100644 comfy_extras/nodes_camera_trajectory.py diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py new file mode 100644 index 00000000000..874c830fe6c --- /dev/null +++ b/comfy_extras/nodes_camera_trajectory.py @@ -0,0 +1,218 @@ +import nodes +import torch +import numpy as np +from einops import rearrange +import comfy.model_management + + + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +CAMERA_DICT = { + "base_T_norm": 1.5, + "base_angle": np.pi/3, + "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]}, + "Pan Up": { "angle":[0., 0., 0.], "T":[0., 1., 0.]}, + "Pan Down": { "angle":[0., 0., 0.], "T":[0.,-1.,0.]}, + "Pan Left": { "angle":[0., 0., 0.], "T":[1.,0.,0.]}, + "Pan Right": { "angle":[0., 0., 0.], "T": [-1.,0.,0.]}, + "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, + "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, + "Anti Clockwise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, + "ClockWise (CW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, +} + + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + + def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + c2w_mat = np.array(entry[7:]).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = torch.meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + indexing='ij' + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def get_camera_motion(angle, T, speed, n=81): + def compute_R_form_rad_angle(angles): + theta_x, theta_y, theta_z = angles + Rx = np.array([[1, 0, 0], + [0, np.cos(theta_x), -np.sin(theta_x)], + [0, np.sin(theta_x), np.cos(theta_x)]]) + + Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], + [0, 1, 0], + [-np.sin(theta_y), 0, np.cos(theta_y)]]) + + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], + [np.sin(theta_z), np.cos(theta_z), 0], + [0, 0, 1]]) + + R = np.dot(Rz, np.dot(Ry, Rx)) + return R + RT = [] + for i in range(n): + _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle + R = compute_R_form_rad_angle(_angle) + _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1)) + _RT = np.concatenate([R,_T], axis=1) + RT.append(_RT) + RT = np.stack(RT) + return RT + +class WanCameraEmbeding: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (CW)", "ClockWise (CW)"],{"default":"Static"}), + "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), + }, + "optional":{ + "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), + "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + } + + } + + RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") + RETURN_NAMES = ("camera_embedding","width","height","length") + FUNCTION = "run" + CATEGORY = "camera" + + def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + """ + Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) + Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py + """ + motion_list = [camera_pose] + speed = speed + angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) + T = np.array(CAMERA_DICT[motion_list[0]]["T"]) + RT = get_camera_motion(angle, T, speed, length) + + trajs=[] + for cp in RT.tolist(): + traj=[fx,fy,cx,cy,0,0] + traj.extend(cp[0]) + traj.extend(cp[1]) + traj.extend(cp[2]) + traj.extend([0,0,0,1]) + trajs.append(traj) + + cam_params = np.array([[float(x) for x in pose] for pose in trajs]) + cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) + control_camera_video = process_pose_params(cam_params, width=width, height=height) + control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) + + control_camera_video = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_video.shape + control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + + return (control_camera_video, width, height, length) + + +NODE_CLASS_MAPPINGS = { + "WanCameraEmbeding": WanCameraEmbeding, +} diff --git a/nodes.py b/nodes.py index a1ddf2dd6fc..f6171acb66d 100644 --- a/nodes.py +++ b/nodes.py @@ -2263,6 +2263,7 @@ def init_builtin_extra_nodes(): "nodes_fresca.py", "nodes_preview_any.py", "nodes_ace.py", + "nodes_camera_trajectory.py", ] import_failed = [] From b919367780c836307ecfb56bb77b1589be5fe07c Mon Sep 17 00:00:00 2001 From: George0726 Date: Tue, 13 May 2025 21:19:09 +0000 Subject: [PATCH 5/5] fix camera direction --- comfy_extras/nodes_camera_trajectory.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 874c830fe6c..b84b4785c69 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -12,14 +12,14 @@ "base_T_norm": 1.5, "base_angle": np.pi/3, "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]}, - "Pan Up": { "angle":[0., 0., 0.], "T":[0., 1., 0.]}, - "Pan Down": { "angle":[0., 0., 0.], "T":[0.,-1.,0.]}, - "Pan Left": { "angle":[0., 0., 0.], "T":[1.,0.,0.]}, - "Pan Right": { "angle":[0., 0., 0.], "T": [-1.,0.,0.]}, - "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, - "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, - "Anti Clockwise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, - "ClockWise (CW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, + "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]}, + "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]}, + "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]}, + "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]}, + "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, + "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, + "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, + "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, } @@ -153,7 +153,7 @@ class WanCameraEmbeding: def INPUT_TYPES(cls): return { "required": { - "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (CW)", "ClockWise (CW)"],{"default":"Static"}), + "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),