daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
28.8 kB
from typing import Callable
import math
import torch
from torch import Tensor
from torch.nn.functional import group_norm
from einops import rearrange
import comfy.ldm.modules.attention as attention
from comfy.ldm.modules.diffusionmodules import openaimodel
import comfy.model_management as model_management
import comfy.samplers
import comfy.sample
import comfy.utils
from comfy.controlnet import ControlBase
import comfy.ops
from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows
from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, prepare_mask_ad
from .utils_model import ModelTypeSD, wrap_function_to_inject_xformers_bug_info
from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher
from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule
from .logger import logger
##################################################################################
######################################################################
# Global variable to use to more conveniently hack variable access into samplers
class AnimateDiffHelper_GlobalState:
def __init__(self):
self.motion_models: MotionModelGroup = None
self.params: InjectionParams = None
self.sample_settings: SampleSettings = None
self.reset()
def initialize(self, model):
# this function is to be run in sampling func
if not self.initialized:
self.initialized = True
if self.motion_models is not None:
self.motion_models.initialize_timesteps(model)
if self.params.context_options is not None:
self.params.context_options.initialize_timesteps(model)
if self.sample_settings.custom_cfg is not None:
self.sample_settings.custom_cfg.initialize_timesteps(model)
def reset(self):
self.initialized = False
self.start_step: int = 0
self.last_step: int = 0
self.current_step: int = 0
self.total_steps: int = 0
if self.motion_models is not None:
del self.motion_models
self.motion_models = None
if self.params is not None:
del self.params
self.params = None
if self.sample_settings is not None:
del self.sample_settings
self.sample_settings = None
def update_with_inject_params(self, params: InjectionParams):
self.params = params
def is_using_sliding_context(self):
return self.params is not None and self.params.is_using_sliding_context()
def create_exposed_params(self):
# This dict will be exposed to be used by other extensions
# DO NOT change any of the key names
# or I will find you πŸ‘.πŸ‘
return {
"full_length": self.params.full_length,
"context_length": self.params.context_options.context_length,
"sub_idxs": self.params.sub_idxs,
}
ADGS = AnimateDiffHelper_GlobalState()
######################################################################
##################################################################################
##################################################################################
#### Code Injection ##################################################
# refer to forward_timestep_embed in comfy/ldm/modules/diffusionmodules/openaimodel.py
def forward_timestep_embed_factory() -> Callable:
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, openaimodel.VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, openaimodel.TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, VanillaTemporalModule):
x = layer(x, context)
elif isinstance(layer, attention.SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
if "current_index" in transformer_options: # keep this for backward compat, for now
transformer_options["current_index"] += 1
elif isinstance(layer, attention.SpatialTransformer):
x = layer(x, context, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
if "current_index" in transformer_options: # keep this for backward compat, for now
transformer_options["current_index"] += 1
elif isinstance(layer, openaimodel.Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
return forward_timestep_embed
def unlimited_memory_required(*args, **kwargs):
return 0
def groupnorm_mm_factory(params: InjectionParams, manual_cast=False):
def groupnorm_mm_forward(self, input: Tensor) -> Tensor:
# axes_factor normalizes batch based on total conds and unconds passed in batch;
# the conds and unconds per batch can change based on VRAM optimizations that may kick in
if not params.is_using_sliding_context():
batched_conds = input.size(0)//params.full_length
else:
batched_conds = input.size(0)//params.context_options.context_length
input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds)
if manual_cast:
weight, bias = comfy.ops.cast_bias_weight(self, input)
else:
weight, bias = self.weight, self.bias
input = group_norm(input, self.num_groups, weight, bias, self.eps)
input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds)
return input
return groupnorm_mm_forward
def get_additional_models_factory(orig_get_additional_models: Callable, motion_models: MotionModelGroup):
def get_additional_models_with_motion(*args, **kwargs):
models, inference_memory = orig_get_additional_models(*args, **kwargs)
if motion_models is not None:
for motion_model in motion_models.models:
models.append(motion_model)
# TODO: account for inference memory as well?
return models, inference_memory
return get_additional_models_with_motion
######################################################################
##################################################################################
def apply_params_to_motion_models(motion_models: MotionModelGroup, params: InjectionParams):
params = params.clone()
for context in params.context_options.contexts:
if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT:
context.context_length = params.full_length
# TODO: check (and message) should be different based on use_on_equal_length setting
if params.context_options.context_length:
pass
allow_equal = params.context_options.use_on_equal_length
if params.context_options.context_length:
enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length
else:
enough_latents = False
if params.context_options.context_length and enough_latents:
logger.info(f"Sliding context window activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.")
else:
logger.info(f"Regular AnimateDiff activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.")
params.reset_context()
if motion_models is not None:
# if no context_length, treat video length as intended AD frame window
if not params.context_options.context_length:
for motion_model in motion_models.models:
if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length):
raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.")
motion_models.set_video_length(params.full_length, params.full_length)
# otherwise, treat context_length as intended AD frame window
else:
for motion_model in motion_models.models:
view_options = params.context_options.view_options
context_length = view_options.context_length if view_options else params.context_options.context_length
if not motion_model.model.is_length_valid_for_encoding_max_len(context_length):
raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.")
motion_models.set_video_length(params.context_options.context_length, params.full_length)
# inject model
module_str = "modules" if len(motion_models.models) > 1 else "module"
logger.info(f"Using motion {module_str} {motion_models.get_name_string(show_version=True)}.")
return params
class FunctionInjectionHolder:
def __init__(self):
pass
def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionParams):
# Save Original Functions
self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule
self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds
self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames
self.orig_groupnorm_manual_cast_forward = comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights
self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
self.orig_prepare_mask = comfy.sample.prepare_mask
self.orig_get_additional_models = comfy.sample.get_additional_models
# Inject Functions
openaimodel.forward_timestep_embed = forward_timestep_embed_factory()
if params.unlimited_area_hack:
model.model.memory_required = unlimited_memory_required
if model.motion_models is not None:
# only apply groupnorm hack if not [v3 or ([not Hotshot] and SD1.5 and v2 and apply_v2_properly)]
info: AnimateDiffInfo = model.motion_models[0].model.mm_info
if not (info.mm_version == AnimateDiffVersion.V3 or
(info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)):
torch.nn.GroupNorm.forward = groupnorm_mm_factory(params)
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
# if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack
try:
if model.load_device.type == "mps":
model.model.memory_required = unlimited_memory_required
except Exception:
pass
del info
comfy.samplers.sampling_function = evolved_sampling_function
comfy.sample.prepare_mask = prepare_mask_ad
comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
def restore_functions(self, model: ModelPatcherAndInjector):
# Restoration
try:
model.model.memory_required = self.orig_memory_required
openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed
torch.nn.GroupNorm.forward = self.orig_groupnorm_forward
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_manual_cast_forward
comfy.samplers.sampling_function = self.orig_sampling_function
comfy.sample.prepare_mask = self.orig_prepare_mask
comfy.sample.get_additional_models = self.orig_get_additional_models
except AttributeError:
logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \
"to save original functions before injection, and a more specific error was thrown by ComfyUI.")
def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable:
def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs):
# check if model is intended for injecting
if type(model) != ModelPatcherAndInjector:
return orig_comfy_sample(model, noise, *args, **kwargs)
# otherwise, injection time
latents = None
cached_latents = None
cached_noise = None
function_injections = FunctionInjectionHolder()
try:
if model.sample_settings.custom_cfg is not None:
model = model.sample_settings.custom_cfg.patch_model(model)
# clone params from model
params = model.motion_injection_params.clone()
# get amount of latents passed in, and store in params
latents: Tensor = args[-1]
params.full_length = latents.size(0)
# reset global state
ADGS.reset()
# apply custom noise, if needed
disable_noise = kwargs.get("disable_noise") or False
seed = kwargs["seed"]
# apply params to motion model
params = apply_params_to_motion_models(model.motion_models, params)
# store and inject functions
function_injections.inject_functions(model, params)
# prepare noise_extra_args for noise generation purposes
noise_extra_args = {"disable_noise": disable_noise}
params.set_noise_extra_args(noise_extra_args)
# if noise is not disabled, do noise stuff
if not disable_noise:
noise = model.sample_settings.prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False)
# callback setup
original_callback = kwargs.get("callback", None)
def ad_callback(step, x0, x, total_steps):
if original_callback is not None:
original_callback(step, x0, x, total_steps)
# update GLOBALSTATE for next iteration
ADGS.current_step = ADGS.start_step + step + 1
kwargs["callback"] = ad_callback
ADGS.motion_models = model.motion_models
ADGS.sample_settings = model.sample_settings
# apply adapt_denoise_steps
args = list(args)
if model.sample_settings.adapt_denoise_steps and not is_custom:
# only applicable when denoise and steps are provided (from simple KSampler nodes)
denoise = kwargs.get("denoise", None)
steps = args[0]
if denoise is not None and type(steps) == int:
args[0] = max(int(denoise * steps), 1)
iter_opts = IterationOptions()
if model.sample_settings is not None:
iter_opts = model.sample_settings.iteration_opts
iter_opts.initialize(latents)
# cache initial noise and latents, if needed
if iter_opts.cache_init_latents:
cached_latents = latents.clone()
if iter_opts.cache_init_noise:
cached_noise = noise.clone()
# prepare iter opts preprocess kwargs, if needed
iter_kwargs = {}
if iter_opts.need_sampler:
# -5 for sampler_name (not custom) and sampler (custom)
model_management.load_model_gpu(model)
if is_custom:
iter_kwargs[IterationOptions.SAMPLER] = None #args[-5]
else:
iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler(
model.model, steps=999, #steps=args[-7],
device=model.current_device, sampler=args[-5],
scheduler=args[-4], denoise=kwargs.get("denoise", None),
model_options=model.model_options)
for curr_i in range(iter_opts.iterations):
# handle GLOBALSTATE vars and step tally
ADGS.update_with_inject_params(params)
ADGS.start_step = kwargs.get("start_step") or 0
ADGS.current_step = ADGS.start_step
ADGS.last_step = kwargs.get("last_step") or 0
if iter_opts.iterations > 1:
logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}")
# perform any iter_opts preprocessing on latents
latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=model, latents=latents, noise=noise,
cached_latents=cached_latents, cached_noise=cached_noise,
seed=seed,
sample_settings=model.sample_settings, noise_extra_args=noise_extra_args,
**iter_kwargs)
args[-1] = latents
if model.motion_models is not None:
model.motion_models.pre_run(model)
if model.sample_settings is not None:
model.sample_settings.pre_run(model)
latents = wrap_function_to_inject_xformers_bug_info(orig_comfy_sample)(model, noise, *args, **kwargs)
return latents
finally:
del latents
del noise
del cached_latents
del cached_noise
# reset global state
ADGS.reset()
# restore injected functions
function_injections.restore_functions(model)
del function_injections
return motion_sample
def evolved_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options: dict={}, seed=None):
ADGS.initialize(model)
if ADGS.motion_models is not None:
ADGS.motion_models.prepare_current_keyframe(t=timestep)
if ADGS.params.context_options is not None:
ADGS.params.context_options.prepare_current_context(t=timestep)
if ADGS.sample_settings.custom_cfg is not None:
ADGS.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep)
# never use cfg1 optimization if using custom_cfg (since can have timesteps and such)
if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
# add AD/evolved-sampling params to model_options (transformer_options)
model_options = model_options.copy()
if "tranformer_options" not in model_options:
model_options["tranformer_options"] = {}
model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params()
if not ADGS.is_using_sliding_context():
cond_pred, uncond_pred = comfy.samplers.calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
else:
cond_pred, uncond_pred = sliding_calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
# sliding_calc_cond_uncond_batch inspired by ashen's initial hack for 16-frame sliding context:
# https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master
def sliding_calc_cond_uncond_batch(model, cond, uncond, x_in: Tensor, timestep, model_options):
def prepare_control_objects(control: ControlBase, full_idxs: list[int]):
if control.previous_controlnet is not None:
prepare_control_objects(control.previous_controlnet, full_idxs)
control.sub_idxs = full_idxs
control.full_latent_length = ADGS.params.full_length
control.context_length = ADGS.params.context_options.context_length
def get_resized_cond(cond_in, full_idxs) -> list:
# reuse or resize cond items to match context requirements
resized_cond = []
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
for key in actual_cond:
try:
cond_item = actual_cond[key]
if isinstance(cond_item, Tensor):
# check that tensor is the expected length - x.size(0)
if cond_item.size(0) == x_in.size(0):
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
actual_cond_item = cond_item[full_idxs]
resized_actual_cond[key] = actual_cond_item
else:
resized_actual_cond[key] = cond_item
# look for control
elif key == "control":
control_item = cond_item
if hasattr(control_item, "sub_idxs"):
prepare_control_objects(control_item, full_idxs)
else:
raise ValueError(f"Control type {type(control_item).__name__} may not support required features for sliding context window; \
use Control objects from Kosinkadink/ComfyUI-Advanced-ControlNet nodes, or make sure Advanced-ControlNet is updated.")
resized_actual_cond[key] = control_item
del control_item
elif isinstance(cond_item, dict):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, Tensor):
if cond_value.size(0) == x_in.size(0):
new_cond_item[cond_key] = cond_value[full_idxs]
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, Tensor):
if cond_value.cond.size(0) == x_in.size(0):
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond[full_idxs])
resized_actual_cond[key] = new_cond_item
else:
resized_actual_cond[key] = cond_item
finally:
del cond_item # just in case to prevent VRAM issues
resized_cond.append(resized_actual_cond)
return resized_cond
# get context windows
ADGS.params.context_options.step = ADGS.current_step
context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options)
# figure out how input is split
batched_conds = x_in.size(0)//ADGS.params.full_length
if ADGS.motion_models is not None:
ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options)
# prepare final cond, uncond, and out_count
cond_final = torch.zeros_like(x_in)
uncond_final = torch.zeros_like(x_in)
out_count_final = torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device)
bias_final = [0.0] * x_in.shape[0]
# perform calc_cond_uncond_batch per context window
for ctx_idxs in context_windows:
ADGS.params.sub_idxs = ctx_idxs
if ADGS.motion_models is not None:
ADGS.motion_models.set_sub_idxs(ctx_idxs)
ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length)
# update exposed params
model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs
model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs)
# account for all portions of input frames
full_idxs = []
for n in range(batched_conds):
for ind in ctx_idxs:
full_idxs.append((ADGS.params.full_length*n)+ind)
# get subsections of x, timestep, cond, uncond, cond_concat
sub_x = x_in[full_idxs]
sub_timestep = timestep[full_idxs]
sub_cond = get_resized_cond(cond, full_idxs) if cond is not None else None
sub_uncond = get_resized_cond(uncond, full_idxs) if uncond is not None else None
sub_cond_out, sub_uncond_out = comfy.samplers.calc_cond_uncond_batch(model, sub_cond, sub_uncond, sub_x, sub_timestep, model_options)
if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
full_length = ADGS.params.full_length
for pos, idx in enumerate(ctx_idxs):
# bias is the influence of a specific index in relation to the whole context window
bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2)
bias = max(1e-2, bias)
# take weighted average relative to total bias of current idx
# and account for batched_conds
for n in range(batched_conds):
bias_total = bias_final[(full_length*n)+idx]
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
cond_final[(full_length*n)+idx] = cond_final[(full_length*n)+idx] * prev_weight + sub_cond_out[(full_length*n)+pos] * new_weight
uncond_final[(full_length*n)+idx] = uncond_final[(full_length*n)+idx] * prev_weight + sub_uncond_out[(full_length*n)+pos] * new_weight
bias_final[(full_length*n)+idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method) * batched_conds
weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
cond_final[full_idxs] += sub_cond_out * weights_tensor
uncond_final[full_idxs] += sub_uncond_out * weights_tensor
out_count_final[full_idxs] += weights_tensor
if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
# already normalized, so return as is
del out_count_final
return cond_final, uncond_final
else:
# normalize cond and uncond via division by context usage counts
cond_final /= out_count_final
uncond_final /= out_count_final
del out_count_final
return cond_final, uncond_final