|
from copy import deepcopy |
|
from typing import Callable, Union |
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional |
|
from einops import rearrange |
|
import numpy as np |
|
import math |
|
|
|
import comfy.ops |
|
import comfy.utils |
|
|
|
from comfy.controlnet import ControlBase |
|
from comfy.model_patcher import ModelPatcher |
|
from comfy.sd import VAE |
|
|
|
from .logger import logger |
|
|
|
BIGMIN = -(2**53-1) |
|
BIGMAX = (2**53-1) |
|
|
|
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet" |
|
CONTROL_INIT_BY_ACN = "_control_init_by_ACN" |
|
|
|
class Extras: |
|
MIDDLE_MULT = "middle_mult" |
|
|
|
|
|
def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable): |
|
def load_torch_file_with_dict(*args, **kwargs): |
|
|
|
comfy.utils.load_torch_file = orig_load_torch_file |
|
return controlnet_data |
|
return load_torch_file_with_dict |
|
|
|
|
|
class WrapperConsts: |
|
ACN = "ACN" |
|
VERSION = "version" |
|
ACN_OUTER_SAMPLE_WRAPPER_KEY = "ACN_outer_sample_wrapper" |
|
ACN_CREATE_SAMPLER_SAMPLE_WRAPPER = "create_outer_sample_wrapper" |
|
|
|
|
|
def get_properly_arranged_t2i_weights(initial_weights: list[float]): |
|
new_weights = [] |
|
new_weights.extend([initial_weights[0]]*3) |
|
new_weights.extend([initial_weights[1]]*3) |
|
new_weights.extend([initial_weights[2]]*3) |
|
new_weights.extend([initial_weights[3]]*3) |
|
return new_weights |
|
|
|
|
|
class ControlWeightType: |
|
DEFAULT = "default" |
|
UNIVERSAL = "universal" |
|
T2IADAPTER = "t2iadapter" |
|
CONTROLNET = "controlnet" |
|
CONTROLNETPLUSPLUS = "controlnet++" |
|
CONTROLLORA = "controllora" |
|
CONTROLLLLITE = "controllllite" |
|
SVD_CONTROLNET = "svd_controlnet" |
|
SPARSECTRL = "sparsectrl" |
|
CTRLORA = "ctrlora" |
|
|
|
|
|
class ControlWeights: |
|
def __init__(self, weight_type: str, base_multiplier: float=1.0, |
|
weights_input: list[float]=None, weights_middle: list[float]=None, weights_output: list[float]=None, |
|
weight_func: Callable=None, weight_mask: Tensor=None, |
|
uncond_multiplier=1.0, uncond_mask: Tensor=None, |
|
extras: dict[str]={}, disable_applied_to=False): |
|
self.weight_type = weight_type |
|
self.base_multiplier = base_multiplier |
|
self.weights_input = weights_input |
|
self.weights_middle = weights_middle |
|
self.weights_output = weights_output |
|
self.weight_func = weight_func |
|
self.weight_mask = weight_mask |
|
self.uncond_multiplier = float(uncond_multiplier) |
|
self.has_uncond_multiplier = not math.isclose(self.uncond_multiplier, 1.0) |
|
self.uncond_mask = uncond_mask if uncond_mask is not None else 1.0 |
|
self.has_uncond_mask = uncond_mask is not None |
|
self.extras = extras.copy() |
|
self.disable_applied_to = disable_applied_to |
|
|
|
def get(self, idx: int, control: dict[str, list[Tensor]], key: str, default=1.0) -> Union[float, Tensor]: |
|
|
|
if self.weight_func is not None: |
|
return self.weight_func(idx=idx, control=control, key=key) |
|
effective_mult = 1.0 |
|
|
|
relevant_weights = None |
|
if key == "middle": |
|
relevant_weights = self.weights_middle |
|
effective_mult *= self.extras.get(Extras.MIDDLE_MULT, 1.0) |
|
elif key == "input": |
|
relevant_weights = self.weights_input |
|
if relevant_weights is not None: |
|
relevant_weights = list(reversed(relevant_weights)) |
|
else: |
|
relevant_weights = self.weights_output |
|
if relevant_weights is None: |
|
return default * effective_mult |
|
elif idx >= len(relevant_weights): |
|
return default * effective_mult |
|
return relevant_weights[idx] * effective_mult |
|
|
|
def copy_with_new_weights(self, new_weights_input: list[float]=None, new_weights_middle: list[float]=None, new_weights_output: list[float]=None, |
|
new_weight_func: Callable=None): |
|
return ControlWeights(weight_type=self.weight_type, base_multiplier=self.base_multiplier, |
|
weights_input=new_weights_input, weights_middle=new_weights_middle, weights_output=new_weights_output, |
|
weight_func=new_weight_func, weight_mask=self.weight_mask, |
|
uncond_multiplier=self.uncond_multiplier, |
|
extras=self.extras, disable_applied_to=self.disable_applied_to) |
|
|
|
@classmethod |
|
def default(cls, extras: dict[str]={}): |
|
return cls(ControlWeightType.DEFAULT, extras=extras) |
|
|
|
@classmethod |
|
def universal(cls, base_multiplier: float, uncond_multiplier: float=1.0, extras: dict[str]={}): |
|
return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, uncond_multiplier=uncond_multiplier, disable_applied_to=True, extras=extras) |
|
|
|
@classmethod |
|
def universal_mask(cls, weight_mask: Tensor, uncond_multiplier: float=1.0, extras: dict[str]={}): |
|
return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask, uncond_multiplier=uncond_multiplier, disable_applied_to=True, extras=extras) |
|
|
|
@classmethod |
|
def t2iadapter(cls, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}, disable_applied_to=False): |
|
return cls(ControlWeightType.T2IADAPTER, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras, disable_applied_to=disable_applied_to) |
|
|
|
@classmethod |
|
def controlnet(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}, disable_applied_to=False): |
|
return cls(ControlWeightType.CONTROLNET, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras, disable_applied_to=disable_applied_to) |
|
|
|
@classmethod |
|
def controllora(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}, disable_applied_to=False): |
|
return cls(ControlWeightType.CONTROLLORA, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras, disable_applied_to=disable_applied_to) |
|
|
|
@classmethod |
|
def controllllite(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}, disable_applied_to=False): |
|
return cls(ControlWeightType.CONTROLLLLITE, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras, disable_applied_to=disable_applied_to) |
|
|
|
|
|
class StrengthInterpolation: |
|
LINEAR = "linear" |
|
EASE_IN = "ease-in" |
|
EASE_OUT = "ease-out" |
|
EASE_IN_OUT = "ease-in-out" |
|
NONE = "none" |
|
|
|
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] |
|
_LIST_WITH_NONE = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT, NONE] |
|
|
|
@classmethod |
|
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): |
|
diff = num_to - num_from |
|
if method == cls.LINEAR: |
|
weights = torch.linspace(num_from, num_to, length) |
|
elif method == cls.EASE_IN: |
|
index = torch.linspace(0, 1, length) |
|
weights = diff * np.power(index, 2) + num_from |
|
elif method == cls.EASE_OUT: |
|
index = torch.linspace(0, 1, length) |
|
weights = diff * (1 - np.power(1 - index, 2)) + num_from |
|
elif method == cls.EASE_IN_OUT: |
|
index = torch.linspace(0, 1, length) |
|
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from |
|
else: |
|
raise ValueError(f"Unrecognized interpolation method '{method}'.") |
|
if reverse: |
|
weights = weights.flip(dims=(0,)) |
|
return weights |
|
|
|
|
|
class LatentKeyframe: |
|
def __init__(self, batch_index: int, strength: float) -> None: |
|
self.batch_index = batch_index |
|
self.strength = strength |
|
|
|
|
|
|
|
class LatentKeyframeGroup: |
|
def __init__(self) -> None: |
|
self.keyframes: list[LatentKeyframe] = [] |
|
|
|
def add(self, keyframe: LatentKeyframe) -> None: |
|
added = False |
|
|
|
for i in range(len(self.keyframes)): |
|
if self.keyframes[i].batch_index == keyframe.batch_index: |
|
self.keyframes[i] = keyframe |
|
added = True |
|
break |
|
if not added: |
|
self.keyframes.append(keyframe) |
|
self.keyframes.sort(key=lambda k: k.batch_index) |
|
|
|
def get_index(self, index: int) -> Union[LatentKeyframe, None]: |
|
try: |
|
return self.keyframes[index] |
|
except IndexError: |
|
return None |
|
|
|
def __getitem__(self, index) -> LatentKeyframe: |
|
return self.keyframes[index] |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self) -> 'LatentKeyframeGroup': |
|
cloned = LatentKeyframeGroup() |
|
for tk in self.keyframes: |
|
cloned.add(tk) |
|
return cloned |
|
|
|
|
|
class TimestepKeyframe: |
|
def __init__(self, |
|
start_percent: float = 0.0, |
|
strength: float = 1.0, |
|
control_weights: ControlWeights = None, |
|
latent_keyframes: LatentKeyframeGroup = None, |
|
null_latent_kf_strength: float = 0.0, |
|
inherit_missing: bool = True, |
|
guarantee_steps: int = 1, |
|
mask_hint_orig: Tensor = None) -> None: |
|
self.start_percent = float(start_percent) |
|
self.start_t = 999999999.9 |
|
self.strength = strength |
|
self.control_weights = control_weights |
|
self.latent_keyframes = latent_keyframes |
|
self.null_latent_kf_strength = null_latent_kf_strength |
|
self.inherit_missing = inherit_missing |
|
self.guarantee_steps = guarantee_steps |
|
self.mask_hint_orig = mask_hint_orig |
|
|
|
def has_control_weights(self): |
|
return self.control_weights is not None |
|
|
|
def has_latent_keyframes(self): |
|
return self.latent_keyframes is not None |
|
|
|
def has_mask_hint(self): |
|
return self.mask_hint_orig is not None |
|
|
|
|
|
@staticmethod |
|
def default() -> 'TimestepKeyframe': |
|
return TimestepKeyframe(start_percent=0.0, guarantee_steps=0) |
|
|
|
|
|
|
|
class TimestepKeyframeGroup: |
|
def __init__(self, add_default=True) -> None: |
|
self.keyframes: list[TimestepKeyframe] = [] |
|
if add_default: |
|
self.keyframes.append(TimestepKeyframe.default()) |
|
|
|
def add(self, keyframe: TimestepKeyframe) -> None: |
|
|
|
self.keyframes.append(keyframe) |
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, attr="start_percent") |
|
|
|
def get_index(self, index: int) -> Union[TimestepKeyframe, None]: |
|
try: |
|
return self.keyframes[index] |
|
except IndexError: |
|
return None |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >=0 and index < len(self.keyframes) |
|
|
|
def __getitem__(self, index) -> TimestepKeyframe: |
|
return self.keyframes[index] |
|
|
|
def __len__(self) -> int: |
|
return len(self.keyframes) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self) -> 'TimestepKeyframeGroup': |
|
cloned = TimestepKeyframeGroup(add_default=False) |
|
|
|
for tk in self.keyframes: |
|
cloned.keyframes.append(tk) |
|
return cloned |
|
|
|
@classmethod |
|
def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup': |
|
group = cls() |
|
group.keyframes[0] = keyframe |
|
return group |
|
|
|
|
|
class AbstractPreprocWrapper: |
|
error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." |
|
def __init__(self, condhint): |
|
self.condhint = condhint |
|
|
|
def movedim(self, *args, **kwargs): |
|
return self |
|
|
|
def __getattr__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __setattr__(self, name, value): |
|
if name != "condhint": |
|
raise AttributeError(self.error_msg) |
|
super().__setattr__(name, value) |
|
|
|
def __iter__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __next__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __len__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __getitem__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __setitem__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
|
|
|
|
class disable_weight_init_clean_groupnorm(comfy.ops.disable_weight_init): |
|
class GroupNorm(comfy.ops.disable_weight_init.GroupNorm): |
|
def forward_comfy_cast_weights(self, input): |
|
weight, bias = comfy.ops.cast_bias_weight(self, input) |
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) |
|
|
|
def forward(self, input): |
|
if self.comfy_cast_weights: |
|
return self.forward_comfy_cast_weights(input) |
|
else: |
|
return torch.nn.functional.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) |
|
|
|
class manual_cast_clean_groupnorm(comfy.ops.manual_cast): |
|
class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm): |
|
comfy_cast_weights = True |
|
|
|
|
|
|
|
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False, match_shape=False, flux_shape=None): |
|
mask = mask.clone() |
|
if flux_shape is not None: |
|
multiplier = multiplier * 0.5 |
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(flux_shape[-2]*multiplier), round(flux_shape[-1]*multiplier)), mode="bilinear") |
|
mask = rearrange(mask, "b c h w -> b (h w) c") |
|
else: |
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(shape[-2]*multiplier), round(shape[-1]*multiplier)), mode="bilinear") |
|
if match_dim1: |
|
if match_shape and len(shape) < 4: |
|
raise Exception(f"match_dim1 cannot be True if shape is under 4 dims; was {len(shape)}.") |
|
mask = torch.cat([mask] * shape[1], dim=1) |
|
if match_shape and len(shape) == 3 and len(mask.shape) != 3: |
|
mask = mask.squeeze(1) |
|
return mask |
|
|
|
|
|
|
|
|
|
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): |
|
x_min, x_max = x.min(), x.max() |
|
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
|
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): |
|
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
|
def extend_to_batch_size(tensor: Tensor, batch_size: int): |
|
if tensor.shape[0] > batch_size: |
|
return tensor[:batch_size] |
|
elif tensor.shape[0] < batch_size: |
|
remainder = batch_size-tensor.shape[0] |
|
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0) |
|
return tensor |
|
|
|
def broadcast_image_to_extend(tensor, target_batch_size, batched_number, except_one=True): |
|
current_batch_size = tensor.shape[0] |
|
|
|
if except_one and current_batch_size == 1: |
|
return tensor |
|
|
|
per_batch = target_batch_size // batched_number |
|
tensor = tensor[:per_batch] |
|
|
|
if per_batch > tensor.shape[0]: |
|
tensor = extend_to_batch_size(tensor=tensor, batch_size=per_batch) |
|
|
|
current_batch_size = tensor.shape[0] |
|
if current_batch_size == target_batch_size: |
|
return tensor |
|
else: |
|
return torch.cat([tensor] * batched_number, dim=0) |
|
|
|
|
|
|
|
def deepcopy_with_sharing(obj, shared_attribute_names, memo=None): |
|
''' |
|
Deepcopy an object, except for a given list of attributes, which should |
|
be shared between the original object and its copy. |
|
|
|
obj is some object |
|
shared_attribute_names: A list of strings identifying the attributes that |
|
should be shared between the original and its copy. |
|
memo is the dictionary passed into __deepcopy__. Ignore this argument if |
|
not calling from within __deepcopy__. |
|
''' |
|
assert isinstance(shared_attribute_names, (list, tuple)) |
|
|
|
shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names} |
|
|
|
if hasattr(obj, '__deepcopy__'): |
|
|
|
deepcopy_method = obj.__deepcopy__ |
|
obj.__deepcopy__ = None |
|
|
|
for attr in shared_attribute_names: |
|
del obj.__dict__[attr] |
|
|
|
clone = deepcopy(obj) |
|
|
|
for attr, val in shared_attributes.items(): |
|
setattr(obj, attr, val) |
|
setattr(clone, attr, val) |
|
|
|
if hasattr(obj, '__deepcopy__'): |
|
|
|
obj.__deepcopy__ = deepcopy_method |
|
del clone.__deepcopy__ |
|
|
|
return clone |
|
|
|
|
|
def get_sorted_list_via_attr(objects: list, attr: str) -> list: |
|
if not objects: |
|
return objects |
|
elif len(objects) <= 1: |
|
return [x for x in objects] |
|
|
|
|
|
|
|
unique_attrs = {} |
|
for o in objects: |
|
val_attr = getattr(o, attr) |
|
attr_list: list = unique_attrs.get(val_attr, list()) |
|
attr_list.append(o) |
|
if val_attr not in unique_attrs: |
|
unique_attrs[val_attr] = attr_list |
|
|
|
sorted_attrs = dict(sorted(unique_attrs.items())) |
|
|
|
sorted_list = [] |
|
for object_list in sorted_attrs.values(): |
|
sorted_list.extend(object_list) |
|
return sorted_list |
|
|
|
|
|
|
|
def torch_dfs(model: torch.nn.Module): |
|
result = [model] |
|
for child in model.children(): |
|
result += torch_dfs(child) |
|
return result |
|
|
|
|
|
class WeightTypeException(TypeError): |
|
"Raised when weight not compatible with AdvancedControlBase object" |
|
pass |
|
|
|
|
|
class AdvancedControlBase: |
|
def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights, require_vae=False, allow_condhint_latents=False): |
|
self.base = base |
|
self.compatible_weights = [ControlWeightType.UNIVERSAL, ControlWeightType.DEFAULT] |
|
self.add_compatible_weight(weights_default.weight_type) |
|
|
|
self.mask_cond_hint_original = None |
|
self.mask_cond_hint = None |
|
self.tk_mask_cond_hint_original = None |
|
self.tk_mask_cond_hint = None |
|
self.weight_mask_cond_hint = None |
|
|
|
self.sub_idxs = None |
|
self.full_latent_length = 0 |
|
self.context_length = 0 |
|
|
|
self.t: float = None |
|
self.prev_t: float = None |
|
self.batched_number: int = None |
|
self.batch_size: int = 0 |
|
self.cond_or_uncond: list[int] = None |
|
|
|
self.weights: ControlWeights = None |
|
self.weights_default: ControlWeights = weights_default |
|
self.weights_override: ControlWeights = None |
|
|
|
self.latent_keyframes: LatentKeyframeGroup = None |
|
self.latent_keyframe_override: LatentKeyframeGroup = None |
|
|
|
self.set_timestep_keyframes(timestep_keyframes) |
|
|
|
self.get_control = self.get_control_inject |
|
self.control_merge = self.control_merge_inject |
|
self.pre_run = self.pre_run_inject |
|
self.cleanup = self.cleanup_inject |
|
self.set_previous_controlnet = self.set_previous_controlnet_inject |
|
self.set_cond_hint = self.set_cond_hint_inject |
|
|
|
self.adv_vae = None |
|
self.mult_by_ratio_when_vae = True |
|
|
|
self.real_compression_ratio = None |
|
|
|
self.require_vae = require_vae |
|
self.allow_condhint_latents = allow_condhint_latents |
|
self.postpone_condhint_latents_check = False |
|
|
|
self.disarmed = True |
|
|
|
def add_compatible_weight(self, control_weight_type: str): |
|
self.compatible_weights.append(control_weight_type) |
|
|
|
def verify_all_weights(self, throw_error=True): |
|
|
|
if self.weights_override is not None: |
|
if self.weights_override.weight_type not in self.compatible_weights: |
|
msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \ |
|
f"only supports {self.compatible_weights} weights." |
|
raise WeightTypeException(msg) |
|
|
|
else: |
|
for tk in self.timestep_keyframes.keyframes: |
|
if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights: |
|
msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type " + \ |
|
f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights." |
|
raise WeightTypeException(msg) |
|
|
|
def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup): |
|
self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup() |
|
|
|
self._current_timestep_keyframe = None |
|
self._current_timestep_index = -1 |
|
self._current_used_steps = 0 |
|
self.weights = None |
|
self.latent_keyframes = None |
|
|
|
def prepare_current_timestep(self, t: Tensor, batched_number: int=1): |
|
self.t = float(t[0]) |
|
|
|
if self.t == self.prev_t: |
|
return |
|
|
|
curr_t: float = self.t |
|
prev_index = self._current_timestep_index |
|
|
|
if self._current_timestep_keyframe is None or self._current_used_steps >= self._current_timestep_keyframe.guarantee_steps: |
|
|
|
if self.timestep_keyframes.has_index(self._current_timestep_index+1): |
|
for i in range(self._current_timestep_index+1, len(self.timestep_keyframes)): |
|
eval_tk = self.timestep_keyframes[i] |
|
|
|
if eval_tk.start_t >= curr_t: |
|
self._current_timestep_index = i |
|
self._current_timestep_keyframe = eval_tk |
|
self._current_used_steps = 0 |
|
|
|
|
|
if self._current_timestep_keyframe.has_control_weights(): |
|
self.weights = self._current_timestep_keyframe.control_weights |
|
elif not self._current_timestep_keyframe.inherit_missing: |
|
self.weights = self.weights_default |
|
if self._current_timestep_keyframe.has_latent_keyframes(): |
|
self.latent_keyframes = self._current_timestep_keyframe.latent_keyframes |
|
elif not self._current_timestep_keyframe.inherit_missing: |
|
self.latent_keyframes = None |
|
if self._current_timestep_keyframe.has_mask_hint(): |
|
self.tk_mask_cond_hint_original = self._current_timestep_keyframe.mask_hint_orig |
|
elif not self._current_timestep_keyframe.inherit_missing: |
|
del self.tk_mask_cond_hint_original |
|
self.tk_mask_cond_hint_original = None |
|
|
|
if self._current_timestep_keyframe.guarantee_steps > 0: |
|
break |
|
|
|
else: |
|
break |
|
|
|
self.prev_t = self.t |
|
|
|
self._current_used_steps += 1 |
|
|
|
if prev_index != self._current_timestep_index: |
|
if self.weights_override is not None: |
|
self.weights = self.weights_override |
|
if self.latent_keyframe_override is not None: |
|
self.latent_keyframes = self.latent_keyframe_override |
|
|
|
|
|
|
|
self.prepare_weights() |
|
|
|
def prepare_weights(self): |
|
if self.weights is None: |
|
self.weights = self.weights_default |
|
elif self.weights.weight_type == ControlWeightType.UNIVERSAL: |
|
|
|
if self.weights.weight_mask is not None: |
|
return |
|
self.weights = self.get_universal_weights() |
|
|
|
def get_universal_weights(self) -> ControlWeights: |
|
return self.weights |
|
|
|
def set_cond_hint_mask(self, mask_hint): |
|
self.mask_cond_hint_original = mask_hint |
|
return self |
|
|
|
def set_cond_hint_inject(self, *args, **kwargs): |
|
to_return = self.base.set_cond_hint(*args, **kwargs) |
|
|
|
if self.require_vae: |
|
|
|
for arg in args: |
|
if isinstance(arg, VAE): |
|
self.adv_vae = arg |
|
self.vae = arg |
|
break |
|
|
|
if self.adv_vae is None: |
|
if 'vae' in kwargs: |
|
self.adv_vae = kwargs['vae'] |
|
self.vae = kwargs['vae'] |
|
return to_return |
|
|
|
def pre_run_inject(self, model, percent_to_timestep_function): |
|
self.base.pre_run(model, percent_to_timestep_function) |
|
self.pre_run_advanced(model, percent_to_timestep_function) |
|
|
|
def pre_run_advanced(self, model, percent_to_timestep_function): |
|
|
|
for tk in self.timestep_keyframes.keyframes: |
|
tk.start_t = percent_to_timestep_function(tk.start_percent) |
|
|
|
if hasattr(self, "compression_ratio"): |
|
self.real_compression_ratio = self.compression_ratio |
|
|
|
self.cleanup_advanced() |
|
|
|
def set_previous_controlnet_inject(self, *args, **kwargs): |
|
to_return = self.base.set_previous_controlnet(*args, **kwargs) |
|
if not self.disarmed: |
|
raise Exception(f"Type '{type(self).__name__}' must be used with Apply Advanced ControlNet 🛂🅐🅒🅝 node (with model_optional passed in); otherwise, it will not work.") |
|
return to_return |
|
|
|
def disarm(self): |
|
self.disarmed = True |
|
|
|
def should_run(self): |
|
if math.isclose(self.strength, 0.0) or math.isclose(self._current_timestep_keyframe.strength, 0.0): |
|
return False |
|
if self.timestep_range is not None: |
|
if self.t > self.timestep_range[0] or self.t < self.timestep_range[1]: |
|
return False |
|
return True |
|
|
|
def get_control_inject(self, x_noisy, t, cond, batched_number, transformer_options: dict): |
|
self.batched_number = batched_number |
|
self.batch_size = len(t) |
|
self.cond_or_uncond = transformer_options.get("cond_or_uncond", None) |
|
|
|
self.prepare_current_timestep(t=t, batched_number=batched_number) |
|
|
|
if self.strength == 0.0 or self._current_timestep_keyframe.strength == 0.0: |
|
return self.default_control_actions(x_noisy, t, cond, batched_number, transformer_options) |
|
|
|
return self.get_control_advanced(x_noisy, t, cond, batched_number, transformer_options) |
|
|
|
def get_control_advanced(self, x_noisy, t, cond, batched_number, transformer_options): |
|
return self.default_control_actions(x_noisy, t, cond, batched_number, transformer_options) |
|
|
|
def default_control_actions(self, x_noisy, t, cond, batched_number, transformer_options): |
|
control_prev = None |
|
if self.previous_controlnet is not None: |
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) |
|
return control_prev |
|
|
|
def calc_weight(self, idx: int, x: Tensor, control: dict[str, list[Tensor]], key: str) -> Union[float, Tensor]: |
|
if self.weights.weight_mask is not None: |
|
|
|
self.prepare_weight_mask_cond_hint(x, self.batched_number) |
|
|
|
return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, control=control, key=key)) |
|
return self.weights.get(idx=idx, control=control, key=key) |
|
|
|
def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int: |
|
if key == "middle": |
|
return 0 |
|
else: |
|
c_len = len(control[key]) |
|
real_idx = c_len-idx |
|
if key == "input": |
|
real_idx = c_len - real_idx + 1 |
|
return real_idx |
|
|
|
def calc_latent_keyframe_mults(self, x: Tensor, batched_number: int) -> Tensor: |
|
|
|
|
|
final_mults = [1.0] * x.shape[0] |
|
if self.latent_keyframes: |
|
latent_count = x.shape[0] // batched_number |
|
indeces_to_null = set(range(latent_count)) |
|
mapped_indeces = None |
|
|
|
if self.sub_idxs: |
|
mapped_indeces = {} |
|
for i, actual in enumerate(self.sub_idxs): |
|
mapped_indeces[actual] = i |
|
for keyframe in self.latent_keyframes: |
|
real_index = keyframe.batch_index |
|
|
|
if real_index < 0: |
|
real_index += latent_count if self.sub_idxs is None else self.full_latent_length |
|
|
|
|
|
if mapped_indeces is None: |
|
if real_index in indeces_to_null: |
|
indeces_to_null.remove(real_index) |
|
|
|
else: |
|
real_index = mapped_indeces.get(real_index, None) |
|
if real_index is None: |
|
continue |
|
indeces_to_null.remove(real_index) |
|
|
|
|
|
if real_index >= latent_count or real_index < 0: |
|
continue |
|
|
|
|
|
for b in range(batched_number): |
|
final_mults[(latent_count*b)+real_index] = keyframe.strength |
|
|
|
for batch_index in indeces_to_null: |
|
|
|
for b in range(batched_number): |
|
final_mults[(latent_count*b)+batch_index] = self._current_timestep_keyframe.null_latent_kf_strength |
|
|
|
final_tensor = torch.tensor(final_mults, dtype=x.dtype, device=x.device) |
|
while len(final_tensor.shape) < len(x.shape): |
|
final_tensor = final_tensor.unsqueeze(-1) |
|
return final_tensor |
|
|
|
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape: tuple=None): |
|
|
|
if self.weights.has_uncond_multiplier: |
|
actual_length = x.size(0) // batched_number |
|
for idx, cond_type in enumerate(self.cond_or_uncond): |
|
|
|
if cond_type == 1: |
|
x[actual_length*idx:actual_length*(idx+1)] *= self.weights.uncond_multiplier |
|
if self.weights.has_uncond_mask: |
|
pass |
|
|
|
if self.latent_keyframes is not None: |
|
x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number) |
|
|
|
if self.mask_cond_hint is not None: |
|
masks = prepare_mask_batch(self.mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape) |
|
x[:] = x[:] * masks |
|
if self.tk_mask_cond_hint is not None: |
|
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape) |
|
x[:] = x[:] * masks |
|
|
|
if self._current_timestep_keyframe.strength != 1.0: |
|
x[:] *= self._current_timestep_keyframe.strength |
|
|
|
def control_merge_inject(self: 'AdvancedControlBase', control: dict[str, list[Tensor]], control_prev: dict, output_dtype): |
|
out = {'input':[], 'middle':[], 'output': []} |
|
|
|
for key in control: |
|
control_output = control[key] |
|
applied_to = set() |
|
for i in range(len(control_output)): |
|
x = control_output[i] |
|
if x is not None: |
|
if self.global_average_pooling: |
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) |
|
|
|
|
|
if self.weights.disable_applied_to and x in applied_to: |
|
x = x.clone() |
|
|
|
if x not in applied_to: |
|
applied_to.add(x) |
|
self.apply_advanced_strengths_and_masks(x, self.batched_number) |
|
x *= self.strength * self.calc_weight(i, x, control, key) |
|
|
|
if output_dtype is not None and x.dtype != output_dtype: |
|
x = x.to(output_dtype) |
|
|
|
out[key].append(x) |
|
|
|
if control_prev is not None: |
|
for x in ['input', 'middle', 'output']: |
|
o = out[x] |
|
for i in range(len(control_prev[x])): |
|
prev_val = control_prev[x][i] |
|
if i >= len(o): |
|
o.append(prev_val) |
|
elif prev_val is not None: |
|
if o[i] is None: |
|
o[i] = prev_val |
|
else: |
|
if o[i].shape[0] < prev_val.shape[0]: |
|
o[i] = prev_val + o[i] |
|
else: |
|
o[i] = prev_val + o[i] |
|
return out |
|
|
|
def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): |
|
self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) |
|
self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) |
|
|
|
def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): |
|
return self._prepare_mask("tk_mask_cond_hint", self._current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) |
|
|
|
def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None): |
|
return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True) |
|
|
|
def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): |
|
|
|
if orig_mask is not None: |
|
out_mask = getattr(self, attr_name) |
|
multiplier = 1 if direct_attn else 8 |
|
if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * multiplier != out_mask.shape[1] or x_noisy.shape[3] * multiplier != out_mask.shape[2]: |
|
self._reset_attr(attr_name) |
|
del out_mask |
|
|
|
|
|
out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier, match_shape=True) |
|
actual_latent_length = x_noisy.shape[0] // batched_number |
|
out_mask = extend_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length) |
|
if self.sub_idxs is not None: |
|
out_mask = out_mask[self.sub_idxs] |
|
|
|
if x_noisy.shape[0] != out_mask.shape[0]: |
|
out_mask = broadcast_image_to_extend(out_mask, x_noisy.shape[0], batched_number) |
|
|
|
if dtype is None: |
|
dtype = x_noisy.dtype |
|
setattr(self, attr_name, out_mask.to(dtype=dtype).to(x_noisy.device)) |
|
del out_mask |
|
|
|
def _reset_attr(self, attr_name, new_value=None): |
|
if hasattr(self, attr_name): |
|
delattr(self, attr_name) |
|
setattr(self, attr_name, new_value) |
|
|
|
def cleanup_inject(self): |
|
self.base.cleanup() |
|
self.cleanup_advanced() |
|
|
|
def cleanup_advanced(self): |
|
self.sub_idxs = None |
|
self.full_latent_length = 0 |
|
self.context_length = 0 |
|
self.t = None |
|
self.prev_t = None |
|
self.batched_number = None |
|
self.batch_size = 0 |
|
self.weights = None |
|
self.latent_keyframes = None |
|
|
|
if hasattr(self, "compression_ratio"): |
|
self.real_compression_ratio = self.compression_ratio |
|
|
|
self._current_timestep_keyframe = None |
|
self._current_timestep_index = -1 |
|
self._current_used_steps = 0 |
|
|
|
if self.mask_cond_hint is not None: |
|
del self.mask_cond_hint |
|
self.mask_cond_hint = None |
|
if self.tk_mask_cond_hint_original is not None: |
|
del self.tk_mask_cond_hint_original |
|
self.tk_mask_cond_hint_original = None |
|
if self.tk_mask_cond_hint is not None: |
|
del self.tk_mask_cond_hint |
|
self.tk_mask_cond_hint = None |
|
if self.weight_mask_cond_hint is not None: |
|
del self.weight_mask_cond_hint |
|
self.weight_mask_cond_hint = None |
|
|
|
def copy_to_advanced(self, copied: 'AdvancedControlBase'): |
|
copied.mask_cond_hint_original = self.mask_cond_hint_original |
|
copied.weights_override = self.weights_override |
|
copied.latent_keyframe_override = self.latent_keyframe_override |
|
copied.adv_vae = self.adv_vae |
|
copied.mult_by_ratio_when_vae = self.mult_by_ratio_when_vae |
|
copied.require_vae = self.require_vae |
|
copied.allow_condhint_latents = self.allow_condhint_latents |
|
copied.postpone_condhint_latents_check = self.postpone_condhint_latents_check |
|
copied.disarmed = self.disarmed |
|
|