jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
#credit to ExponentialML for this module
#from https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter
import os
import torch
import comfy
from einops import rearrange
from comfy import model_base, model_management
from .lvdm.modules.networks.openaimodel3d import UNetModel as DynamiCrafterUNetModel
from .utils.model_utils import DynamiCrafterBase, DYNAMICRAFTER_CONFIG, load_image_proj_dict, load_dynamicrafter_dict, get_image_proj_model
class DynamiCrafter:
def __init__(self):
self.model_patcher = None
# There is probably a better way to do this, but with the apply_model callback, this seems necessary.
# The model gets wrapped around a CFG Denoiser class, and handles the conditioning parts there.
# We cannot access it, so we must find the conditioning according to how ComfyUI handles it.
def get_conditioning_pair(self, c_crossattn, use_cfg: bool):
if not use_cfg:
return c_crossattn
conditioning_group = []
for i in range(c_crossattn.shape[0]):
# Get the positive and negative conditioning.
positive_idx = i + 1
negative_idx = i
if positive_idx >= c_crossattn.shape[0]:
break
if not torch.equal(c_crossattn[[positive_idx]], c_crossattn[[negative_idx]]):
conditioning_group = [
c_crossattn[[positive_idx]],
c_crossattn[[negative_idx]]
]
break
if len(conditioning_group) == 0:
raise ValueError("Could not get the appropriate conditioning group.")
return torch.cat(conditioning_group)
# apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}
def _forward(self, *args):
transformer_options = self.model_patcher.model_options['transformer_options']
conditioning = transformer_options['conditioning']
apply_model = args[0]
# forward_dict
fd = args[1]
x, t, model_in_kwargs, _ = fd['input'], fd['timestep'], fd['c'], fd['cond_or_uncond']
c_crossattn = model_in_kwargs.pop("c_crossattn")
c_concat = conditioning['c_concat']
num_video_frames = conditioning['num_video_frames']
fs = conditioning['fs']
original_num_frames = num_video_frames
# Better way to determine if we're using CFG
# The cond batch will always be num_frames >= 2 since we're doing video,
# so we need get this condition differently here.
if x.shape[0] > num_video_frames:
num_video_frames *= 2
batch_size = 2
use_cfg = True
else:
use_cfg = False
batch_size = 1
if use_cfg:
c_concat = torch.cat([c_concat] * 2)
self.validate_forwardable_latent(x, c_concat, num_video_frames, use_cfg)
x_in, c_concat = map(lambda xc: rearrange(xc, '(b t) c h w -> b c t h w', b=batch_size), (x, c_concat))
# We always assume video, so there will always be batched conditionings.
c_crossattn = self.get_conditioning_pair(c_crossattn, use_cfg)
c_crossattn = c_crossattn[:2] if use_cfg else c_crossattn[:1]
context_in = c_crossattn
img_embs = conditioning['image_emb']
if use_cfg:
img_emb_uncond = conditioning['image_emb_uncond']
img_embs = torch.cat([img_embs, img_emb_uncond])
fs = torch.cat([fs] * x_in.shape[0])
outs = []
for i in range(batch_size):
model_in_kwargs['transformer_options']['cond_idx'] = i
x_out = apply_model(
x_in[[i]],
t=torch.cat([t[:1]]),
context_in=context_in[[i]],
c_crossattn=c_crossattn,
cc_concat=c_concat[[i]], # "cc" is to handle naming conflict with apply_model wrapper.
# We want to handle this in the UNet forward.
num_video_frames=num_video_frames // 2 if batch_size > 1 else num_video_frames,
img_emb=img_embs[[i]],
fs=fs[[i]],
**model_in_kwargs
)
outs.append(x_out)
x_out = torch.cat(list(reversed(outs)))
x_out = rearrange(x_out, 'b c t h w -> (b t) c h w')
return x_out
def assign_forward_args(
self,
model,
c_concat,
image_emb,
image_emb_uncond,
fs,
frames,
):
model.model_options['transformer_options']['conditioning'] = {
"c_concat": c_concat,
"image_emb": image_emb,
'image_emb_uncond': image_emb_uncond,
"fs": fs,
"num_video_frames": frames,
}
def validate_forwardable_latent(self, latent, c_concat, num_video_frames, use_cfg):
check_no_cfg = latent.shape[0] != num_video_frames
check_with_cfg = latent.shape[0] != (num_video_frames * 2)
latent_batch_size = latent.shape[0] if not use_cfg else latent.shape[0] // 2
num_frames = num_video_frames if not use_cfg else num_video_frames // 2
if all([check_no_cfg, check_with_cfg]):
raise ValueError(
"Please make sure your latent inputs match the number of frames in the DynamiCrafter Processor."
f"Got a latent batch size of ({latent_batch_size}) with number of frames being ({num_frames})."
)
latent_h, latent_w = latent.shape[-2:]
c_concat_h, c_concat_w = c_concat.shape[-2:]
if not all([latent_h == c_concat_h, latent_w == c_concat_w]):
raise ValueError(
"Please make sure that your input latent and image frames are the same height and width.",
f"Image Size: {c_concat_w * 8}, {c_concat_h * 8}, Latent Size: {latent_h * 8}, {latent_w * 8}"
)
def process_image_conditioning(
self,
model,
clip_vision,
vae,
image_proj_model,
images,
use_interpolate,
fps: int,
frames: int,
scale_latents: bool
):
self.model_patcher = model
encoded_latent = vae.encode(images[:, :, :, :3])
encoded_image = clip_vision.encode_image(images[:1])['last_hidden_state']
image_emb = image_proj_model(encoded_image)
encoded_image_uncond = clip_vision.encode_image(torch.zeros_like(images)[:1])['last_hidden_state']
image_emb_uncond = image_proj_model(encoded_image_uncond)
c_concat = encoded_latent
if scale_latents:
vae_process_input = vae.process_input
vae.process_input = lambda image: (image - .5) * 2
c_concat = vae.encode(images[:, :, :, :3])
vae.process_input = vae_process_input
c_concat = model.model.process_latent_in(c_concat) * 1.3
else:
c_concat = model.model.process_latent_in(c_concat)
fs = torch.tensor([fps], dtype=torch.long, device=model_management.intermediate_device())
model.set_model_unet_function_wrapper(self._forward)
used_interpolate_processing = False
if use_interpolate and frames > 16:
raise ValueError(
"When using interpolation mode, the maximum amount of frames are 16."
"If you're doing long video generation, consider using the last frame\
from the first generation for the next one (autoregressive)."
)
if encoded_latent.shape[0] == 1:
c_concat = torch.cat([c_concat] * frames, dim=0)[:frames]
if use_interpolate:
mask = torch.zeros_like(c_concat)
mask[:1] = c_concat[:1]
c_concat = mask
used_interpolate_processing = True
else:
if use_interpolate and c_concat.shape[0] in [2, 3]:
input_frame_count = c_concat.shape[0]
# We're just padding to the same type an size of the concat
masked_frames = torch.zeros_like(torch.cat([c_concat[:1]] * frames))[:frames]
# Start frame
masked_frames[:1] = c_concat[:1]
end_frame_idx = -1
# TODO
speed = 1.0
if speed < 1.0:
possible_speeds = list(torch.linspace(0, 1.0, c_concat.shape[0]))
speed_from_frames = enumerate(possible_speeds)
speed_idx = min(speed_from_frames, key=lambda n: n[1] - speed)[0]
end_frame_idx = speed_idx
# End frame
masked_frames[-1:] = c_concat[[end_frame_idx]]
# Possible middle frame, but not working at the moment.
if input_frame_count == 3:
middle_idx = masked_frames.shape[0] // 2
middle_idx_frame = c_concat.shape[0] // 2
masked_frames[[middle_idx]] = c_concat[[middle_idx_frame]]
c_concat = masked_frames
used_interpolate_processing = True
print(f"Using interpolation mode with {input_frame_count} frames.")
if c_concat.shape[0] < frames and not used_interpolate_processing:
print(
"Multiple images found, but interpolation mode is unset. Using the first frame as condition.",
)
c_concat = torch.cat([c_concat[:1]] * frames)
c_concat = c_concat[:frames]
if encoded_latent.shape[0] == 1:
encoded_latent = torch.cat([encoded_latent] * frames)[:frames]
if encoded_latent.shape[0] < frames and encoded_latent.shape[0] != 1:
encoded_latent = torch.cat(
[encoded_latent] + [encoded_latent[-1:]] * abs(encoded_latent.shape[0] - frames)
)[:frames]
# We could store this as a state in this Node Class Instance, but to prevent any weird edge cases,
# this should always be passed through the 'stateless' way, and let ComfyUI handle the transformer_options state.
self.assign_forward_args(model, c_concat, image_emb, image_emb_uncond, fs, frames)
return (model, {"samples": torch.zeros_like(c_concat)}, {"samples": encoded_latent},)
# Loader for the DynamiCrafter model.
def load_model_sicts(self, model_path: str):
model_state_dict = comfy.utils.load_torch_file(model_path)
dynamicrafter_dict = load_dynamicrafter_dict(model_state_dict)
image_proj_dict = load_image_proj_dict(model_state_dict)
return dynamicrafter_dict, image_proj_dict
def get_prediction_type(self, is_eps: bool, model_config):
if not is_eps and "image_cross_attention_scale_learnable" in model_config.unet_config.keys():
model_config.unet_config["image_cross_attention_scale_learnable"] = False
return model_base.ModelType.EPS if is_eps else model_base.ModelType.V_PREDICTION
def handle_model_management(self, dynamicrafter_dict: dict, model_config):
parameters = comfy.utils.calculate_parameters(dynamicrafter_dict, "model.diffusion_model.")
load_device = model_management.get_torch_device()
unet_dtype = model_management.unet_dtype(
model_params=parameters,
supported_dtypes=model_config.supported_inference_dtypes
)
manual_cast_dtype = model_management.unet_manual_cast(
unet_dtype,
load_device,
model_config.supported_inference_dtypes
)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
return load_device, inital_load_device
def check_leftover_keys(self, state_dict: dict):
left_over = state_dict.keys()
if len(left_over) > 0:
print("left over keys:", left_over)
def load_dynamicrafter(self, model_path):
if os.path.exists(model_path):
dynamicrafter_dict, image_proj_dict = self.load_model_sicts(model_path)
model_config = DynamiCrafterBase(DYNAMICRAFTER_CONFIG)
dynamicrafter_dict, is_eps = model_config.process_dict_version(state_dict=dynamicrafter_dict)
MODEL_TYPE = self.get_prediction_type(is_eps, model_config)
load_device, inital_load_device = self.handle_model_management(dynamicrafter_dict, model_config)
model = model_base.BaseModel(
model_config,
model_type=MODEL_TYPE,
device=inital_load_device,
unet_model=DynamiCrafterUNetModel
)
image_proj_model = get_image_proj_model(image_proj_dict)
model.load_model_weights(dynamicrafter_dict, "model.diffusion_model.")
self.check_leftover_keys(dynamicrafter_dict)
model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=load_device,
offload_device=model_management.unet_offload_device(),
current_device=inital_load_device
)
return (model_patcher, image_proj_model,)