|
|
|
|
|
import os |
|
import torch |
|
import comfy |
|
|
|
from einops import rearrange |
|
from comfy import model_base, model_management |
|
from .lvdm.modules.networks.openaimodel3d import UNetModel as DynamiCrafterUNetModel |
|
|
|
from .utils.model_utils import DynamiCrafterBase, DYNAMICRAFTER_CONFIG, load_image_proj_dict, load_dynamicrafter_dict, get_image_proj_model |
|
|
|
class DynamiCrafter: |
|
|
|
def __init__(self): |
|
self.model_patcher = None |
|
|
|
|
|
|
|
|
|
def get_conditioning_pair(self, c_crossattn, use_cfg: bool): |
|
if not use_cfg: |
|
return c_crossattn |
|
|
|
conditioning_group = [] |
|
|
|
for i in range(c_crossattn.shape[0]): |
|
|
|
positive_idx = i + 1 |
|
negative_idx = i |
|
|
|
if positive_idx >= c_crossattn.shape[0]: |
|
break |
|
|
|
if not torch.equal(c_crossattn[[positive_idx]], c_crossattn[[negative_idx]]): |
|
conditioning_group = [ |
|
c_crossattn[[positive_idx]], |
|
c_crossattn[[negative_idx]] |
|
] |
|
break |
|
|
|
if len(conditioning_group) == 0: |
|
raise ValueError("Could not get the appropriate conditioning group.") |
|
|
|
return torch.cat(conditioning_group) |
|
|
|
|
|
def _forward(self, *args): |
|
transformer_options = self.model_patcher.model_options['transformer_options'] |
|
conditioning = transformer_options['conditioning'] |
|
|
|
apply_model = args[0] |
|
|
|
|
|
fd = args[1] |
|
|
|
x, t, model_in_kwargs, _ = fd['input'], fd['timestep'], fd['c'], fd['cond_or_uncond'] |
|
|
|
c_crossattn = model_in_kwargs.pop("c_crossattn") |
|
c_concat = conditioning['c_concat'] |
|
num_video_frames = conditioning['num_video_frames'] |
|
fs = conditioning['fs'] |
|
|
|
original_num_frames = num_video_frames |
|
|
|
|
|
|
|
|
|
if x.shape[0] > num_video_frames: |
|
num_video_frames *= 2 |
|
batch_size = 2 |
|
use_cfg = True |
|
else: |
|
use_cfg = False |
|
batch_size = 1 |
|
|
|
if use_cfg: |
|
c_concat = torch.cat([c_concat] * 2) |
|
|
|
self.validate_forwardable_latent(x, c_concat, num_video_frames, use_cfg) |
|
|
|
x_in, c_concat = map(lambda xc: rearrange(xc, '(b t) c h w -> b c t h w', b=batch_size), (x, c_concat)) |
|
|
|
|
|
c_crossattn = self.get_conditioning_pair(c_crossattn, use_cfg) |
|
c_crossattn = c_crossattn[:2] if use_cfg else c_crossattn[:1] |
|
context_in = c_crossattn |
|
|
|
img_embs = conditioning['image_emb'] |
|
|
|
if use_cfg: |
|
img_emb_uncond = conditioning['image_emb_uncond'] |
|
img_embs = torch.cat([img_embs, img_emb_uncond]) |
|
|
|
fs = torch.cat([fs] * x_in.shape[0]) |
|
|
|
outs = [] |
|
for i in range(batch_size): |
|
model_in_kwargs['transformer_options']['cond_idx'] = i |
|
x_out = apply_model( |
|
x_in[[i]], |
|
t=torch.cat([t[:1]]), |
|
context_in=context_in[[i]], |
|
c_crossattn=c_crossattn, |
|
cc_concat=c_concat[[i]], |
|
|
|
num_video_frames=num_video_frames // 2 if batch_size > 1 else num_video_frames, |
|
img_emb=img_embs[[i]], |
|
fs=fs[[i]], |
|
**model_in_kwargs |
|
) |
|
outs.append(x_out) |
|
|
|
x_out = torch.cat(list(reversed(outs))) |
|
x_out = rearrange(x_out, 'b c t h w -> (b t) c h w') |
|
|
|
return x_out |
|
|
|
def assign_forward_args( |
|
self, |
|
model, |
|
c_concat, |
|
image_emb, |
|
image_emb_uncond, |
|
fs, |
|
frames, |
|
): |
|
model.model_options['transformer_options']['conditioning'] = { |
|
"c_concat": c_concat, |
|
"image_emb": image_emb, |
|
'image_emb_uncond': image_emb_uncond, |
|
"fs": fs, |
|
"num_video_frames": frames, |
|
} |
|
|
|
def validate_forwardable_latent(self, latent, c_concat, num_video_frames, use_cfg): |
|
check_no_cfg = latent.shape[0] != num_video_frames |
|
check_with_cfg = latent.shape[0] != (num_video_frames * 2) |
|
|
|
latent_batch_size = latent.shape[0] if not use_cfg else latent.shape[0] // 2 |
|
num_frames = num_video_frames if not use_cfg else num_video_frames // 2 |
|
|
|
if all([check_no_cfg, check_with_cfg]): |
|
raise ValueError( |
|
"Please make sure your latent inputs match the number of frames in the DynamiCrafter Processor." |
|
f"Got a latent batch size of ({latent_batch_size}) with number of frames being ({num_frames})." |
|
) |
|
|
|
latent_h, latent_w = latent.shape[-2:] |
|
c_concat_h, c_concat_w = c_concat.shape[-2:] |
|
|
|
if not all([latent_h == c_concat_h, latent_w == c_concat_w]): |
|
raise ValueError( |
|
"Please make sure that your input latent and image frames are the same height and width.", |
|
f"Image Size: {c_concat_w * 8}, {c_concat_h * 8}, Latent Size: {latent_h * 8}, {latent_w * 8}" |
|
) |
|
|
|
def process_image_conditioning( |
|
self, |
|
model, |
|
clip_vision, |
|
vae, |
|
image_proj_model, |
|
images, |
|
use_interpolate, |
|
fps: int, |
|
frames: int, |
|
scale_latents: bool |
|
): |
|
self.model_patcher = model |
|
encoded_latent = vae.encode(images[:, :, :, :3]) |
|
|
|
encoded_image = clip_vision.encode_image(images[:1])['last_hidden_state'] |
|
image_emb = image_proj_model(encoded_image) |
|
|
|
encoded_image_uncond = clip_vision.encode_image(torch.zeros_like(images)[:1])['last_hidden_state'] |
|
image_emb_uncond = image_proj_model(encoded_image_uncond) |
|
|
|
c_concat = encoded_latent |
|
|
|
if scale_latents: |
|
vae_process_input = vae.process_input |
|
vae.process_input = lambda image: (image - .5) * 2 |
|
c_concat = vae.encode(images[:, :, :, :3]) |
|
vae.process_input = vae_process_input |
|
c_concat = model.model.process_latent_in(c_concat) * 1.3 |
|
else: |
|
c_concat = model.model.process_latent_in(c_concat) |
|
|
|
fs = torch.tensor([fps], dtype=torch.long, device=model_management.intermediate_device()) |
|
|
|
model.set_model_unet_function_wrapper(self._forward) |
|
|
|
used_interpolate_processing = False |
|
|
|
if use_interpolate and frames > 16: |
|
raise ValueError( |
|
"When using interpolation mode, the maximum amount of frames are 16." |
|
"If you're doing long video generation, consider using the last frame\ |
|
from the first generation for the next one (autoregressive)." |
|
) |
|
if encoded_latent.shape[0] == 1: |
|
c_concat = torch.cat([c_concat] * frames, dim=0)[:frames] |
|
|
|
if use_interpolate: |
|
mask = torch.zeros_like(c_concat) |
|
mask[:1] = c_concat[:1] |
|
c_concat = mask |
|
|
|
used_interpolate_processing = True |
|
else: |
|
if use_interpolate and c_concat.shape[0] in [2, 3]: |
|
input_frame_count = c_concat.shape[0] |
|
|
|
|
|
masked_frames = torch.zeros_like(torch.cat([c_concat[:1]] * frames))[:frames] |
|
|
|
|
|
masked_frames[:1] = c_concat[:1] |
|
|
|
end_frame_idx = -1 |
|
|
|
|
|
speed = 1.0 |
|
if speed < 1.0: |
|
possible_speeds = list(torch.linspace(0, 1.0, c_concat.shape[0])) |
|
speed_from_frames = enumerate(possible_speeds) |
|
speed_idx = min(speed_from_frames, key=lambda n: n[1] - speed)[0] |
|
end_frame_idx = speed_idx |
|
|
|
|
|
masked_frames[-1:] = c_concat[[end_frame_idx]] |
|
|
|
|
|
if input_frame_count == 3: |
|
middle_idx = masked_frames.shape[0] // 2 |
|
middle_idx_frame = c_concat.shape[0] // 2 |
|
masked_frames[[middle_idx]] = c_concat[[middle_idx_frame]] |
|
|
|
c_concat = masked_frames |
|
used_interpolate_processing = True |
|
|
|
print(f"Using interpolation mode with {input_frame_count} frames.") |
|
|
|
if c_concat.shape[0] < frames and not used_interpolate_processing: |
|
print( |
|
"Multiple images found, but interpolation mode is unset. Using the first frame as condition.", |
|
) |
|
c_concat = torch.cat([c_concat[:1]] * frames) |
|
|
|
c_concat = c_concat[:frames] |
|
|
|
if encoded_latent.shape[0] == 1: |
|
encoded_latent = torch.cat([encoded_latent] * frames)[:frames] |
|
|
|
if encoded_latent.shape[0] < frames and encoded_latent.shape[0] != 1: |
|
encoded_latent = torch.cat( |
|
[encoded_latent] + [encoded_latent[-1:]] * abs(encoded_latent.shape[0] - frames) |
|
)[:frames] |
|
|
|
|
|
|
|
self.assign_forward_args(model, c_concat, image_emb, image_emb_uncond, fs, frames) |
|
|
|
return (model, {"samples": torch.zeros_like(c_concat)}, {"samples": encoded_latent},) |
|
|
|
|
|
|
|
def load_model_sicts(self, model_path: str): |
|
model_state_dict = comfy.utils.load_torch_file(model_path) |
|
dynamicrafter_dict = load_dynamicrafter_dict(model_state_dict) |
|
image_proj_dict = load_image_proj_dict(model_state_dict) |
|
|
|
return dynamicrafter_dict, image_proj_dict |
|
|
|
def get_prediction_type(self, is_eps: bool, model_config): |
|
if not is_eps and "image_cross_attention_scale_learnable" in model_config.unet_config.keys(): |
|
model_config.unet_config["image_cross_attention_scale_learnable"] = False |
|
|
|
return model_base.ModelType.EPS if is_eps else model_base.ModelType.V_PREDICTION |
|
|
|
def handle_model_management(self, dynamicrafter_dict: dict, model_config): |
|
parameters = comfy.utils.calculate_parameters(dynamicrafter_dict, "model.diffusion_model.") |
|
load_device = model_management.get_torch_device() |
|
unet_dtype = model_management.unet_dtype( |
|
model_params=parameters, |
|
supported_dtypes=model_config.supported_inference_dtypes |
|
) |
|
manual_cast_dtype = model_management.unet_manual_cast( |
|
unet_dtype, |
|
load_device, |
|
model_config.supported_inference_dtypes |
|
) |
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) |
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) |
|
offload_device = model_management.unet_offload_device() |
|
|
|
return load_device, inital_load_device |
|
|
|
def check_leftover_keys(self, state_dict: dict): |
|
left_over = state_dict.keys() |
|
if len(left_over) > 0: |
|
print("left over keys:", left_over) |
|
|
|
def load_dynamicrafter(self, model_path): |
|
|
|
if os.path.exists(model_path): |
|
dynamicrafter_dict, image_proj_dict = self.load_model_sicts(model_path) |
|
model_config = DynamiCrafterBase(DYNAMICRAFTER_CONFIG) |
|
|
|
dynamicrafter_dict, is_eps = model_config.process_dict_version(state_dict=dynamicrafter_dict) |
|
|
|
MODEL_TYPE = self.get_prediction_type(is_eps, model_config) |
|
load_device, inital_load_device = self.handle_model_management(dynamicrafter_dict, model_config) |
|
|
|
model = model_base.BaseModel( |
|
model_config, |
|
model_type=MODEL_TYPE, |
|
device=inital_load_device, |
|
unet_model=DynamiCrafterUNetModel |
|
) |
|
|
|
image_proj_model = get_image_proj_model(image_proj_dict) |
|
model.load_model_weights(dynamicrafter_dict, "model.diffusion_model.") |
|
self.check_leftover_keys(dynamicrafter_dict) |
|
|
|
model_patcher = comfy.model_patcher.ModelPatcher( |
|
model, |
|
load_device=load_device, |
|
offload_device=model_management.unet_offload_device(), |
|
current_device=inital_load_device |
|
) |
|
|
|
return (model_patcher, image_proj_model,) |