from typing import Callable import math import torch from torch import Tensor from torch.nn.functional import group_norm from einops import rearrange import comfy.model_management import comfy.model_patcher import comfy.patcher_extension import comfy.samplers import comfy.sampler_helpers import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher from comfy.patcher_extension import WrapperExecutor, WrappersMP import comfy.conds import comfy.ops from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows from .context_extras import ContextRefMode from .sample_settings import SampleSettings, NoisedImageToInject from .utils_model import MachineState, vae_encode_raw_batched, vae_decode_raw_batched from .utils_motion import composite_extend, prepare_mask_batch, extend_to_batch_size from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, get_mm_attachment from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion from .logger import logger ################################################################################## ###################################################################### # Global variable to use to more conveniently hack variable access into samplers class AnimateDiffGlobalState: def __init__(self): self.model_patcher: ModelPatcher = None self.motion_models: MotionModelGroup = None self.params: InjectionParams = None self.sample_settings: SampleSettings = None self.callback_output_dict: dict[str] = {} self.function_injections: FunctionInjectionHolder = None self.reset() def initialize(self, model: BaseModel): # this function is to be run in sampling func if not self.initialized: self.initialized = True if self.motion_models is not None: self.motion_models.initialize_timesteps(model) if self.params.context_options is not None: self.params.context_options.initialize_timesteps(model) if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.initialize_timesteps(model) def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.motion_models is not None: self.motion_models.prepare_current_keyframe(x=x, t=timestep) if self.params.context_options is not None: self.params.context_options.prepare_current(t=timestep) if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor, model_options: dict[str]): if self.motion_models is not None: special_models = self.motion_models.get_special_models() if len(special_models) > 0: for special_model in special_models: if special_model.model.is_in_effect(): attachment = get_mm_attachment(special_model) if attachment.is_pia(special_model): special_model.model.inject_unet_conv_in_pia_fancyvideo(model) conds = get_conds_with_c_concat(conds, attachment.get_pia_c_concat(model, x_in)) elif attachment.is_fancyvideo(special_model): # TODO: handle other weights special_model.model.inject_unet_conv_in_pia_fancyvideo(model) conds = get_conds_with_c_concat(conds, attachment.get_fancy_c_concat(model, x_in)) # add fps_embedding/motion_embedding patches emb_patches = special_model.model.get_fancyvideo_emb_patches(dtype=x_in.dtype, device=x_in.device) transformer_patches = model_options["transformer_options"].get("patches", {}) transformer_patches["emb_patch"] = emb_patches model_options["transformer_options"]["patches"] = transformer_patches return conds def restore_special_model_features(self, model: BaseModel): if self.motion_models is not None: special_models = self.motion_models.get_special_models() if len(special_models) > 0: for special_model in reversed(special_models): attachment = get_mm_attachment(special_model) if attachment.is_pia(special_model): special_model.model.restore_unet_conv_in_pia_fancyvideo(model) elif attachment.is_fancyvideo(special_model): # TODO: fill out special_model.model.restore_unet_conv_in_pia_fancyvideo(model) def reset(self): self.initialized = False self.hooks_initialized = False self.start_step: int = 0 self.last_step: int = 0 self.current_step: int = 0 self.total_steps: int = 0 self.callback_output_dict.clear() self.callback_output_dict = {} if self.model_patcher is not None: self.model_patcher.clean_hooks() del self.model_patcher self.model_patcher = None if self.motion_models is not None: del self.motion_models self.motion_models = None if self.params is not None: self.params.context_options.reset() del self.params self.params = None if self.sample_settings is not None: del self.sample_settings self.sample_settings = None if self.function_injections is not None: del self.function_injections self.function_injections = None def update_with_inject_params(self, params: InjectionParams): self.params = params def is_using_sliding_context(self): return self.params is not None and self.params.is_using_sliding_context() def create_exposed_params(self): # This dict will be exposed to be used by other extensions # DO NOT change any of the key names # or I will find you 👁.👁 return { "full_length": self.params.full_length, "context_length": self.params.context_options.context_length, "sub_idxs": self.params.sub_idxs, } ###################################################################### ################################################################################## ################################################################################## #### Code Injection ################################################## def unlimited_memory_required(*args, **kwargs): return 0 def groupnorm_mm_factory(params: InjectionParams, manual_cast=False): def groupnorm_mm_forward(self, input: Tensor) -> Tensor: # axes_factor normalizes batch based on total conds and unconds passed in batch; # the conds and unconds per batch can change based on VRAM optimizations that may kick in if not params.is_using_sliding_context(): batched_conds = input.size(0)//params.full_length else: batched_conds = input.size(0)//params.context_options.context_length input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds) if manual_cast: weight, bias = comfy.ops.cast_bias_weight(self, input) else: weight, bias = self.weight, self.bias input = group_norm(input, self.num_groups, weight, bias, self.eps) input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds) return input return groupnorm_mm_forward def create_special_model_apply_model_wrapper(model_options: dict): comfy.patcher_extension.add_wrapper_with_key(WrappersMP.APPLY_MODEL, "ADE_special_model_apply_model", _apply_model_wrapper, model_options, is_model_options=True) def _apply_model_wrapper(executor, *args, **kwargs): # args (from BaseModel._apply_model): # 0: x # 1: t # 2: c_concat # 3: c_crossattn # 4: control # 5: transformer_options x: Tensor = args[0] transformer_options = args[5] cond_or_uncond = transformer_options["cond_or_uncond"] ad_params = transformer_options["ad_params"] ADGS: AnimateDiffGlobalState = transformer_options["ADGS"] if ADGS.motion_models is not None: for motion_model in ADGS.motion_models.models: attachment = get_mm_attachment(motion_model) attachment.prepare_alcmi2v_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=executor.class_obj.latent_format) attachment.prepare_camera_features(motion_model, x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params) del x return executor(*args, **kwargs) def create_diffusion_model_groupnormed_wrapper(model_options: dict, inject_helper: 'GroupnormInjectHelper'): comfy.patcher_extension.add_wrapper_with_key(WrappersMP.DIFFUSION_MODEL, "ADE_groupnormed_diffusion_model", _diffusion_model_groupnormed_wrapper_factory(inject_helper), model_options, is_model_options=True) def _diffusion_model_groupnormed_wrapper_factory(inject_helper: 'GroupnormInjectHelper'): def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs): with inject_helper: return executor(*args, **kwargs) return _diffusion_model_groupnormed_wrapper ###################################################################### ################################################################################## def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams): params = params.clone() for context in params.context_options.contexts: if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT: context.context_length = params.full_length # TODO: check (and message) should be different based on use_on_equal_length setting if params.context_options.context_length: pass allow_equal = params.context_options.use_on_equal_length if params.context_options.context_length: enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length else: enough_latents = False if params.context_options.context_length and enough_latents: logger.info(f"Sliding context window sampling activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.") else: logger.info(f"Regular sampling activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.") params.reset_context() if helper.get_motion_models(): # if no context_length, treat video length as intended AD frame window if not params.context_options.context_length: for motion_model in helper.get_motion_models(): if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length): raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.") helper.set_video_length(params.full_length, params.full_length) # otherwise, treat context_length as intended AD frame window else: for motion_model in helper.get_motion_models(): view_options = params.context_options.view_options context_length = view_options.context_length if view_options else params.context_options.context_length if not motion_model.model.is_length_valid_for_encoding_max_len(context_length): raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.") helper.set_video_length(params.context_options.context_length, params.full_length) # inject model module_str = "modules" if len(helper.get_motion_models()) > 1 else "module" logger.info(f"Using motion {module_str} {helper.get_name_string(show_version=True)}.") return params class FunctionInjectionHolder: def __init__(self): self.temp_uninjector: GroupnormUninjectHelper = GroupnormUninjectHelper() self.groupnorm_injector: GroupnormInjectHelper = GroupnormInjectHelper() def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, model_options: dict): # Save Original Functions - order must match between here and restore_functions self.orig_memory_required = None self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers # Inject Functions if params.unlimited_area_hack: # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_memory_required = helper.model.model.memory_required helper.model.model.memory_required = unlimited_memory_required if helper.get_motion_models(): # only apply groupnorm hack if PIA, v2 and not properly applied, or v1 info: AnimateDiffInfo = helper.get_motion_models()[0].model.mm_info if ((info.mm_format == AnimateDiffFormat.PIA) or (info.mm_version == AnimateDiffVersion.V2 and not params.apply_v2_properly) or (info.mm_version == AnimateDiffVersion.V1)): self.inject_groupnorm_forward = groupnorm_mm_factory(params) self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) self.groupnorm_injector = GroupnormInjectHelper(self) create_diffusion_model_groupnormed_wrapper(model_options, self.groupnorm_injector) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: if helper.model.load_device.type == "mps": self.orig_memory_required = helper.model.model.memory_required helper.model.model.memory_required = unlimited_memory_required except Exception: pass # if img_encoder or camera_encoder present, inject apply_model to handle correctly for motion_model in helper.get_motion_models(): if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None): create_special_model_apply_model_wrapper(model_options) break del info comfy.samplers.sampling_function = evolved_sampling_function # create temp_uninjector to help facilitate uninjecting functions self.temp_uninjector = GroupnormUninjectHelper(self) def restore_functions(self, helper: ModelPatcherHelper): # Restoration try: if self.orig_memory_required is not None: helper.model.model.memory_required = self.orig_memory_required torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights comfy.samplers.sampling_function = self.orig_sampling_function except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ "to save original functions before injection, and a more specific error was thrown by ComfyUI.") class GroupnormUninjectHelper: def __init__(self, holder: FunctionInjectionHolder=None): self.holder = holder self.previous_gn_forward = None self.previous_dwi_gn_cast_weights = None def __enter__(self): if self.holder is None: return self # backup current groupnorm funcs self.previous_gn_forward = torch.nn.GroupNorm.forward self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights # restore groupnorm to default state torch.nn.GroupNorm.forward = self.holder.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.orig_groupnorm_forward_comfy_cast_weights return self def __exit__(self, *args, **kwargs): if self.holder is None: return # bring groupnorm back to previous state torch.nn.GroupNorm.forward = self.previous_gn_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights self.previous_gn_forward = None self.previous_dwi_gn_cast_weights = None class GroupnormInjectHelper: def __init__(self, holder: FunctionInjectionHolder=None): self.holder = holder self.previous_gn_forward = None self.previous_dwi_gn_cast_weights = None def __enter__(self): if self.holder is None: return self # store previous gn_forward self.previous_gn_forward = torch.nn.GroupNorm.forward self.previous_dwi_gn_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights # inject groupnorm functions torch.nn.GroupNorm.forward = self.holder.inject_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.holder.inject_groupnorm_forward_comfy_cast_weights return self def __exit__(self, *args, **kwargs): if self.holder is None: return # bring groupnorm back to previous state torch.nn.GroupNorm.forward = self.previous_gn_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.previous_dwi_gn_cast_weights self.previous_gn_forward = None self.previous_dwi_gn_cast_weights = None def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): # NOTE: OUTER_SAMPLE wrapper patch in ModelPatcher latents = None cached_latents = None cached_noise = None function_injections = FunctionInjectionHolder() try: guider: comfy.samplers.CFGGuider = executor.class_obj helper = ModelPatcherHelper(guider.model_patcher) orig_model_options = guider.model_options guider.model_options = comfy.model_patcher.create_model_options_clone(guider.model_options) # create ADGS in transformer_options ADGS = AnimateDiffGlobalState() guider.model_options["transformer_options"]["ADGS"] = ADGS args = list(args) # clone params from model params = helper.get_params().clone() # get amount of latents passed in, and store in params noise: Tensor = args[0] latents: Tensor = args[1] params.full_length = latents.size(0) # reset global state ADGS.reset() # apply custom noise, if needed disable_noise = math.isclose(noise.max(), 0.0) seed = args[-1] # apply params to motion model params = apply_params_to_motion_models(helper, params) # store and inject funtions function_injections.inject_functions(helper, params, guider.model_options) # prepare noise_extra_args for noise generation purposes noise_extra_args = {"disable_noise": disable_noise} params.set_noise_extra_args(noise_extra_args) # if noise is not disabled, do noise stuff if not disable_noise: noise = helper.get_sample_settings().prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False) # callback setup original_callback = args[-3] def ad_callback(step, x0, x, total_steps): if original_callback is not None: original_callback(step, x0, x, total_steps) # store denoised latents if image_injection will be used if not helper.get_sample_settings().image_injection.is_empty(): ADGS.callback_output_dict["x0"] = x0 # update GLOBALSTATE for next iteration ADGS.current_step = ADGS.start_step + step + 1 args[-3] = ad_callback ADGS.model_patcher = helper.model ADGS.motion_models = MotionModelGroup(helper.get_motion_models()) ADGS.sample_settings = helper.get_sample_settings() ADGS.function_injections = function_injections # apply adapt_denoise_steps - does not work here! would need to mess with this elsewhere... # TODO: implement proper wrapper to handle this feature... iter_opts = helper.get_sample_settings().iteration_opts iter_opts.initialize(latents) # cache initial noise and latents, if needed if iter_opts.cache_init_latents: cached_latents = latents.clone() if iter_opts.cache_init_noise: cached_noise = noise.clone() # prepare iter opts preprocess kwargs, if needed iter_kwargs = {} # NOTE: original KSampler stuff is not doable here, so skipping... for curr_i in range(iter_opts.iterations): # handle GLOBALSTATE vars and step tally # NOTE: only KSampler/KSampler (Advanced) would have steps; # explore modifying ComfyUI to provide this when possible? ADGS.update_with_inject_params(params) ADGS.start_step = kwargs.get("start_step") or 0 ADGS.current_step = ADGS.start_step ADGS.last_step = kwargs.get("last_step") or 0 if iter_opts.iterations > 1: logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}") # perform any iter_opts preprocessing on latents latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=helper.model, latents=latents, noise=noise, cached_latents=cached_latents, cached_noise=cached_noise, seed=seed, sample_settings=helper.get_sample_settings(), noise_extra_args=noise_extra_args, **iter_kwargs) if helper.get_sample_settings().noise_calibration is not None: latents, noise = helper.get_sample_settings().noise_calibration.perform_calibration(sample_func=executor, model=helper.model, latents=latents, noise=noise, is_custom=True, args=args, kwargs=kwargs) # finalize latent_image in args args[0] = noise args[1] = latents helper.pre_run() if ADGS.sample_settings.image_injection.is_empty(): latents = executor(*tuple(args), **kwargs) else: ADGS.sample_settings.image_injection.initialize_timesteps(helper.model.model) sigmas = args[3] sigmas_list, injection_list = ADGS.sample_settings.image_injection.custom_ksampler_get_injections(helper.model, sigmas) # useful logging if len(injection_list) > 0: inj_str = "s" if len(injection_list) > 1 else "" logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(sigmas_list)}.") else: logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") is_first = True new_noise = noise for i in range(len(sigmas_list)): args[0] = new_noise args[1] = latents args[3] = sigmas_list[i] latents = executor(*tuple(args), **kwargs) if is_first: new_noise = torch.zeros_like(latents) # if injection expected, perform injection if i < len(injection_list): to_inject = injection_list[i] latents = perform_image_injection(ADGS, helper.model.model, latents, to_inject) return latents finally: guider.model_options = orig_model_options del noise del latents del cached_latents del cached_noise del orig_model_options # reset global state ADGS.reset() # clean motion_models helper.cleanup_motion_models() # restore injected functions function_injections.restore_functions(helper) del function_injections del helper def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None): ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"] ADGS.initialize(model) ADGS.prepare_current_keyframes(x=x, timestep=timestep) try: # add AD/evolved-sampling params to model_options (transformer_options) model_options = model_options.copy() if "transformer_options" not in model_options: model_options["transformer_options"] = {} else: model_options["transformer_options"] = model_options["transformer_options"].copy() model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() cond, uncond = ADGS.perform_special_model_features(model, [cond, uncond], x, model_options) # only use cfg1_optimization if not using custom_cfg or explicitly set to 1.0 uncond_ = uncond if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None elif ADGS.sample_settings.custom_cfg is not None: cfg_multival = ADGS.sample_settings.custom_cfg.cfg_multival if type(cfg_multival) != Tensor and math.isclose(cfg_multival, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None del cfg_multival cond_pred, uncond_pred = comfy.samplers.calc_cond_batch(model, [cond, uncond_], x, timestep, model_options) if ADGS.sample_settings.custom_cfg is not None: cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred) model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options) return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) finally: ADGS.restore_special_model_features(model) def perform_image_injection(ADGS: AnimateDiffGlobalState, model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor: # NOTE: the latents here have already been process_latent_out'ed # get currently used models so they can be properly reloaded after perfoming VAE Encoding cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) try: orig_device = latents.device orig_dtype = latents.dtype # follow same steps as in KSampler Custom to get same denoised_x0 value x0 = ADGS.callback_output_dict.get("x0", None) if x0 is None: return latents # x0 should be process_latent_out'ed to match expected state of latents between nodes x0 = model.process_latent_out(x0) # first, decode x0 into images, and then re-encode decoded_images = vae_decode_raw_batched(to_inject.vae, x0) encoded_x0 = vae_encode_raw_batched(to_inject.vae, decoded_images) # get difference between sampled latents and encoded_x0 latents = latents.to(device=encoded_x0.device) encoded_x0 = latents - encoded_x0 # get mask, or default to full mask mask = to_inject.mask b, c, h, w = encoded_x0.shape # need to resize images and masks to match expected dims if mask is None: mask = torch.ones(1, h, w) if to_inject.invert_mask: mask = 1.0 - mask opts = to_inject.img_inject_opts # composite decoded_x0 with image to inject; # make sure to move dims to match expectation of (b,c,h,w) composited = composite_extend(destination=decoded_images.movedim(-1, 1), source=to_inject.image.movedim(-1, 1), x=opts.x, y=opts.y, mask=mask, multiplier=to_inject.vae.downscale_ratio, resize_source=to_inject.resize_image).movedim(1, -1) # encode composited to get latent representation composited = vae_encode_raw_batched(to_inject.vae, composited) # add encoded_x0 diff to composited composited += encoded_x0 if type(to_inject.strength_multival) == float and math.isclose(1.0, to_inject.strength_multival): return composited.to(dtype=orig_dtype, device=orig_device) strength = to_inject.strength_multival if type(strength) == Tensor: strength = extend_to_batch_size(prepare_mask_batch(strength, composited.shape), b) return (composited * strength + latents * (1.0 - strength)).to(dtype=orig_dtype, device=orig_device) finally: comfy.model_management.load_models_gpu(cached_loaded_models) # initial sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context: # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]], x_in: Tensor, timestep, model_options): ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"] if not ADGS.is_using_sliding_context(): return executor(model, conds, x_in, timestep, model_options) def prepare_control_objects(control: ControlBase, full_idxs: list[int]): if control.previous_controlnet is not None: prepare_control_objects(control.previous_controlnet, full_idxs) if not hasattr(control, "sub_idxs"): raise ValueError(f"Control type {type(control).__name__} may not support required features for sliding context window; \ use ControlNet nodes from Kosinkadink/ComfyUI-Advanced-ControlNet, or make sure ComfyUI-Advanced-ControlNet is updated.") control.sub_idxs = full_idxs control.full_latent_length = ADGS.params.full_length control.context_length = ADGS.params.context_options.context_length def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list: if cond_in is None: return None # reuse or resize cond items to match context requirements resized_cond = [] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary for key in actual_cond: try: cond_item = actual_cond[key] if isinstance(cond_item, Tensor): # check that tensor is the expected length - x.size(0) if cond_item.size(0) == x_in.size(0): # if so, it's subsetting time - tell controls the expected indeces so they can handle them actual_cond_item = cond_item[full_idxs] resized_actual_cond[key] = actual_cond_item else: resized_actual_cond[key] = cond_item # look for control elif key == "control": control_item = cond_item prepare_control_objects(control_item, full_idxs) resized_actual_cond[key] = control_item del control_item elif isinstance(cond_item, dict): new_cond_item = cond_item.copy() # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, Tensor): if cond_value.size(0) == x_in.size(0): new_cond_item[cond_key] = cond_value[full_idxs] # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, Tensor): if cond_value.cond.size(0) == x_in.size(0): new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond[full_idxs]) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = context_length resized_actual_cond[key] = new_cond_item else: resized_actual_cond[key] = cond_item finally: del cond_item # just in case to prevent VRAM issues resized_cond.append(resized_actual_cond) return resized_cond # get context windows ADGS.params.context_options.step = ADGS.current_step context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options) if ADGS.motion_models is not None: ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options) # prepare final conds, out_counts, and biases conds_final = [torch.zeros_like(x_in) for _ in conds] if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # counts_final not used for RELATIVE fuse_method counts_final = [torch.ones((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] else: # default counts_final initialization counts_final = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] biases_final = [([0.0] * x_in.shape[0]) for _ in conds] CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" contextref_active = False contextref_mode = None contextref_idxs_set = None first_context = True # need to make sure that contextref stuff gets cleaned up, no matter what try: if ADGS.params.context_options.extras.should_run_context_ref(): # check that ACN provided ContextRef as requested temp_refcn_list = model_options["transformer_options"].get(CONTEXTREF_CONTROL_LIST_ALL, None) if temp_refcn_list is None: raise Exception("Advanced-ControlNet nodes are either missing or too outdated to support ContextRef. Update/install ComfyUI-Advanced-ControlNet to use ContextRef.") if len(temp_refcn_list) == 0: raise Exception("Unexpected ContextRef issue; Advanced-ControlNet did not provide any ContextRef objs for AnimateDiff-Evolved.") del temp_refcn_list # check if ContextRef ReferenceAdvanced ACN objs should_run actually_should_run = True for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: refcn.prepare_current_timestep(timestep) if not refcn.should_run(): actually_should_run = False if actually_should_run: contextref_active = True for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: # get mode_override if present, mode otherwise contextref_mode = refcn.get_contextref_mode_replace() or ADGS.params.context_options.extras.context_ref.mode contextref_idxs_set = contextref_mode.indexes.copy() curr_window_idx = -1 naivereuse_active = False cached_naive_conds = None cached_naive_ctx_idxs = None if ADGS.params.context_options.extras.should_run_naive_reuse(): cached_naive_conds = [torch.zeros_like(x_in) for _ in conds] #cached_naive_counts = [torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) for _ in conds] naivereuse_active = True # perform calc_conds_batch per context window for ctx_idxs in context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() curr_window_idx += 1 ADGS.params.sub_idxs = ctx_idxs if ADGS.motion_models is not None: ADGS.motion_models.set_sub_idxs(ctx_idxs) ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) # update exposed params model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) # get subsections of x, timestep, conds sub_x = x_in[ctx_idxs] sub_timestep = timestep[ctx_idxs] sub_conds = [get_resized_cond(cond, ctx_idxs, len(ctx_idxs)) for cond in conds] if contextref_active: # set cond counter to 0 (each cond encountered will increment it by 1) for refcn in model_options["transformer_options"][CONTEXTREF_CONTROL_LIST_ALL]: refcn.contextref_cond_idx = 0 if first_context: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE else: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ if contextref_mode.mode == ContextRefMode.SLIDING: # if sliding, check if time to READ and WRITE if curr_window_idx % (contextref_mode.sliding_width-1) == 0: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE # override with indexes mode, if set if contextref_mode.mode == ContextRefMode.INDEXES: contains_idx = False for i in ctx_idxs: if i in contextref_idxs_set: contains_idx = True # single trigger decides if each index should only trigger READ_WRITE once per step if not contextref_mode.single_trigger: break contextref_idxs_set.remove(i) if contains_idx: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ_WRITE if first_context: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.WRITE else: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.READ else: model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") sub_conds_out = executor(model, sub_conds, sub_x, sub_timestep, model_options) if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: full_length = ADGS.params.full_length for pos, idx in enumerate(ctx_idxs): # bias is the influence of a specific index in relation to the whole context window bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) bias = max(1e-2, bias) # take weighted average relative to total bias of current idx for i in range(len(sub_conds_out)): bias_total = biases_final[i][idx] prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) conds_final[i][idx] = conds_final[i][idx] * prev_weight + sub_conds_out[i][pos] * new_weight biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) for i in range(len(sub_conds_out)): conds_final[i][ctx_idxs] += sub_conds_out[i] * weights_tensor counts_final[i][ctx_idxs] += weights_tensor # handle NaiveReuse if naivereuse_active: cached_naive_ctx_idxs = ctx_idxs for i in range(len(sub_conds)): cached_naive_conds[i][ctx_idxs] = conds_final[i][ctx_idxs] / counts_final[i][ctx_idxs] naivereuse_active = False # toggle first_context off, if needed if first_context: first_context = False finally: # clean contextref stuff with provided ACN function, if applicable if contextref_active: model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() # handle NaiveReuse if cached_naive_conds is not None: start_idx = cached_naive_ctx_idxs[0] for z in range(0, ADGS.params.full_length, len(cached_naive_ctx_idxs)): for i in range(len(cached_naive_conds)): # get the 'true' idxs of this window new_ctx_idxs = [(zz+start_idx) % ADGS.params.full_length for zz in list(range(z, z+len(cached_naive_ctx_idxs))) if zz < ADGS.params.full_length] # make sure when getting cached_naive idxs, they are adjusted for actual length leftover length adjusted_naive_ctx_idxs = cached_naive_ctx_idxs[:len(new_ctx_idxs)] weighted_mean = ADGS.params.context_options.extras.naive_reuse.get_effective_weighted_mean(x_in, new_ctx_idxs) conds_final[i][new_ctx_idxs] = (weighted_mean * (cached_naive_conds[i][adjusted_naive_ctx_idxs]*counts_final[i][new_ctx_idxs])) + ((1.-weighted_mean) * conds_final[i][new_ctx_idxs]) del cached_naive_conds if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: # already normalized, so return as is del counts_final return conds_final else: # normalize conds via division by context usage counts for i in range(len(conds_final)): conds_final[i] /= counts_final[i] del counts_final return conds_final def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseShape): new_conds = [] for cond in conds: resized_cond = None if cond is not None: # reuse or resize cond items to match context requirements resized_cond = [] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond: resized_actual_cond = actual_cond.copy() # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary for key in actual_cond: if key == "model_conds": new_model_conds = actual_cond[key].copy() if "c_concat" in new_model_conds: new_model_conds["c_concat"] = comfy.conds.CONDNoiseShape(torch.cat(new_model_conds["c_concat"].cond, c_concat.cond, dim=1)) else: new_model_conds["c_concat"] = c_concat resized_actual_cond[key] = new_model_conds resized_cond.append(resized_actual_cond) new_conds.append(resized_cond) return new_conds