|
import copy |
|
from typing import Union, Callable |
|
from collections import namedtuple |
|
|
|
from einops import rearrange |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
import torch |
|
import uuid |
|
import math |
|
|
|
import comfy.conds |
|
import comfy.lora |
|
import comfy.model_management |
|
import comfy.utils |
|
from comfy.model_patcher import ModelPatcher |
|
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection |
|
from comfy.model_base import BaseModel |
|
from comfy.sd import CLIP, VAE |
|
|
|
from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight |
|
from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding |
|
from .context import ContextOptions, ContextOptions, ContextOptionsGroup |
|
from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, |
|
VanillaTemporalModule, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) |
|
from .logger import logger |
|
from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, |
|
get_combined_multival, get_combined_input, get_combined_input_effect_multival, |
|
ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch) |
|
from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode |
|
from .motion_lora import MotionLoraInfo, MotionLoraList |
|
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched |
|
from .sample_settings import SampleSettings, SeedNoiseGeneration |
|
from .dinklink import DinkLinkConst, get_dinklink, get_acn_outer_sample_wrapper |
|
|
|
|
|
def prepare_dinklink_register_definitions(): |
|
|
|
d = get_dinklink() |
|
link_ade = d.setdefault(DinkLinkConst.ADE, {}) |
|
link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER] = create_MotionModelPatcher |
|
|
|
|
|
class MotionModelPatcher(ModelPatcher): |
|
'''Class used only for type hints.''' |
|
def __init__(self): |
|
self.model: AnimateDiffModel |
|
|
|
|
|
class ModelPatcherHelper: |
|
SAMPLE_SETTINGS = "ADE_sample_settings" |
|
PARAMS = "ADE_params" |
|
ADE = "ADE" |
|
|
|
def __init__(self, model: ModelPatcher): |
|
self.model = model |
|
|
|
def set_all_properties(self, outer_sampler_wrapper: Callable, calc_cond_batch_wrapper: Callable, |
|
params: 'InjectionParams', sample_settings: SampleSettings=None, motion_models: 'MotionModelGroup'=None): |
|
self.set_outer_sample_wrapper(outer_sampler_wrapper) |
|
self.set_calc_cond_batch_wrapper(calc_cond_batch_wrapper) |
|
self.set_sample_settings(sample_settings = sample_settings if sample_settings is not None else SampleSettings()) |
|
self.set_params(params) |
|
if motion_models is not None: |
|
self.set_motion_models(motion_models.models.copy()) |
|
self.set_forward_timestep_embed_patch() |
|
else: |
|
self.remove_motion_models() |
|
self.remove_forward_timestep_embed_patch() |
|
|
|
def get_motion_models(self) -> list[MotionModelPatcher]: |
|
return self.model.additional_models.get(self.ADE, []) |
|
|
|
def set_motion_models(self, motion_models: list[MotionModelPatcher]): |
|
self.model.set_additional_models(self.ADE, motion_models) |
|
self.model.set_injections(self.ADE, |
|
[PatcherInjection(inject=inject_motion_models, eject=eject_motion_models)]) |
|
|
|
def remove_motion_models(self): |
|
self.model.remove_additional_models(self.ADE) |
|
self.model.remove_injections(self.ADE) |
|
|
|
def cleanup_motion_models(self): |
|
for motion_model in self.get_motion_models(): |
|
motion_model.cleanup() |
|
|
|
|
|
def set_forward_timestep_embed_patch(self): |
|
self.remove_forward_timestep_embed_patch() |
|
self.model.set_model_forward_timestep_embed_patch(create_forward_timestep_embed_patch()) |
|
|
|
def remove_forward_timestep_embed_patch(self): |
|
if "transformer_options" in self.model.model_options: |
|
transformer_options = self.model.model_options["transformer_options"] |
|
if "patches" in transformer_options: |
|
patches = transformer_options["patches"] |
|
if "forward_timestep_embed_patch" in patches: |
|
forward_timestep_patches: list = patches["forward_timestep_embed_patch"] |
|
to_remove = [] |
|
for idx, patch in enumerate(forward_timestep_patches): |
|
if patch[1] == forward_timestep_embed_patch_ade: |
|
to_remove.append(idx) |
|
for idx in to_remove: |
|
forward_timestep_patches.pop(idx) |
|
|
|
|
|
|
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
for motion_model in self.get_motion_models(): |
|
motion_model.model.set_video_length(video_length=video_length, full_length=full_length) |
|
|
|
def get_name_string(self, show_version=False): |
|
identifiers = [] |
|
for motion_model in self.get_motion_models(): |
|
id = motion_model.model.mm_info.mm_name |
|
if show_version: |
|
id += f":{motion_model.model.mm_info.mm_version}" |
|
identifiers.append(id) |
|
return ", ".join(identifiers) |
|
|
|
|
|
|
|
def get_sample_settings(self) -> SampleSettings: |
|
return self.model.get_attachment(self.SAMPLE_SETTINGS) |
|
|
|
def set_sample_settings(self, sample_settings: SampleSettings): |
|
self.model.set_attachments(self.SAMPLE_SETTINGS, sample_settings) |
|
|
|
|
|
def get_params(self) -> 'InjectionParams': |
|
return self.model.get_attachment(self.PARAMS) |
|
|
|
def set_params(self, params: 'InjectionParams'): |
|
self.model.set_attachments(self.PARAMS, params) |
|
if params.context_options.context_length is not None: |
|
self.set_ACN_outer_sample_wrapper(throw_exception=False) |
|
elif params.context_options.extras.context_ref is not None: |
|
self.set_ACN_outer_sample_wrapper(throw_exception=True) |
|
|
|
def set_ACN_outer_sample_wrapper(self, throw_exception=True): |
|
|
|
wrapper_info = get_acn_outer_sample_wrapper(throw_exception) |
|
if wrapper_info is None: |
|
return |
|
wrapper_type, key, wrapper = wrapper_info |
|
if len(self.model.get_wrappers(wrapper_type, key)) == 0: |
|
self.model.add_wrapper_with_key(wrapper_type, key, wrapper) |
|
|
|
def set_outer_sample_wrapper(self, wrapper: Callable): |
|
self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) |
|
self.model.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, self.ADE, wrapper) |
|
|
|
def set_calc_cond_batch_wrapper(self, wrapper: Callable): |
|
self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) |
|
self.model.add_wrapper_with_key(WrappersMP.CALC_COND_BATCH, self.ADE, wrapper) |
|
|
|
def remove_wrappers(self): |
|
self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) |
|
self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) |
|
|
|
def pre_run(self): |
|
|
|
for motion_model in self.get_motion_models(): |
|
motion_model.pre_run() |
|
self.get_sample_settings().pre_run(self.model) |
|
|
|
|
|
def inject_motion_models(patcher: ModelPatcher): |
|
helper = ModelPatcherHelper(patcher) |
|
motion_models = helper.get_motion_models() |
|
for mm in motion_models: |
|
mm.model.inject(patcher) |
|
|
|
|
|
def eject_motion_models(patcher: ModelPatcher): |
|
helper = ModelPatcherHelper(patcher) |
|
motion_models = helper.get_motion_models() |
|
for mm in motion_models: |
|
mm.model.eject(patcher) |
|
|
|
|
|
def create_forward_timestep_embed_patch(): |
|
return (VanillaTemporalModule, forward_timestep_embed_patch_ade) |
|
|
|
def forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator, *args, **kwargs): |
|
return layer(x, context, transformer_options=transformer_options) |
|
|
|
|
|
def create_MotionModelPatcher(model, load_device, offload_device) -> MotionModelPatcher: |
|
patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) |
|
ade = ModelPatcherHelper.ADE |
|
patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_patch_lowvram_extras_callback) |
|
patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_handle_float8_pe_tensors_callback) |
|
patcher.add_callback_with_key(CallbacksMP.ON_PRE_RUN, ade, _mm_pre_run_callback) |
|
patcher.add_callback_with_key(CallbacksMP.ON_CLEANUP, ade, _mm_clean_callback) |
|
patcher.set_attachments(ade, MotionModelAttachment()) |
|
return patcher |
|
|
|
|
|
def _mm_patch_lowvram_extras_callback(self: MotionModelPatcher, device_to, lowvram_model_memory, *args, **kwargs): |
|
if lowvram_model_memory > 0: |
|
|
|
remaining_tensors = list(self.model.state_dict().keys()) |
|
named_modules = [] |
|
for n, _ in self.model.named_modules(): |
|
named_modules.append(n) |
|
named_modules.append(f"{n}.weight") |
|
named_modules.append(f"{n}.bias") |
|
for name in named_modules: |
|
if name in remaining_tensors: |
|
remaining_tensors.remove(name) |
|
|
|
for key in remaining_tensors: |
|
self.patch_weight_to_device(key, device_to) |
|
if device_to is not None: |
|
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) |
|
|
|
def _mm_handle_float8_pe_tensors_callback(self: MotionModelPatcher, *args, **kwargs): |
|
remaining_tensors = list(self.model.state_dict().keys()) |
|
pe_tensors = [x for x in remaining_tensors if '.pe' in x] |
|
is_first = True |
|
for key in pe_tensors: |
|
if is_first: |
|
is_first = False |
|
if comfy.utils.get_attr(self.model, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]: |
|
break |
|
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).half()) |
|
|
|
def _mm_pre_run_callback(self: MotionModelPatcher, *args, **kwargs): |
|
attachment = get_mm_attachment(self) |
|
attachment.pre_run(self) |
|
|
|
def _mm_clean_callback(self: MotionModelPatcher, *args, **kwargs): |
|
attachment = get_mm_attachment(self) |
|
attachment.cleanup(self) |
|
|
|
|
|
def get_mm_attachment(patcher: MotionModelPatcher) -> 'MotionModelAttachment': |
|
return patcher.get_attachment(ModelPatcherHelper.ADE) |
|
|
|
|
|
class MotionModelAttachment: |
|
def __init__(self): |
|
self.timestep_percent_range = (0.0, 1.0) |
|
self.timestep_range: tuple[float, float] = None |
|
self.keyframes: ADKeyframeGroup = ADKeyframeGroup() |
|
|
|
self.scale_multival: Union[float, Tensor, None] = None |
|
self.effect_multival: Union[float, Tensor, None] = None |
|
self.per_block_list: Union[list[PerBlock], None] = None |
|
|
|
|
|
self.orig_ref_drift: float = None |
|
self.orig_insertion_weights: list[float] = None |
|
self.orig_apply_ref_when_disabled = False |
|
self.orig_img_latents: Tensor = None |
|
self.img_features: list[int, Tensor] = None |
|
self.img_latents_shape: tuple = None |
|
|
|
|
|
self.orig_camera_entries: list[CameraEntry] = None |
|
self.camera_features: list[Tensor] = None |
|
self.camera_features_shape: tuple = None |
|
self.cameractrl_multival: Union[float, Tensor] = None |
|
|
|
|
|
self.orig_pia_images: Tensor = None |
|
self.pia_vae: VAE = None |
|
self.pia_input: InputPIA = None |
|
self.cached_pia_c_concat: comfy.conds.CONDNoiseShape = None |
|
self.prev_pia_latents_shape: tuple = None |
|
self.prev_current_pia_input: InputPIA = None |
|
self.pia_multival: Union[float, Tensor] = None |
|
|
|
|
|
self.orig_fancy_images: Tensor = None |
|
self.fancy_vae: VAE = None |
|
self.cached_fancy_c_concat: comfy.conds.CONDNoiseShape = None |
|
self.prev_fancy_latents_shape: tuple = None |
|
self.fancy_multival: Union[float, Tensor] = None |
|
|
|
|
|
self.current_used_steps = 0 |
|
self.current_keyframe: ADKeyframe = None |
|
self.current_index = -1 |
|
self.previous_t = -1 |
|
self.current_scale: Union[float, Tensor] = None |
|
self.current_effect: Union[float, Tensor] = None |
|
self.current_cameractrl_effect: Union[float, Tensor] = None |
|
self.current_pia_input: InputPIA = None |
|
self.combined_scale: Union[float, Tensor] = None |
|
self.combined_effect: Union[float, Tensor] = None |
|
self.combined_per_block_list: Union[float, Tensor] = None |
|
self.combined_cameractrl_effect: Union[float, Tensor] = None |
|
self.combined_pia_mask: Union[float, Tensor] = None |
|
self.combined_pia_effect: Union[float, Tensor] = None |
|
self.was_within_range = False |
|
self.prev_sub_idxs = None |
|
self.prev_batched_number = None |
|
|
|
def pre_run(self, patcher: MotionModelPatcher): |
|
self.cleanup(patcher) |
|
patcher.model.set_scale(self.scale_multival, self.per_block_list) |
|
patcher.model.set_effect(self.effect_multival, self.per_block_list) |
|
patcher.model.set_cameractrl_effect(self.cameractrl_multival) |
|
if patcher.model.img_encoder is not None: |
|
patcher.model.img_encoder.set_ref_drift(self.orig_ref_drift) |
|
patcher.model.img_encoder.set_insertion_weights(self.orig_insertion_weights) |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
self.timestep_range = (model.model_sampling.percent_to_sigma(self.timestep_percent_range[0]), |
|
model.model_sampling.percent_to_sigma(self.timestep_percent_range[1])) |
|
if self.keyframes is not None: |
|
for keyframe in self.keyframes.keyframes: |
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) |
|
|
|
def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Tensor): |
|
curr_t: float = t[0] |
|
|
|
if curr_t == self.previous_t: |
|
return |
|
prev_index = self.current_index |
|
|
|
if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: |
|
|
|
if self.keyframes.has_index(self.current_index+1): |
|
for i in range(self.current_index+1, len(self.keyframes)): |
|
eval_kf = self.keyframes[i] |
|
|
|
|
|
if eval_kf.start_t >= curr_t: |
|
self.current_index = i |
|
self.current_keyframe = eval_kf |
|
self.current_used_steps = 0 |
|
|
|
if self.current_keyframe.has_scale(): |
|
self.current_scale = self.current_keyframe.scale_multival |
|
elif not self.current_keyframe.inherit_missing: |
|
self.current_scale = None |
|
if self.current_keyframe.has_effect(): |
|
self.current_effect = self.current_keyframe.effect_multival |
|
elif not self.current_keyframe.inherit_missing: |
|
self.current_effect = None |
|
if self.current_keyframe.has_cameractrl_effect(): |
|
self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival |
|
elif not self.current_keyframe.inherit_missing: |
|
self.current_cameractrl_effect = None |
|
if self.current_keyframe.has_pia_input(): |
|
self.current_pia_input = self.current_keyframe.pia_input |
|
elif not self.current_keyframe.inherit_missing: |
|
self.current_pia_input = None |
|
|
|
if self.current_keyframe.guarantee_steps > 0: |
|
break |
|
|
|
else: |
|
break |
|
|
|
if prev_index != self.current_index: |
|
|
|
self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale) |
|
self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect) |
|
self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect) |
|
self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) |
|
self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) |
|
|
|
patcher.model.set_scale(self.combined_scale, self.per_block_list) |
|
patcher.model.set_effect(self.combined_effect, self.per_block_list) |
|
patcher.model.set_cameractrl_effect(self.combined_cameractrl_effect) |
|
|
|
if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: |
|
patcher.model.set_effect(0.0) |
|
self.was_within_range = False |
|
else: |
|
|
|
if not self.was_within_range: |
|
patcher.model.set_effect(self.combined_effect, self.per_block_list) |
|
self.was_within_range = True |
|
|
|
self.current_used_steps += 1 |
|
|
|
self.previous_t = curr_t |
|
|
|
def prepare_alcmi2v_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): |
|
|
|
if patcher.model.img_encoder is None: |
|
return |
|
batched_number = len(cond_or_uncond) |
|
full_length = ad_params["full_length"] |
|
sub_idxs = ad_params["sub_idxs"] |
|
goal_length = x.size(0) // batched_number |
|
|
|
if (self.img_latents_shape is None or sub_idxs != self.prev_sub_idxs or batched_number != self.prev_batched_number |
|
or x.shape[2] != self.img_latents_shape[2] or x.shape[3] != self.img_latents_shape[3]): |
|
if sub_idxs is not None and self.orig_img_latents.size(0) >= full_length: |
|
img_latents = comfy.utils.common_upscale(self.orig_img_latents[sub_idxs], x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) |
|
else: |
|
img_latents = comfy.utils.common_upscale(self.orig_img_latents, x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) |
|
img_latents: Tensor = latent_format.process_in(img_latents) |
|
|
|
if goal_length != img_latents.shape[0]: |
|
img_latents = ade_broadcast_image_to(img_latents, goal_length, batched_number) |
|
img_features = patcher.model.img_encoder(img_latents, goal_length, batched_number) |
|
patcher.model.set_img_features(img_features=img_features, apply_ref_when_disabled=self.orig_apply_ref_when_disabled) |
|
|
|
self.img_latents_shape = img_latents.shape |
|
self.prev_sub_idxs = sub_idxs |
|
self.prev_batched_number = batched_number |
|
|
|
def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str]): |
|
|
|
if patcher.model.camera_encoder is None: |
|
return |
|
batched_number = len(cond_or_uncond) |
|
full_length = ad_params["full_length"] |
|
sub_idxs = ad_params["sub_idxs"] |
|
goal_length = x.size(0) // batched_number |
|
|
|
if self.camera_features_shape is None or sub_idxs != self.prev_sub_idxs or batched_number != self.prev_batched_number: |
|
|
|
camera_poses = self.orig_camera_entries.copy() |
|
if len(camera_poses) < full_length: |
|
for i in range(full_length-len(camera_poses)): |
|
camera_poses.append(camera_poses[-1]) |
|
if sub_idxs is not None: |
|
camera_poses = [camera_poses[idx] for idx in sub_idxs] |
|
|
|
if len(camera_poses) > goal_length: |
|
camera_poses = camera_poses[:goal_length] |
|
elif len(camera_poses) < goal_length: |
|
|
|
for i in range(goal_length-len(camera_poses)): |
|
camera_poses.append(camera_poses[-1]) |
|
|
|
b, c, h, w = x.shape |
|
plucker_embedding = prepare_pose_embedding(camera_poses, image_width=w*8, image_height=h*8).to(dtype=x.dtype, device=x.device) |
|
camera_embedding = patcher.model.camera_encoder(plucker_embedding, video_length=goal_length, batched_number=batched_number) |
|
patcher.model.set_camera_features(camera_features=camera_embedding) |
|
self.camera_features_shape = len(camera_embedding) |
|
self.prev_sub_idxs = sub_idxs |
|
self.prev_batched_number = batched_number |
|
|
|
def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: |
|
|
|
if self.prev_pia_latents_shape is not None: |
|
if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]: |
|
|
|
if self.prev_current_pia_input == self.current_pia_input: |
|
return self.cached_pia_c_concat |
|
|
|
b, c, h ,w = x.shape |
|
mask = prepare_mask_batch(self.combined_pia_mask, x.shape) |
|
mask = extend_to_batch_size(mask, b) |
|
|
|
self.prev_current_pia_input = self.current_pia_input |
|
|
|
|
|
self.cached_pia_c_concat.cond[:, :1, :, :] = mask |
|
return self.cached_pia_c_concat |
|
self.prev_pia_latents_shape = None |
|
|
|
|
|
cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) |
|
try: |
|
b, c, h ,w = x.shape |
|
usable_ref = self.orig_pia_images[:b] |
|
|
|
|
|
|
|
|
|
usable_ref = usable_ref.movedim(-1,1) |
|
usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.pia_vae.downscale_ratio, height=h*self.pia_vae.downscale_ratio, |
|
upscale_method="bilinear", crop="center") |
|
usable_ref = usable_ref.movedim(1,-1) |
|
|
|
logger.info("VAE Encoding PIA input images...") |
|
usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.pia_vae, pixels=usable_ref, show_pbar=False)) |
|
logger.info("VAE Encoding PIA input images complete.") |
|
|
|
usable_ref = extend_to_batch_size(usable_ref, b) |
|
self.prev_pia_latents_shape = x.shape |
|
|
|
mask = prepare_mask_batch(self.combined_pia_mask, x.shape) |
|
mask = extend_to_batch_size(mask, b) |
|
|
|
self.prev_current_pia_input = self.current_pia_input |
|
if type(self.combined_pia_effect) == Tensor or not math.isclose(self.combined_pia_effect, 1.0): |
|
real_pia_effect = self.combined_pia_effect |
|
if type(self.combined_pia_effect) == Tensor: |
|
real_pia_effect = extend_to_batch_size(prepare_mask_batch(self.combined_pia_effect, x.shape), b) |
|
zero_mask = torch.zeros_like(mask) |
|
mask = mask * real_pia_effect + zero_mask * (1.0 - real_pia_effect) |
|
del zero_mask |
|
zero_usable_ref = torch.zeros_like(usable_ref) |
|
usable_ref = usable_ref * real_pia_effect + zero_usable_ref * (1.0 - real_pia_effect) |
|
del zero_usable_ref |
|
|
|
self.cached_pia_c_concat = comfy.conds.CONDNoiseShape(torch.cat([mask, usable_ref], dim=1)) |
|
return self.cached_pia_c_concat |
|
finally: |
|
comfy.model_management.load_models_gpu(cached_loaded_models) |
|
|
|
def get_fancy_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: |
|
|
|
if self.prev_fancy_latents_shape is not None: |
|
if self.prev_fancy_latents_shape[0] == x.shape[0] and self.prev_fancy_latents_shape[-2] == x.shape[-2] and self.prev_fancy_latents_shape[-1] == x.shape[-1]: |
|
|
|
return self.cached_fancy_c_concat |
|
self.prev_fancy_latents_shape = None |
|
|
|
|
|
cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) |
|
try: |
|
b, c, h, w = x.shape |
|
usable_ref = self.orig_fancy_images[:b] |
|
|
|
usable_ref = usable_ref.movedim(-1,1) |
|
usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.fancy_vae.downscale_ratio, height=h*self.fancy_vae.downscale_ratio, |
|
upscale_method="bilinear", crop="center") |
|
usable_ref = usable_ref.movedim(1,-1) |
|
|
|
logger.info("VAE Encoding FancyVideo input images...") |
|
usable_ref: Tensor = model.process_latent_in(vae_encode_raw_batched(vae=self.fancy_vae, pixels=usable_ref, show_pbar=False)) |
|
logger.info("VAE Encoding FancyVideo input images complete.") |
|
self.prev_fancy_latents_shape = x.shape |
|
|
|
|
|
ref_length = usable_ref.shape[0] |
|
pad_length = b - ref_length |
|
zero_ref = torch.zeros([pad_length, c, h, w], dtype=usable_ref.dtype, device=usable_ref.device) |
|
usable_ref = torch.cat([usable_ref, zero_ref], dim=0) |
|
del zero_ref |
|
|
|
mask_ones = torch.ones([ref_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) |
|
mask_zeros = torch.zeros([pad_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) |
|
mask = torch.cat([mask_ones, mask_zeros], dim=0) |
|
|
|
|
|
self.cached_fancy_c_concat = comfy.conds.CONDNoiseShape(torch.cat([usable_ref, mask], dim=1)) |
|
return self.cached_fancy_c_concat |
|
finally: |
|
comfy.model_management.load_models_gpu(cached_loaded_models) |
|
|
|
def is_pia(self, patcher: MotionModelPatcher): |
|
return patcher.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None |
|
|
|
def is_fancyvideo(self, patcher: MotionModelPatcher): |
|
return patcher.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO |
|
|
|
def cleanup(self, patcher: MotionModelPatcher): |
|
if patcher.model is not None: |
|
patcher.model.cleanup() |
|
|
|
del self.img_features |
|
self.img_features = None |
|
self.img_latents_shape = None |
|
|
|
del self.camera_features |
|
self.camera_features = None |
|
self.camera_features_shape = None |
|
|
|
self.combined_pia_mask = None |
|
self.combined_pia_effect = None |
|
|
|
self.current_used_steps = 0 |
|
self.current_keyframe = None |
|
self.current_index = -1 |
|
self.previous_t = -1 |
|
self.current_scale = None |
|
self.current_effect = None |
|
self.combined_scale = None |
|
self.combined_effect = None |
|
self.combined_per_block_list = None |
|
self.was_within_range = False |
|
self.prev_sub_idxs = None |
|
self.prev_batched_number = None |
|
|
|
def on_model_patcher_clone(self): |
|
n = MotionModelAttachment() |
|
|
|
n.timestep_percent_range = self.timestep_percent_range |
|
n.timestep_range = self.timestep_range |
|
n.keyframes = self.keyframes.clone() |
|
n.scale_multival = self.scale_multival |
|
n.effect_multival = self.effect_multival |
|
|
|
n.orig_img_latents = self.orig_img_latents |
|
n.orig_ref_drift = self.orig_ref_drift |
|
n.orig_insertion_weights = self.orig_insertion_weights.copy() if self.orig_insertion_weights is not None else self.orig_insertion_weights |
|
n.orig_apply_ref_when_disabled = self.orig_apply_ref_when_disabled |
|
|
|
n.orig_camera_entries = self.orig_camera_entries |
|
n.cameractrl_multival = self.cameractrl_multival |
|
|
|
n.orig_pia_images = self.orig_pia_images |
|
n.pia_vae = self.pia_vae |
|
n.pia_input = self.pia_input |
|
n.pia_multival = self.pia_multival |
|
return n |
|
|
|
|
|
class MotionModelGroup: |
|
def __init__(self, init_motion_model: MotionModelPatcher=None): |
|
self.models: list[MotionModelPatcher] = [] |
|
if init_motion_model is not None: |
|
if isinstance(init_motion_model, list): |
|
for m in init_motion_model: |
|
self.add(m) |
|
else: |
|
self.add(init_motion_model) |
|
|
|
def add(self, mm: MotionModelPatcher): |
|
|
|
self.models.append(mm) |
|
|
|
def add_to_start(self, mm: MotionModelPatcher): |
|
self.models.insert(0, mm) |
|
|
|
def __getitem__(self, index) -> MotionModelPatcher: |
|
return self.models[index] |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.models) == 0 |
|
|
|
def clone(self) -> 'MotionModelGroup': |
|
cloned = MotionModelGroup() |
|
for mm in self.models: |
|
cloned.add(mm) |
|
return cloned |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
for motion_model in self.models: |
|
motion_model.model.set_sub_idxs(sub_idxs=sub_idxs) |
|
|
|
def set_view_options(self, view_options: ContextOptions): |
|
for motion_model in self.models: |
|
motion_model.model.set_view_options(view_options) |
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
for motion_model in self.models: |
|
motion_model.model.set_video_length(video_length=video_length, full_length=full_length) |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
for motion_model in self.models: |
|
attachment = get_mm_attachment(motion_model) |
|
attachment.initialize_timesteps(model) |
|
|
|
def pre_run(self, model: ModelPatcher): |
|
for motion_model in self.models: |
|
motion_model.pre_run() |
|
|
|
def cleanup(self): |
|
for motion_model in self.models: |
|
motion_model.cleanup() |
|
|
|
def prepare_current_keyframe(self, x: Tensor, t: Tensor): |
|
for motion_model in self.models: |
|
attachment = get_mm_attachment(motion_model) |
|
attachment.prepare_current_keyframe(motion_model, x=x, t=t) |
|
|
|
def get_special_models(self): |
|
pia_motion_models: list[MotionModelPatcher] = [] |
|
for motion_model in self.models: |
|
attachment = get_mm_attachment(motion_model) |
|
if attachment.is_pia(motion_model) or attachment.is_fancyvideo(motion_model): |
|
pia_motion_models.append(motion_model) |
|
return pia_motion_models |
|
|
|
def get_name_string(self, show_version=False): |
|
identifiers = [] |
|
for motion_model in self.models: |
|
id = motion_model.model.mm_info.mm_name |
|
if show_version: |
|
id += f":{motion_model.model.mm_info.mm_version}" |
|
identifiers.append(id) |
|
return ", ".join(identifiers) |
|
|
|
|
|
def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher: |
|
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update) |
|
model.patches = {} |
|
for k in m.patches: |
|
model.patches[k] = m.patches[k][:] |
|
|
|
model.object_patches = m.object_patches.copy() |
|
model.model_options = copy.deepcopy(m.model_options) |
|
if hasattr(model, "model_keys"): |
|
model.model_keys = m.model_keys |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None: |
|
def get_version(has_midblock: bool): |
|
return "v2" if has_midblock else "v1" |
|
|
|
lora_path = get_motion_lora_path(lora.name) |
|
logger.info(f"Loading motion LoRA {lora.name}") |
|
state_dict = comfy.utils.load_torch_file(lora_path) |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
if "temporal" not in key: |
|
del state_dict[key] |
|
if len(state_dict) == 0: |
|
raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!") |
|
|
|
model_has_midblock = motion_model.model.mid_block != None |
|
lora_has_midblock = has_mid_block(state_dict) |
|
logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.") |
|
|
|
patches = {} |
|
|
|
for key in state_dict: |
|
|
|
if not model_has_midblock: |
|
if "mid_block" in key: continue |
|
|
|
if "up." in key: continue |
|
|
|
|
|
up_key = key.replace(".down.", ".up.") |
|
|
|
|
|
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") |
|
|
|
|
|
if "to_out.0." not in model_key: |
|
model_key = model_key.replace("to_out.", "to_out.0.") |
|
|
|
weight_down = state_dict[key] |
|
weight_up = state_dict[up_key] |
|
|
|
|
|
patches[model_key] = (torch.mm( |
|
comfy.model_management.cast_to_device(weight_up, weight_up.device, torch.float32), |
|
comfy.model_management.cast_to_device(weight_down, weight_down.device, torch.float32) |
|
),) |
|
del state_dict |
|
|
|
motion_model.add_patches(patches=patches, strength_patch=lora.strength) |
|
|
|
|
|
def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: |
|
model_path = get_motion_model_path(model_name) |
|
logger.info(f"Loading motion module {model_name}") |
|
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) |
|
|
|
|
|
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) |
|
|
|
model_sd_type = get_sd_model_type(model) |
|
if model_sd_type != mm_info.sd_type: |
|
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ |
|
+ f"but the provided model is type {model_sd_type}.") |
|
|
|
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) |
|
|
|
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) |
|
ad_wrapper.to(model.model_dtype()) |
|
ad_wrapper.to(model.offload_device) |
|
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) |
|
verify_load_result(load_result=load_result, mm_info=mm_info) |
|
|
|
motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) |
|
|
|
if motion_lora is not None: |
|
for lora in motion_lora.loras: |
|
load_motion_lora_as_patches(motion_model, lora) |
|
return motion_model |
|
|
|
|
|
def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: |
|
model_path = get_motion_model_path(model_name) |
|
logger.info(f"Loading motion module {model_name} via Gen2") |
|
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) |
|
|
|
|
|
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) |
|
|
|
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) |
|
|
|
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) |
|
ad_wrapper.to(comfy.model_management.unet_dtype()) |
|
ad_wrapper.to(comfy.model_management.unet_offload_device()) |
|
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) |
|
verify_load_result(load_result=load_result, mm_info=mm_info) |
|
|
|
motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), |
|
offload_device=comfy.model_management.unet_offload_device()) |
|
return motion_model |
|
|
|
|
|
IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys']) |
|
def verify_load_result(load_result: IncompatibleKeys, mm_info: AnimateDiffInfo): |
|
error_msgs: list[str] = [] |
|
is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM |
|
|
|
remove_missing_idxs = [] |
|
remove_unexpected_idxs = [] |
|
for idx, key in enumerate(load_result.missing_keys): |
|
|
|
if is_animatelcm and "pos_encoder.pe" in key: |
|
remove_missing_idxs.append(idx) |
|
|
|
for idx in reversed(remove_unexpected_idxs): |
|
load_result.unexpected_keys.pop(idx) |
|
for idx in reversed(remove_missing_idxs): |
|
load_result.missing_keys.pop(idx) |
|
|
|
if len(load_result.unexpected_keys) > 0: |
|
error_msgs.insert( |
|
0, 'Unexpected key(s) in state_dict: {}. '.format( |
|
', '.join(f'"{k}"' for k in load_result.unexpected_keys))) |
|
if len(load_result.missing_keys) > 0: |
|
error_msgs.insert( |
|
0, 'Missing key(s) in state_dict: {}. '.format( |
|
', '.join(f'"{k}"' for k in load_result.missing_keys))) |
|
if len(error_msgs) > 0: |
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( |
|
mm_info.mm_name, "\n\t".join(error_msgs))) |
|
|
|
|
|
def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher: |
|
ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) |
|
ad_wrapper.to(comfy.model_management.unet_dtype()) |
|
ad_wrapper.to(comfy.model_management.unet_offload_device()) |
|
ad_wrapper.load_state_dict(motion_model.model.state_dict()) |
|
return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), |
|
offload_device=comfy.model_management.unet_offload_device()) |
|
|
|
|
|
def create_fresh_encoder_only_model(motion_model: MotionModelPatcher) -> MotionModelPatcher: |
|
ad_wrapper = EncoderOnlyAnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) |
|
ad_wrapper.to(comfy.model_management.unet_dtype()) |
|
ad_wrapper.to(comfy.model_management.unet_offload_device()) |
|
ad_wrapper.load_state_dict(motion_model.model.state_dict(), strict=False) |
|
return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), |
|
offload_device=comfy.model_management.unet_offload_device()) |
|
|
|
|
|
def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: MotionModelPatcher): |
|
motion_model.model.init_img_encoder() |
|
motion_model.model.img_encoder.to(comfy.model_management.unet_dtype()) |
|
motion_model.model.img_encoder.to(comfy.model_management.unet_offload_device()) |
|
motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict()) |
|
|
|
|
|
def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher): |
|
motion_model.model.init_conv_in(w_pia.model.state_dict()) |
|
motion_model.model.conv_in.to(comfy.model_management.unet_dtype()) |
|
motion_model.model.conv_in.to(comfy.model_management.unet_offload_device()) |
|
motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict()) |
|
motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA |
|
|
|
|
|
def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str): |
|
camera_ctrl_path = get_motion_model_path(camera_ctrl_name) |
|
full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True) |
|
camera_state_dict: dict[str, Tensor] = dict() |
|
attention_state_dict: dict[str, Tensor] = dict() |
|
for key in full_state_dict: |
|
if key.startswith("encoder"): |
|
camera_state_dict[key] = full_state_dict[key] |
|
elif "qkv_merge" in key: |
|
attention_state_dict[key] = full_state_dict[key] |
|
|
|
if len(camera_state_dict) == 0: |
|
raise Exception("Provided CameraCtrl model had no Camera Encoder-related keys; not a valid CameraCtrl model!") |
|
if len(attention_state_dict) == 0: |
|
raise Exception("Provided CameraCtrl model had no qkv_merge keys; not a valid CameraCtrl model!") |
|
|
|
camera_encoder = CameraPoseEncoder(channels=motion_model.model.layer_channels, nums_rb=2, ops=motion_model.model.ops).to( |
|
device=comfy.model_management.unet_offload_device(), |
|
dtype=comfy.model_management.unet_dtype() |
|
) |
|
camera_encoder.load_state_dict(camera_state_dict) |
|
camera_encoder.temporal_pe_max_len = get_position_encoding_max_len(camera_state_dict, mm_name=camera_ctrl_name, mm_format=AnimateDiffFormat.ANIMATEDIFF) |
|
motion_model.model.set_camera_encoder(camera_encoder=camera_encoder) |
|
|
|
for key in attention_state_dict: |
|
key = key.strip() |
|
|
|
if key.endswith("weight"): |
|
continue |
|
attr_path = key.split(".processor.qkv_merge")[0] |
|
base_key = key.split(".bias")[0] |
|
|
|
attention_obj: VersatileAttention = comfy.utils.get_attr(motion_model.model, attr_path) |
|
attention_obj.init_qkv_merge(ops=motion_model.model.ops) |
|
|
|
qkv_merge_state_dict = {} |
|
qkv_merge_state_dict["weight"] = attention_state_dict[f"{base_key}.weight"] |
|
qkv_merge_state_dict["bias"] = attention_state_dict[f"{base_key}.bias"] |
|
attention_obj.qkv_merge.load_state_dict(qkv_merge_state_dict) |
|
attention_obj.qkv_merge = attention_obj.qkv_merge.to( |
|
device=comfy.model_management.unet_offload_device(), |
|
dtype=comfy.model_management.unet_dtype() |
|
) |
|
|
|
|
|
def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionModelPatcher): |
|
|
|
model_sd_type = get_sd_model_type(model) |
|
mm_info = motion_model.model.mm_info |
|
if model_sd_type != mm_info.sd_type: |
|
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ |
|
+ f"but the provided model is type {model_sd_type}.") |
|
|
|
|
|
def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks): |
|
if all_per_blocks is None or all_per_blocks.sd_type is None: |
|
return |
|
mm_info = motion_model.model.mm_info |
|
if all_per_blocks.sd_type != mm_info.sd_type: |
|
raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.") |
|
|
|
|
|
def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): |
|
pe_shape = model_dict[key].shape |
|
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) |
|
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") |
|
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) |
|
model_dict[key] = temp_pe |
|
del temp_pe |
|
|
|
|
|
def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int): |
|
|
|
pe_shape = model_dict[key].shape |
|
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) |
|
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") |
|
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) |
|
model_dict[key] = temp_pe |
|
del temp_pe |
|
|
|
|
|
def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int): |
|
if model_dict[key].shape[1] < new_length: |
|
temp_pe = model_dict[key] |
|
flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1]) |
|
use_flipped = True |
|
preview_pe = None |
|
while model_dict[key].shape[1] < new_length: |
|
preview_pe = model_dict[key] |
|
model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1) |
|
use_flipped = not use_flipped |
|
del temp_pe |
|
del flipped_temp_pe |
|
del preview_pe |
|
model_dict[key] = model_dict[key][:, :new_length] |
|
|
|
|
|
def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str): |
|
pe_portion = model_dict[key].shape[2] // 64 |
|
first_pe = model_dict[key][:,:1,:] |
|
model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:] |
|
del first_pe |
|
|
|
|
|
def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str): |
|
attn_portion = model_dict[key].shape[0] // 2 |
|
model_dict[key][:attn_portion,:attn_portion] *= 1.5 |
|
|
|
|
|
def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: AnimateDiffSettings) -> dict[str, Tensor]: |
|
if mm_settings is None: |
|
return model_dict |
|
if not mm_settings.has_anything_to_apply(): |
|
return model_dict |
|
|
|
for adjust_pe in mm_settings.adjust_pe.adjusts: |
|
adjust_pe: AdjustPE |
|
if adjust_pe.has_anything_to_apply(): |
|
already_printed = False |
|
for key in model_dict: |
|
if "attention_blocks" in key and "pos_encoder" in key: |
|
|
|
if adjust_pe.has_motion_pe_stretch(): |
|
original_length = model_dict[key].shape[1] |
|
new_pe_length = original_length + adjust_pe.motion_pe_stretch |
|
interpolate_pe_to_length(model_dict, key, new_length=new_pe_length) |
|
if adjust_pe.print_adjustment and not already_printed: |
|
logger.info(f"[Adjust PE]: PE Stretch from {original_length} to {new_pe_length}.") |
|
|
|
if adjust_pe.has_initial_pe_idx_offset(): |
|
original_length = model_dict[key].shape[1] |
|
model_dict[key] = model_dict[key][:, adjust_pe.initial_pe_idx_offset:] |
|
if adjust_pe.print_adjustment and not already_printed: |
|
logger.info(f"[Adjust PE]: Offsetting PEs by {adjust_pe.initial_pe_idx_offset}; PE length to shortens from {original_length} to {model_dict[key].shape[1]}.") |
|
|
|
if adjust_pe.has_cap_initial_pe_length(): |
|
original_length = model_dict[key].shape[1] |
|
model_dict[key] = model_dict[key][:, :adjust_pe.cap_initial_pe_length] |
|
if adjust_pe.print_adjustment and not already_printed: |
|
logger.info(f"[Adjust PE]: Capping PEs (initial) from {original_length} to {model_dict[key].shape[1]}.") |
|
|
|
if adjust_pe.has_interpolate_pe_to_length(): |
|
original_length = model_dict[key].shape[1] |
|
interpolate_pe_to_length(model_dict, key, new_length=adjust_pe.interpolate_pe_to_length) |
|
if adjust_pe.print_adjustment and not already_printed: |
|
logger.info(f"[Adjust PE]: Interpolating PE length from {original_length} to {model_dict[key].shape[1]}.") |
|
|
|
if adjust_pe.has_final_pe_idx_offset(): |
|
original_length = model_dict[key].shape[1] |
|
model_dict[key] = model_dict[key][:, adjust_pe.final_pe_idx_offset:] |
|
if adjust_pe.print_adjustment and not already_printed: |
|
logger.info(f"[Adjust PE]: Capping PEs (final) from {original_length} to {model_dict[key].shape[1]}.") |
|
already_printed = True |
|
|
|
for adjust_w in mm_settings.adjust_weight.adjusts: |
|
adjust_w: AdjustWeight |
|
if adjust_w.has_anything_to_apply(): |
|
adjust_w.mark_attrs_as_unprinted() |
|
for key in model_dict: |
|
|
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ALL, model_dict=model_dict, key=key) |
|
if "attention_blocks" in key: |
|
|
|
if "pos_encoder" in key: |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_PE, model_dict=model_dict, key=key) |
|
else: |
|
|
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN, model_dict=model_dict, key=key) |
|
|
|
|
|
if "to_q" in key: |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_Q, model_dict=model_dict, key=key) |
|
|
|
elif "to_k" in key: |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_K, model_dict=model_dict, key=key) |
|
|
|
elif "to_v" in key: |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_V, model_dict=model_dict, key=key) |
|
|
|
elif "to_out" in key: |
|
if key.strip().endswith("weight"): |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_OUT_WEIGHT, model_dict=model_dict, key=key) |
|
elif key.strip().endswith("bias"): |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_OUT_BIAS, model_dict=model_dict, key=key) |
|
else: |
|
adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_OTHER, model_dict=model_dict, key=key) |
|
return model_dict |
|
|
|
|
|
class InjectionParams: |
|
def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, |
|
apply_v2_properly: bool=True) -> None: |
|
self.full_length = None |
|
self.unlimited_area_hack = unlimited_area_hack |
|
self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack |
|
self.apply_v2_properly = apply_v2_properly |
|
self.context_options: ContextOptionsGroup = ContextOptionsGroup.default() |
|
self.motion_model_settings = AnimateDiffSettings() |
|
self.sub_idxs = None |
|
|
|
def set_noise_extra_args(self, noise_extra_args: dict): |
|
noise_extra_args["context_options"] = self.context_options.clone() |
|
|
|
def set_context(self, context_options: ContextOptionsGroup): |
|
self.context_options = context_options.clone() if context_options else ContextOptionsGroup.default() |
|
|
|
def is_using_sliding_context(self) -> bool: |
|
return self.context_options.context_length is not None |
|
|
|
def set_motion_model_settings(self, motion_model_settings: AnimateDiffSettings): |
|
if motion_model_settings is None: |
|
self.motion_model_settings = AnimateDiffSettings() |
|
else: |
|
self.motion_model_settings = motion_model_settings |
|
|
|
def reset_context(self): |
|
self.context_options = ContextOptionsGroup.default() |
|
|
|
def clone(self) -> 'InjectionParams': |
|
new_params = InjectionParams( |
|
self.unlimited_area_hack, self.apply_mm_groupnorm_hack, apply_v2_properly=self.apply_v2_properly, |
|
) |
|
new_params.full_length = self.full_length |
|
new_params.set_context(self.context_options) |
|
new_params.set_motion_model_settings(self.motion_model_settings) |
|
return new_params |
|
|
|
def on_model_patcher_clone(self): |
|
return self.clone() |
|
|