import copy from typing import Union, Callable from collections import namedtuple from einops import rearrange from torch import Tensor import torch.nn.functional as F import torch import uuid import math import comfy.conds import comfy.lora import comfy.model_management import comfy.utils from comfy.model_patcher import ModelPatcher from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.model_base import BaseModel from comfy.sd import CLIP, VAE from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, VanillaTemporalModule, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, get_combined_multival, get_combined_input, get_combined_input_effect_multival, ade_broadcast_image_to, extend_to_batch_size, prepare_mask_batch) from .conditioning import HookRef, LoraHook, LoraHookGroup, LoraHookMode from .motion_lora import MotionLoraInfo, MotionLoraList from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched from .sample_settings import SampleSettings, SeedNoiseGeneration from .dinklink import DinkLinkConst, get_dinklink, get_acn_outer_sample_wrapper def prepare_dinklink_register_definitions(): # expose create_MotionModelPatcher d = get_dinklink() link_ade = d.setdefault(DinkLinkConst.ADE, {}) link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER] = create_MotionModelPatcher class MotionModelPatcher(ModelPatcher): '''Class used only for type hints.''' def __init__(self): self.model: AnimateDiffModel class ModelPatcherHelper: SAMPLE_SETTINGS = "ADE_sample_settings" PARAMS = "ADE_params" ADE = "ADE" def __init__(self, model: ModelPatcher): self.model = model def set_all_properties(self, outer_sampler_wrapper: Callable, calc_cond_batch_wrapper: Callable, params: 'InjectionParams', sample_settings: SampleSettings=None, motion_models: 'MotionModelGroup'=None): self.set_outer_sample_wrapper(outer_sampler_wrapper) self.set_calc_cond_batch_wrapper(calc_cond_batch_wrapper) self.set_sample_settings(sample_settings = sample_settings if sample_settings is not None else SampleSettings()) self.set_params(params) if motion_models is not None: self.set_motion_models(motion_models.models.copy()) self.set_forward_timestep_embed_patch() else: self.remove_motion_models() self.remove_forward_timestep_embed_patch() def get_motion_models(self) -> list[MotionModelPatcher]: return self.model.additional_models.get(self.ADE, []) def set_motion_models(self, motion_models: list[MotionModelPatcher]): self.model.set_additional_models(self.ADE, motion_models) self.model.set_injections(self.ADE, [PatcherInjection(inject=inject_motion_models, eject=eject_motion_models)]) def remove_motion_models(self): self.model.remove_additional_models(self.ADE) self.model.remove_injections(self.ADE) def cleanup_motion_models(self): for motion_model in self.get_motion_models(): motion_model.cleanup() def set_forward_timestep_embed_patch(self): self.remove_forward_timestep_embed_patch() self.model.set_model_forward_timestep_embed_patch(create_forward_timestep_embed_patch()) def remove_forward_timestep_embed_patch(self): if "transformer_options" in self.model.model_options: transformer_options = self.model.model_options["transformer_options"] if "patches" in transformer_options: patches = transformer_options["patches"] if "forward_timestep_embed_patch" in patches: forward_timestep_patches: list = patches["forward_timestep_embed_patch"] to_remove = [] for idx, patch in enumerate(forward_timestep_patches): if patch[1] == forward_timestep_embed_patch_ade: to_remove.append(idx) for idx in to_remove: forward_timestep_patches.pop(idx) ########################## # motion models helpers def set_video_length(self, video_length: int, full_length: int): for motion_model in self.get_motion_models(): motion_model.model.set_video_length(video_length=video_length, full_length=full_length) def get_name_string(self, show_version=False): identifiers = [] for motion_model in self.get_motion_models(): id = motion_model.model.mm_info.mm_name if show_version: id += f":{motion_model.model.mm_info.mm_version}" identifiers.append(id) return ", ".join(identifiers) ########################## def get_sample_settings(self) -> SampleSettings: return self.model.get_attachment(self.SAMPLE_SETTINGS) def set_sample_settings(self, sample_settings: SampleSettings): self.model.set_attachments(self.SAMPLE_SETTINGS, sample_settings) def get_params(self) -> 'InjectionParams': return self.model.get_attachment(self.PARAMS) def set_params(self, params: 'InjectionParams'): self.model.set_attachments(self.PARAMS, params) if params.context_options.context_length is not None: self.set_ACN_outer_sample_wrapper(throw_exception=False) elif params.context_options.extras.context_ref is not None: self.set_ACN_outer_sample_wrapper(throw_exception=True) def set_ACN_outer_sample_wrapper(self, throw_exception=True): # get wrapper to register from Advanced-ControlNet via DinkLink shared dict wrapper_info = get_acn_outer_sample_wrapper(throw_exception) if wrapper_info is None: return wrapper_type, key, wrapper = wrapper_info if len(self.model.get_wrappers(wrapper_type, key)) == 0: self.model.add_wrapper_with_key(wrapper_type, key, wrapper) def set_outer_sample_wrapper(self, wrapper: Callable): self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) self.model.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, self.ADE, wrapper) def set_calc_cond_batch_wrapper(self, wrapper: Callable): self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) self.model.add_wrapper_with_key(WrappersMP.CALC_COND_BATCH, self.ADE, wrapper) def remove_wrappers(self): self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) def pre_run(self): # TODO: could implement this as a ModelPatcher ON_PRE_RUN callback for motion_model in self.get_motion_models(): motion_model.pre_run() self.get_sample_settings().pre_run(self.model) def inject_motion_models(patcher: ModelPatcher): helper = ModelPatcherHelper(patcher) motion_models = helper.get_motion_models() for mm in motion_models: mm.model.inject(patcher) def eject_motion_models(patcher: ModelPatcher): helper = ModelPatcherHelper(patcher) motion_models = helper.get_motion_models() for mm in motion_models: mm.model.eject(patcher) def create_forward_timestep_embed_patch(): return (VanillaTemporalModule, forward_timestep_embed_patch_ade) def forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator, *args, **kwargs): return layer(x, context, transformer_options=transformer_options) def create_MotionModelPatcher(model, load_device, offload_device) -> MotionModelPatcher: patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) ade = ModelPatcherHelper.ADE patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_patch_lowvram_extras_callback) patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_handle_float8_pe_tensors_callback) patcher.add_callback_with_key(CallbacksMP.ON_PRE_RUN, ade, _mm_pre_run_callback) patcher.add_callback_with_key(CallbacksMP.ON_CLEANUP, ade, _mm_clean_callback) patcher.set_attachments(ade, MotionModelAttachment()) return patcher def _mm_patch_lowvram_extras_callback(self: MotionModelPatcher, device_to, lowvram_model_memory, *args, **kwargs): if lowvram_model_memory > 0: # figure out the tensors (likely pe's) that should be cast to device besides just the named_modules remaining_tensors = list(self.model.state_dict().keys()) named_modules = [] for n, _ in self.model.named_modules(): named_modules.append(n) named_modules.append(f"{n}.weight") named_modules.append(f"{n}.bias") for name in named_modules: if name in remaining_tensors: remaining_tensors.remove(name) for key in remaining_tensors: self.patch_weight_to_device(key, device_to) if device_to is not None: comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) def _mm_handle_float8_pe_tensors_callback(self: MotionModelPatcher, *args, **kwargs): remaining_tensors = list(self.model.state_dict().keys()) pe_tensors = [x for x in remaining_tensors if '.pe' in x] is_first = True for key in pe_tensors: if is_first: is_first = False if comfy.utils.get_attr(self.model, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]: break comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).half()) def _mm_pre_run_callback(self: MotionModelPatcher, *args, **kwargs): attachment = get_mm_attachment(self) attachment.pre_run(self) def _mm_clean_callback(self: MotionModelPatcher, *args, **kwargs): attachment = get_mm_attachment(self) attachment.cleanup(self) def get_mm_attachment(patcher: MotionModelPatcher) -> 'MotionModelAttachment': return patcher.get_attachment(ModelPatcherHelper.ADE) class MotionModelAttachment: def __init__(self): self.timestep_percent_range = (0.0, 1.0) self.timestep_range: tuple[float, float] = None self.keyframes: ADKeyframeGroup = ADKeyframeGroup() self.scale_multival: Union[float, Tensor, None] = None self.effect_multival: Union[float, Tensor, None] = None self.per_block_list: Union[list[PerBlock], None] = None # AnimateLCM-I2V self.orig_ref_drift: float = None self.orig_insertion_weights: list[float] = None self.orig_apply_ref_when_disabled = False self.orig_img_latents: Tensor = None self.img_features: list[int, Tensor] = None # temporary self.img_latents_shape: tuple = None # CameraCtrl self.orig_camera_entries: list[CameraEntry] = None self.camera_features: list[Tensor] = None # temporary self.camera_features_shape: tuple = None self.cameractrl_multival: Union[float, Tensor] = None # PIA self.orig_pia_images: Tensor = None self.pia_vae: VAE = None self.pia_input: InputPIA = None self.cached_pia_c_concat: comfy.conds.CONDNoiseShape = None # cached self.prev_pia_latents_shape: tuple = None self.prev_current_pia_input: InputPIA = None self.pia_multival: Union[float, Tensor] = None # FancyVideo self.orig_fancy_images: Tensor = None self.fancy_vae: VAE = None self.cached_fancy_c_concat: comfy.conds.CONDNoiseShape = None # cached self.prev_fancy_latents_shape: tuple = None self.fancy_multival: Union[float, Tensor] = None # temporary variables self.current_used_steps = 0 self.current_keyframe: ADKeyframe = None self.current_index = -1 self.previous_t = -1 self.current_scale: Union[float, Tensor] = None self.current_effect: Union[float, Tensor] = None self.current_cameractrl_effect: Union[float, Tensor] = None self.current_pia_input: InputPIA = None self.combined_scale: Union[float, Tensor] = None self.combined_effect: Union[float, Tensor] = None self.combined_per_block_list: Union[float, Tensor] = None self.combined_cameractrl_effect: Union[float, Tensor] = None self.combined_pia_mask: Union[float, Tensor] = None self.combined_pia_effect: Union[float, Tensor] = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None def pre_run(self, patcher: MotionModelPatcher): self.cleanup(patcher) patcher.model.set_scale(self.scale_multival, self.per_block_list) patcher.model.set_effect(self.effect_multival, self.per_block_list) patcher.model.set_cameractrl_effect(self.cameractrl_multival) if patcher.model.img_encoder is not None: patcher.model.img_encoder.set_ref_drift(self.orig_ref_drift) patcher.model.img_encoder.set_insertion_weights(self.orig_insertion_weights) def initialize_timesteps(self, model: BaseModel): self.timestep_range = (model.model_sampling.percent_to_sigma(self.timestep_percent_range[0]), model.model_sampling.percent_to_sigma(self.timestep_percent_range[1])) if self.keyframes is not None: for keyframe in self.keyframes.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Tensor): curr_t: float = t[0] # if curr_t was previous_t, then do nothing (already accounted for this step) if curr_t == self.previous_t: return prev_index = self.current_index # if met guaranteed steps, look for next keyframe in case need to switch if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: # if has next index, loop through and see if need to switch if self.keyframes.has_index(self.current_index+1): for i in range(self.current_index+1, len(self.keyframes)): eval_kf = self.keyframes[i] # check if start_t is greater or equal to curr_t # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling if eval_kf.start_t >= curr_t: self.current_index = i self.current_keyframe = eval_kf self.current_used_steps = 0 # keep track of scale and effect multivals, accounting for inherit_missing if self.current_keyframe.has_scale(): self.current_scale = self.current_keyframe.scale_multival elif not self.current_keyframe.inherit_missing: self.current_scale = None if self.current_keyframe.has_effect(): self.current_effect = self.current_keyframe.effect_multival elif not self.current_keyframe.inherit_missing: self.current_effect = None if self.current_keyframe.has_cameractrl_effect(): self.current_cameractrl_effect = self.current_keyframe.cameractrl_multival elif not self.current_keyframe.inherit_missing: self.current_cameractrl_effect = None if self.current_keyframe.has_pia_input(): self.current_pia_input = self.current_keyframe.pia_input elif not self.current_keyframe.inherit_missing: self.current_pia_input = None # if guarantee_steps greater than zero, stop searching for other keyframes if self.current_keyframe.guarantee_steps > 0: break # if eval_kf is outside the percent range, stop looking further else: break # if index changed, apply new combined values if prev_index != self.current_index: # combine model's scale and effect with keyframe's scale and effect self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale) self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect) self.combined_cameractrl_effect = get_combined_multival(self.cameractrl_multival, self.current_cameractrl_effect) self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect patcher.model.set_scale(self.combined_scale, self.per_block_list) patcher.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list patcher.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: patcher.model.set_effect(0.0) self.was_within_range = False else: # if was not in range last step, apply effect to toggle AD status if not self.was_within_range: patcher.model.set_effect(self.combined_effect, self.per_block_list) self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 # update previous_t self.previous_t = curr_t def prepare_alcmi2v_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): # if no img_encoder, done if patcher.model.img_encoder is None: return batched_number = len(cond_or_uncond) full_length = ad_params["full_length"] sub_idxs = ad_params["sub_idxs"] goal_length = x.size(0) // batched_number # calculate img_features if needed if (self.img_latents_shape is None or sub_idxs != self.prev_sub_idxs or batched_number != self.prev_batched_number or x.shape[2] != self.img_latents_shape[2] or x.shape[3] != self.img_latents_shape[3]): if sub_idxs is not None and self.orig_img_latents.size(0) >= full_length: img_latents = comfy.utils.common_upscale(self.orig_img_latents[sub_idxs], x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) else: img_latents = comfy.utils.common_upscale(self.orig_img_latents, x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) img_latents: Tensor = latent_format.process_in(img_latents) # make sure img_latents matches goal_length if goal_length != img_latents.shape[0]: img_latents = ade_broadcast_image_to(img_latents, goal_length, batched_number) img_features = patcher.model.img_encoder(img_latents, goal_length, batched_number) patcher.model.set_img_features(img_features=img_features, apply_ref_when_disabled=self.orig_apply_ref_when_disabled) # cache values for next step self.img_latents_shape = img_latents.shape self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str]): # if no camera_encoder, done if patcher.model.camera_encoder is None: return batched_number = len(cond_or_uncond) full_length = ad_params["full_length"] sub_idxs = ad_params["sub_idxs"] goal_length = x.size(0) // batched_number # calculate camera_features if needed if self.camera_features_shape is None or sub_idxs != self.prev_sub_idxs or batched_number != self.prev_batched_number: # make sure there are enough camera_poses to match full_length camera_poses = self.orig_camera_entries.copy() if len(camera_poses) < full_length: for i in range(full_length-len(camera_poses)): camera_poses.append(camera_poses[-1]) if sub_idxs is not None: camera_poses = [camera_poses[idx] for idx in sub_idxs] # make sure camera_poses matches goal_length if len(camera_poses) > goal_length: camera_poses = camera_poses[:goal_length] elif len(camera_poses) < goal_length: # pad the camera_poses with the last element to match goal_length for i in range(goal_length-len(camera_poses)): camera_poses.append(camera_poses[-1]) # create encoded embeddings b, c, h, w = x.shape plucker_embedding = prepare_pose_embedding(camera_poses, image_width=w*8, image_height=h*8).to(dtype=x.dtype, device=x.device) camera_embedding = patcher.model.camera_encoder(plucker_embedding, video_length=goal_length, batched_number=batched_number) patcher.model.set_camera_features(camera_features=camera_embedding) self.camera_features_shape = len(camera_embedding) self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: # if have cached shape, check if matches - if so, return cached pia_latents if self.prev_pia_latents_shape is not None: if self.prev_pia_latents_shape[0] == x.shape[0] and self.prev_pia_latents_shape[2] == x.shape[2] and self.prev_pia_latents_shape[3] == x.shape[3]: # if mask is also the same for this timestep, then return cached if self.prev_current_pia_input == self.current_pia_input: return self.cached_pia_c_concat # otherwise, adjust new mask, and create new cached_pia_c_concat b, c, h ,w = x.shape mask = prepare_mask_batch(self.combined_pia_mask, x.shape) mask = extend_to_batch_size(mask, b) # make sure to update prev_current_pia_input to know when is changed self.prev_current_pia_input = self.current_pia_input # TODO: handle self.combined_pia_effect eventually (feature hidden for now) # the first index in dim=1 is the mask that needs to be updated - update in place self.cached_pia_c_concat.cond[:, :1, :, :] = mask return self.cached_pia_c_concat self.prev_pia_latents_shape = None # otherwise, x shape should be the cached pia_latents_shape # get currently used models so they can be properly reloaded after perfoming VAE Encoding cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) try: b, c, h ,w = x.shape usable_ref = self.orig_pia_images[:b] # in diffusers, the image is scaled from [-1, 1] instead of default [0, 1], # but form my testing, that blows out the images here, so I skip it # usable_images = usable_images * 2 - 1 # resize images to latent's dims usable_ref = usable_ref.movedim(-1,1) usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.pia_vae.downscale_ratio, height=h*self.pia_vae.downscale_ratio, upscale_method="bilinear", crop="center") usable_ref = usable_ref.movedim(1,-1) # VAE encode images logger.info("VAE Encoding PIA input images...") usable_ref = model.process_latent_in(vae_encode_raw_batched(vae=self.pia_vae, pixels=usable_ref, show_pbar=False)) logger.info("VAE Encoding PIA input images complete.") # make pia_latents match expected length usable_ref = extend_to_batch_size(usable_ref, b) self.prev_pia_latents_shape = x.shape # now, take care of the mask mask = prepare_mask_batch(self.combined_pia_mask, x.shape) mask = extend_to_batch_size(mask, b) #mask = mask.unsqueeze(1) self.prev_current_pia_input = self.current_pia_input if type(self.combined_pia_effect) == Tensor or not math.isclose(self.combined_pia_effect, 1.0): real_pia_effect = self.combined_pia_effect if type(self.combined_pia_effect) == Tensor: real_pia_effect = extend_to_batch_size(prepare_mask_batch(self.combined_pia_effect, x.shape), b) zero_mask = torch.zeros_like(mask) mask = mask * real_pia_effect + zero_mask * (1.0 - real_pia_effect) del zero_mask zero_usable_ref = torch.zeros_like(usable_ref) usable_ref = usable_ref * real_pia_effect + zero_usable_ref * (1.0 - real_pia_effect) del zero_usable_ref # cache pia c_concat self.cached_pia_c_concat = comfy.conds.CONDNoiseShape(torch.cat([mask, usable_ref], dim=1)) return self.cached_pia_c_concat finally: comfy.model_management.load_models_gpu(cached_loaded_models) def get_fancy_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: # if have cached shape, check if matches - if so, return cached fancy_latents if self.prev_fancy_latents_shape is not None: if self.prev_fancy_latents_shape[0] == x.shape[0] and self.prev_fancy_latents_shape[-2] == x.shape[-2] and self.prev_fancy_latents_shape[-1] == x.shape[-1]: # TODO: if mask is also the same for this timestep, then retucn cached return self.cached_fancy_c_concat self.prev_fancy_latents_shape = None # otherwise, x shape should be the cached fancy_latents_shape # get currently used models so they can be properly reloaded after performing VAE Encoding cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) try: b, c, h, w = x.shape usable_ref = self.orig_fancy_images[:b] # resize images to latent's dims usable_ref = usable_ref.movedim(-1,1) usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.fancy_vae.downscale_ratio, height=h*self.fancy_vae.downscale_ratio, upscale_method="bilinear", crop="center") usable_ref = usable_ref.movedim(1,-1) # VAE encode images logger.info("VAE Encoding FancyVideo input images...") usable_ref: Tensor = model.process_latent_in(vae_encode_raw_batched(vae=self.fancy_vae, pixels=usable_ref, show_pbar=False)) logger.info("VAE Encoding FancyVideo input images complete.") self.prev_fancy_latents_shape = x.shape # TODO: experiment with indexes that aren't the first # pad usable_ref with zeros ref_length = usable_ref.shape[0] pad_length = b - ref_length zero_ref = torch.zeros([pad_length, c, h, w], dtype=usable_ref.dtype, device=usable_ref.device) usable_ref = torch.cat([usable_ref, zero_ref], dim=0) del zero_ref # create mask mask_ones = torch.ones([ref_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) mask_zeros = torch.zeros([pad_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) mask = torch.cat([mask_ones, mask_zeros], dim=0) # TODO: experiment with mask strength # cache fancy c_concat - ref first, then mask self.cached_fancy_c_concat = comfy.conds.CONDNoiseShape(torch.cat([usable_ref, mask], dim=1)) return self.cached_fancy_c_concat finally: comfy.model_management.load_models_gpu(cached_loaded_models) def is_pia(self, patcher: MotionModelPatcher): return patcher.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None def is_fancyvideo(self, patcher: MotionModelPatcher): return patcher.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO def cleanup(self, patcher: MotionModelPatcher): if patcher.model is not None: patcher.model.cleanup() # AnimateLCM-I2V del self.img_features self.img_features = None self.img_latents_shape = None # CameraCtrl del self.camera_features self.camera_features = None self.camera_features_shape = None # PIA self.combined_pia_mask = None self.combined_pia_effect = None # Default self.current_used_steps = 0 self.current_keyframe = None self.current_index = -1 self.previous_t = -1 self.current_scale = None self.current_effect = None self.combined_scale = None self.combined_effect = None self.combined_per_block_list = None self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None def on_model_patcher_clone(self): n = MotionModelAttachment() # extra cloned params n.timestep_percent_range = self.timestep_percent_range n.timestep_range = self.timestep_range n.keyframes = self.keyframes.clone() n.scale_multival = self.scale_multival n.effect_multival = self.effect_multival # AnimateLCM-I2V n.orig_img_latents = self.orig_img_latents n.orig_ref_drift = self.orig_ref_drift n.orig_insertion_weights = self.orig_insertion_weights.copy() if self.orig_insertion_weights is not None else self.orig_insertion_weights n.orig_apply_ref_when_disabled = self.orig_apply_ref_when_disabled # CameraCtrl n.orig_camera_entries = self.orig_camera_entries n.cameractrl_multival = self.cameractrl_multival # PIA n.orig_pia_images = self.orig_pia_images n.pia_vae = self.pia_vae n.pia_input = self.pia_input n.pia_multival = self.pia_multival return n class MotionModelGroup: def __init__(self, init_motion_model: MotionModelPatcher=None): self.models: list[MotionModelPatcher] = [] if init_motion_model is not None: if isinstance(init_motion_model, list): for m in init_motion_model: self.add(m) else: self.add(init_motion_model) def add(self, mm: MotionModelPatcher): # add to end of list self.models.append(mm) def add_to_start(self, mm: MotionModelPatcher): self.models.insert(0, mm) def __getitem__(self, index) -> MotionModelPatcher: return self.models[index] def is_empty(self) -> bool: return len(self.models) == 0 def clone(self) -> 'MotionModelGroup': cloned = MotionModelGroup() for mm in self.models: cloned.add(mm) return cloned def set_sub_idxs(self, sub_idxs: list[int]): for motion_model in self.models: motion_model.model.set_sub_idxs(sub_idxs=sub_idxs) def set_view_options(self, view_options: ContextOptions): for motion_model in self.models: motion_model.model.set_view_options(view_options) def set_video_length(self, video_length: int, full_length: int): for motion_model in self.models: motion_model.model.set_video_length(video_length=video_length, full_length=full_length) def initialize_timesteps(self, model: BaseModel): for motion_model in self.models: attachment = get_mm_attachment(motion_model) attachment.initialize_timesteps(model) def pre_run(self, model: ModelPatcher): for motion_model in self.models: motion_model.pre_run() def cleanup(self): for motion_model in self.models: motion_model.cleanup() def prepare_current_keyframe(self, x: Tensor, t: Tensor): for motion_model in self.models: attachment = get_mm_attachment(motion_model) attachment.prepare_current_keyframe(motion_model, x=x, t=t) def get_special_models(self): pia_motion_models: list[MotionModelPatcher] = [] for motion_model in self.models: attachment = get_mm_attachment(motion_model) if attachment.is_pia(motion_model) or attachment.is_fancyvideo(motion_model): pia_motion_models.append(motion_model) return pia_motion_models def get_name_string(self, show_version=False): identifiers = [] for motion_model in self.models: id = motion_model.model.mm_info.mm_name if show_version: id += f":{motion_model.model.mm_info.mm_version}" identifiers.append(id) return ", ".join(identifiers) def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher: model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update) model.patches = {} for k in m.patches: model.patches[k] = m.patches[k][:] model.object_patches = m.object_patches.copy() model.model_options = copy.deepcopy(m.model_options) if hasattr(model, "model_keys"): model.model_keys = m.model_keys return model # adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py # Example LoRA keys: # down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.down.weight # down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.up.weight # # Example model keys: # down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight # def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None: def get_version(has_midblock: bool): return "v2" if has_midblock else "v1" lora_path = get_motion_lora_path(lora.name) logger.info(f"Loading motion LoRA {lora.name}") state_dict = comfy.utils.load_torch_file(lora_path) # remove all non-temporal keys (in case model has extra stuff in it) for key in list(state_dict.keys()): if "temporal" not in key: del state_dict[key] if len(state_dict) == 0: raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!") model_has_midblock = motion_model.model.mid_block != None lora_has_midblock = has_mid_block(state_dict) logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.") patches = {} # convert lora state dict to one that matches motion_module keys and tensors for key in state_dict: # if motion_module doesn't have a midblock, skip mid_block entries if not model_has_midblock: if "mid_block" in key: continue # only process lora down key (we will process up at the same time as down) if "up." in key: continue # get up key version of down key up_key = key.replace(".down.", ".up.") # adapt key to match motion_module key format - remove 'processor.', '_lora', 'down.', and 'up.' model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") # motion_module keys have a '0.' after all 'to_out.' weight keys if "to_out.0." not in model_key: model_key = model_key.replace("to_out.", "to_out.0.") weight_down = state_dict[key] weight_up = state_dict[up_key] # actual weights obtained by matrix multiplication of up and down weights # save as a tuple, so that (Motion)ModelPatcher's calculate_weight function detects len==1, applying it correctly patches[model_key] = (torch.mm( comfy.model_management.cast_to_device(weight_up, weight_up.device, torch.float32), comfy.model_management.cast_to_device(weight_down, weight_down.device, torch.float32) ),) del state_dict # add patches to motion ModelPatcher motion_model.add_patches(patches=patches, strength_patch=lora.strength) def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: model_path = get_motion_model_path(model_name) logger.info(f"Loading motion module {model_name}") mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) # TODO: check for empty state dict? # get normalized state_dict and motion model info mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) # check that motion model is compatible with sd model model_sd_type = get_sd_model_type(model) if model_sd_type != mm_info.sd_type: raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ + f"but the provided model is type {model_sd_type}.") # apply motion model settings mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) # initialize AnimateDiffModelWrapper ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) ad_wrapper.to(model.model_dtype()) ad_wrapper.to(model.offload_device) load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) # load motion_lora, if present if motion_lora is not None: for lora in motion_lora.loras: load_motion_lora_as_patches(motion_model, lora) return motion_model def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: model_path = get_motion_model_path(model_name) logger.info(f"Loading motion module {model_name} via Gen2") mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) # TODO: check for empty state dict? # get normalized state_dict and motion model info (converts alternate AD models like HotshotXL into AD keys) mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) # apply motion model settings mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) # initialize AnimateDiffModelWrapper ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return motion_model IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys']) def verify_load_result(load_result: IncompatibleKeys, mm_info: AnimateDiffInfo): error_msgs: list[str] = [] is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM remove_missing_idxs = [] remove_unexpected_idxs = [] for idx, key in enumerate(load_result.missing_keys): # NOTE: AnimateLCM has no pe keys in the model file, so any errors associated with missing pe keys can be ignored if is_animatelcm and "pos_encoder.pe" in key: remove_missing_idxs.append(idx) # remove any keys to ignore in reverse order (to preserve idx correlation) for idx in reversed(remove_unexpected_idxs): load_result.unexpected_keys.pop(idx) for idx in reversed(remove_missing_idxs): load_result.missing_keys.pop(idx) # copied over from torch.nn.Module.module class Module's load_state_dict func if len(load_result.unexpected_keys) > 0: error_msgs.insert( 0, 'Unexpected key(s) in state_dict: {}. '.format( ', '.join(f'"{k}"' for k in load_result.unexpected_keys))) if len(load_result.missing_keys) > 0: error_msgs.insert( 0, 'Missing key(s) in state_dict: {}. '.format( ', '.join(f'"{k}"' for k in load_result.missing_keys))) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( mm_info.mm_name, "\n\t".join(error_msgs))) def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher: ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) ad_wrapper.load_state_dict(motion_model.model.state_dict()) return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) def create_fresh_encoder_only_model(motion_model: MotionModelPatcher) -> MotionModelPatcher: ad_wrapper = EncoderOnlyAnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) ad_wrapper.load_state_dict(motion_model.model.state_dict(), strict=False) return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: MotionModelPatcher): motion_model.model.init_img_encoder() motion_model.model.img_encoder.to(comfy.model_management.unet_dtype()) motion_model.model.img_encoder.to(comfy.model_management.unet_offload_device()) motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict()) def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher): motion_model.model.init_conv_in(w_pia.model.state_dict()) motion_model.model.conv_in.to(comfy.model_management.unet_dtype()) motion_model.model.conv_in.to(comfy.model_management.unet_offload_device()) motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict()) motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str): camera_ctrl_path = get_motion_model_path(camera_ctrl_name) full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True) camera_state_dict: dict[str, Tensor] = dict() attention_state_dict: dict[str, Tensor] = dict() for key in full_state_dict: if key.startswith("encoder"): camera_state_dict[key] = full_state_dict[key] elif "qkv_merge" in key: attention_state_dict[key] = full_state_dict[key] # verify has necessary keys if len(camera_state_dict) == 0: raise Exception("Provided CameraCtrl model had no Camera Encoder-related keys; not a valid CameraCtrl model!") if len(attention_state_dict) == 0: raise Exception("Provided CameraCtrl model had no qkv_merge keys; not a valid CameraCtrl model!") # initialize CameraPoseEncoder on motion model, and load keys camera_encoder = CameraPoseEncoder(channels=motion_model.model.layer_channels, nums_rb=2, ops=motion_model.model.ops).to( device=comfy.model_management.unet_offload_device(), dtype=comfy.model_management.unet_dtype() ) camera_encoder.load_state_dict(camera_state_dict) camera_encoder.temporal_pe_max_len = get_position_encoding_max_len(camera_state_dict, mm_name=camera_ctrl_name, mm_format=AnimateDiffFormat.ANIMATEDIFF) motion_model.model.set_camera_encoder(camera_encoder=camera_encoder) # initialize qkv_merge on specific attention blocks, and load keys for key in attention_state_dict: key = key.strip() # to avoid handling the same qkv_merge twice, only pay attention to the bias keys (bias+weight handled together) if key.endswith("weight"): continue attr_path = key.split(".processor.qkv_merge")[0] base_key = key.split(".bias")[0] # first, initialize qkv_merge on model attention_obj: VersatileAttention = comfy.utils.get_attr(motion_model.model, attr_path) attention_obj.init_qkv_merge(ops=motion_model.model.ops) # then, apply weights to qkv_merge qkv_merge_state_dict = {} qkv_merge_state_dict["weight"] = attention_state_dict[f"{base_key}.weight"] qkv_merge_state_dict["bias"] = attention_state_dict[f"{base_key}.bias"] attention_obj.qkv_merge.load_state_dict(qkv_merge_state_dict) attention_obj.qkv_merge = attention_obj.qkv_merge.to( device=comfy.model_management.unet_offload_device(), dtype=comfy.model_management.unet_dtype() ) def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionModelPatcher): # check that motion model is compatible with sd model model_sd_type = get_sd_model_type(model) mm_info = motion_model.model.mm_info if model_sd_type != mm_info.sd_type: raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ + f"but the provided model is type {model_sd_type}.") def validate_per_block_compatibility(motion_model: MotionModelPatcher, all_per_blocks: AllPerBlocks): if all_per_blocks is None or all_per_blocks.sd_type is None: return mm_info = motion_model.model.mm_info if all_per_blocks.sd_type != mm_info.sd_type: raise Exception(f"Per-Block provided is meant for {all_per_blocks.sd_type}, but provided motion module is for {mm_info.sd_type}.") def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): pe_shape = model_dict[key].shape temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) model_dict[key] = temp_pe del temp_pe def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int): # TODO: fill out and try out pe_shape = model_dict[key].shape temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) model_dict[key] = temp_pe del temp_pe def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int): if model_dict[key].shape[1] < new_length: temp_pe = model_dict[key] flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1]) use_flipped = True preview_pe = None while model_dict[key].shape[1] < new_length: preview_pe = model_dict[key] model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1) use_flipped = not use_flipped del temp_pe del flipped_temp_pe del preview_pe model_dict[key] = model_dict[key][:, :new_length] def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str): pe_portion = model_dict[key].shape[2] // 64 first_pe = model_dict[key][:,:1,:] model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:] del first_pe def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str): attn_portion = model_dict[key].shape[0] // 2 model_dict[key][:attn_portion,:attn_portion] *= 1.5 def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: AnimateDiffSettings) -> dict[str, Tensor]: if mm_settings is None: return model_dict if not mm_settings.has_anything_to_apply(): return model_dict # first, handle PE Adjustments for adjust_pe in mm_settings.adjust_pe.adjusts: adjust_pe: AdjustPE if adjust_pe.has_anything_to_apply(): already_printed = False for key in model_dict: if "attention_blocks" in key and "pos_encoder" in key: # apply simple motion pe stretch, if needed if adjust_pe.has_motion_pe_stretch(): original_length = model_dict[key].shape[1] new_pe_length = original_length + adjust_pe.motion_pe_stretch interpolate_pe_to_length(model_dict, key, new_length=new_pe_length) if adjust_pe.print_adjustment and not already_printed: logger.info(f"[Adjust PE]: PE Stretch from {original_length} to {new_pe_length}.") # apply pe_idx_offset, if needed if adjust_pe.has_initial_pe_idx_offset(): original_length = model_dict[key].shape[1] model_dict[key] = model_dict[key][:, adjust_pe.initial_pe_idx_offset:] if adjust_pe.print_adjustment and not already_printed: logger.info(f"[Adjust PE]: Offsetting PEs by {adjust_pe.initial_pe_idx_offset}; PE length to shortens from {original_length} to {model_dict[key].shape[1]}.") # apply has_cap_initial_pe_length, if needed if adjust_pe.has_cap_initial_pe_length(): original_length = model_dict[key].shape[1] model_dict[key] = model_dict[key][:, :adjust_pe.cap_initial_pe_length] if adjust_pe.print_adjustment and not already_printed: logger.info(f"[Adjust PE]: Capping PEs (initial) from {original_length} to {model_dict[key].shape[1]}.") # apply interpolate_pe_to_length, if needed if adjust_pe.has_interpolate_pe_to_length(): original_length = model_dict[key].shape[1] interpolate_pe_to_length(model_dict, key, new_length=adjust_pe.interpolate_pe_to_length) if adjust_pe.print_adjustment and not already_printed: logger.info(f"[Adjust PE]: Interpolating PE length from {original_length} to {model_dict[key].shape[1]}.") # apply final_pe_idx_offset, if needed if adjust_pe.has_final_pe_idx_offset(): original_length = model_dict[key].shape[1] model_dict[key] = model_dict[key][:, adjust_pe.final_pe_idx_offset:] if adjust_pe.print_adjustment and not already_printed: logger.info(f"[Adjust PE]: Capping PEs (final) from {original_length} to {model_dict[key].shape[1]}.") already_printed = True # finally, handle Weight Adjustments for adjust_w in mm_settings.adjust_weight.adjusts: adjust_w: AdjustWeight if adjust_w.has_anything_to_apply(): adjust_w.mark_attrs_as_unprinted() for key in model_dict: # apply global weight adjustments, if needed adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ALL, model_dict=model_dict, key=key) if "attention_blocks" in key: # apply pe change, if needed if "pos_encoder" in key: adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_PE, model_dict=model_dict, key=key) else: # apply attn change, if needed adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN, model_dict=model_dict, key=key) # apply specific attn changes, if needed # apply attn_q change, if needed if "to_q" in key: adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_Q, model_dict=model_dict, key=key) # apply attn_q change, if needed elif "to_k" in key: adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_K, model_dict=model_dict, key=key) # apply attn_q change, if needed elif "to_v" in key: adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_V, model_dict=model_dict, key=key) # apply to_out changes, if needed elif "to_out" in key: if key.strip().endswith("weight"): adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_OUT_WEIGHT, model_dict=model_dict, key=key) elif key.strip().endswith("bias"): adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_ATTN_OUT_BIAS, model_dict=model_dict, key=key) else: adjust_w.perform_applicable_ops(attr=AdjustWeight.ATTR_OTHER, model_dict=model_dict, key=key) return model_dict class InjectionParams: def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, apply_v2_properly: bool=True) -> None: self.full_length = None self.unlimited_area_hack = unlimited_area_hack self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack self.apply_v2_properly = apply_v2_properly self.context_options: ContextOptionsGroup = ContextOptionsGroup.default() self.motion_model_settings = AnimateDiffSettings() # Gen1 self.sub_idxs = None # value should NOT be included in clone, so it will auto reset def set_noise_extra_args(self, noise_extra_args: dict): noise_extra_args["context_options"] = self.context_options.clone() def set_context(self, context_options: ContextOptionsGroup): self.context_options = context_options.clone() if context_options else ContextOptionsGroup.default() def is_using_sliding_context(self) -> bool: return self.context_options.context_length is not None def set_motion_model_settings(self, motion_model_settings: AnimateDiffSettings): # Gen1 if motion_model_settings is None: self.motion_model_settings = AnimateDiffSettings() else: self.motion_model_settings = motion_model_settings def reset_context(self): self.context_options = ContextOptionsGroup.default() def clone(self) -> 'InjectionParams': new_params = InjectionParams( self.unlimited_area_hack, self.apply_mm_groupnorm_hack, apply_v2_properly=self.apply_v2_properly, ) new_params.full_length = self.full_length new_params.set_context(self.context_options) new_params.set_motion_model_settings(self.motion_model_settings) # Gen1 return new_params def on_model_patcher_clone(self): return self.clone()