|
from typing import Union |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
from abc import ABC, abstractmethod |
|
from collections.abc import Iterable |
|
|
|
import comfy.model_management as model_management |
|
import comfy.ops |
|
import comfy.utils |
|
from comfy.cli_args import args |
|
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default |
|
|
|
from .logger import logger |
|
|
|
|
|
|
|
|
|
|
|
optimized_attention_mm = attention_basic |
|
fallback_attention_mm = attention_basic |
|
if model_management.xformers_enabled(): |
|
pass |
|
|
|
if model_management.pytorch_attention_enabled(): |
|
optimized_attention_mm = attention_pytorch |
|
if args.use_split_cross_attention: |
|
fallback_attention_mm = attention_split |
|
else: |
|
fallback_attention_mm = attention_sub_quad |
|
else: |
|
if args.use_split_cross_attention: |
|
optimized_attention_mm = attention_split |
|
else: |
|
optimized_attention_mm = attention_sub_quad |
|
|
|
|
|
class CrossAttentionMM(nn.Module): |
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, |
|
operations=comfy.ops.disable_weight_init): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
|
|
self.actual_attention = optimized_attention_mm |
|
self.heads = heads |
|
self.dim_head = dim_head |
|
self.scale = None |
|
self.default_scale = dim_head ** -0.5 |
|
|
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
|
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) |
|
|
|
def reset_attention_type(self): |
|
self.actual_attention = optimized_attention_mm |
|
|
|
def forward(self, x, context=None, value=None, mask=None, scale_mask=None, mm_kwargs=None, transformer_options=None): |
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k: Tensor = self.to_k(context) |
|
if value is not None: |
|
v = self.to_v(value) |
|
del value |
|
else: |
|
v = self.to_v(context) |
|
|
|
|
|
if self.scale is not None: |
|
k *= self.scale |
|
|
|
|
|
if scale_mask is not None: |
|
k *= scale_mask |
|
|
|
try: |
|
out = self.actual_attention(q, k, v, self.heads, mask) |
|
except RuntimeError as e: |
|
if str(e).startswith("CUDA error: invalid configuration argument"): |
|
self.actual_attention = fallback_attention_mm |
|
out = self.actual_attention(q, k, v, self.heads, mask) |
|
else: |
|
raise |
|
return self.to_out(out) |
|
|
|
|
|
class GroupNormAD(torch.nn.GroupNorm): |
|
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, |
|
device=None, dtype=None) -> None: |
|
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine, device=device, dtype=dtype) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return F.group_norm( |
|
input, self.num_groups, self.weight, self.bias, self.eps) |
|
|
|
|
|
|
|
|
|
def normalize_min_max(x: Tensor, new_min=0.0, new_max=1.0): |
|
return linear_conversion(x, x_min=x.min(), x_max=x.max(), new_min=new_min, new_max=new_max) |
|
|
|
|
|
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): |
|
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
|
|
|
|
|
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): |
|
mask = mask.clone() |
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") |
|
if match_dim1: |
|
mask = torch.cat([mask] * shape[1], dim=1) |
|
return mask |
|
|
|
|
|
def extend_to_batch_size(tensor: Tensor, batch_size: int): |
|
if tensor.shape[0] > batch_size: |
|
return tensor[:batch_size] |
|
elif tensor.shape[0] < batch_size: |
|
remainder = batch_size-tensor.shape[0] |
|
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0) |
|
return tensor |
|
|
|
|
|
def extend_list_to_batch_size(_list: list, batch_size: int): |
|
if len(_list) > batch_size: |
|
return _list[:batch_size] |
|
elif len(_list) < batch_size: |
|
return _list + _list[-1:]*(batch_size-len(_list)) |
|
return _list.copy() |
|
|
|
|
|
|
|
def ade_broadcast_image_to(tensor, target_batch_size, batched_number): |
|
current_batch_size = tensor.shape[0] |
|
|
|
if current_batch_size == 1: |
|
return tensor |
|
|
|
per_batch = target_batch_size // batched_number |
|
tensor = tensor[:per_batch] |
|
|
|
if per_batch > tensor.shape[0]: |
|
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) |
|
|
|
current_batch_size = tensor.shape[0] |
|
if current_batch_size == target_batch_size: |
|
return tensor |
|
else: |
|
return torch.cat([tensor] * batched_number, dim=0) |
|
|
|
|
|
|
|
def composite_extend(destination: Tensor, source: Tensor, x: int, y: int, mask: Tensor = None, multiplier = 8, resize_source = False): |
|
source = source.to(destination.device) |
|
if resize_source: |
|
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") |
|
|
|
source = extend_to_batch_size(source, destination.shape[0]) |
|
|
|
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) |
|
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) |
|
|
|
left, top = (x // multiplier, y // multiplier) |
|
right, bottom = (left + source.shape[3], top + source.shape[2],) |
|
|
|
if mask is None: |
|
mask = torch.ones_like(source) |
|
else: |
|
mask = mask.to(destination.device, copy=True) |
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") |
|
mask = extend_to_batch_size(mask, source.shape[0]) |
|
|
|
|
|
|
|
|
|
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) |
|
|
|
mask = mask[:, :, :visible_height, :visible_width] |
|
inverse_mask = torch.ones_like(mask) - mask |
|
|
|
source_portion = mask * source[:, :, :visible_height, :visible_width] |
|
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] |
|
|
|
destination[:, :, top:bottom, left:right] = source_portion + destination_portion |
|
return destination |
|
|
|
|
|
def get_sorted_list_via_attr(objects: list, attr: str) -> list: |
|
if not objects: |
|
return objects |
|
elif len(objects) <= 1: |
|
return [x for x in objects] |
|
|
|
|
|
|
|
unique_attrs = {} |
|
for o in objects: |
|
val_attr = getattr(o, attr) |
|
attr_list: list = unique_attrs.get(val_attr, list()) |
|
attr_list.append(o) |
|
if val_attr not in unique_attrs: |
|
unique_attrs[val_attr] = attr_list |
|
|
|
sorted_attrs = dict(sorted(unique_attrs.items())) |
|
|
|
sorted_list = [] |
|
for object_list in sorted_attrs.values(): |
|
sorted_list.extend(object_list) |
|
return sorted_list |
|
|
|
|
|
class MotionCompatibilityError(ValueError): |
|
pass |
|
|
|
|
|
class InputPIA(ABC): |
|
def __init__(self, effect_multival: Union[float, Tensor]=None): |
|
self.effect_multival = effect_multival if effect_multival is not None else 1.0 |
|
|
|
@abstractmethod |
|
def get_mask(self, x: Tensor): |
|
pass |
|
|
|
|
|
class InputPIA_Multival(InputPIA): |
|
def __init__(self, multival: Union[float, Tensor], effect_multival: Union[float, Tensor]=None): |
|
super().__init__(effect_multival=effect_multival) |
|
self.multival = multival |
|
|
|
def get_mask(self, x: Tensor): |
|
if type(self.multival) is Tensor: |
|
return self.multival |
|
|
|
b, c, h, w = x.shape |
|
mask = torch.ones(size=(b, h, w)) |
|
return mask * self.multival |
|
|
|
|
|
def create_multival_combo(float_val: Union[float, list[float]], mask_optional: Tensor=None): |
|
|
|
|
|
float_is_iterable = False |
|
if isinstance(float_val, Iterable): |
|
float_is_iterable = True |
|
float_val = list(float_val) |
|
|
|
if mask_optional is not None: |
|
if len(float_val) < mask_optional.shape[0]: |
|
|
|
float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) |
|
if mask_optional.shape[0] < len(float_val): |
|
mask_optional = extend_to_batch_size(mask_optional, len(float_val)) |
|
float_val = float_val[:mask_optional.shape[0]] |
|
float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) |
|
|
|
if mask_optional is not None: |
|
mask_optional = mask_optional.clone() |
|
if float_is_iterable: |
|
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) |
|
else: |
|
mask_optional = mask_optional * float_val |
|
return mask_optional |
|
else: |
|
if not float_is_iterable: |
|
return float_val |
|
|
|
|
|
float_len = float_val.shape[0] if float_is_iterable else 1 |
|
shape = (float_len,1,1) |
|
mask_optional = torch.ones(shape) |
|
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) |
|
return mask_optional |
|
|
|
|
|
def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor], force_leader_A=False) -> Union[float, Tensor]: |
|
if multivalA is None and multivalB is None: |
|
return 1.0 |
|
|
|
if multivalA is None: |
|
return multivalB |
|
elif multivalB is None: |
|
return multivalA |
|
|
|
|
|
if type(multivalA) == Tensor and type(multivalB) == Tensor: |
|
if force_leader_A: |
|
leader,follower = (multivalA,multivalB) |
|
batch_size = multivalA.shape[0] |
|
else: |
|
areaA = multivalA.shape[1]*multivalA.shape[2] |
|
areaB = multivalB.shape[1]*multivalB.shape[2] |
|
|
|
leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) |
|
batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] |
|
|
|
follower = torch.unsqueeze(follower, 1) |
|
follower = comfy.utils.common_upscale(follower, leader.shape[-1], leader.shape[-2], "bilinear", "center") |
|
follower = torch.squeeze(follower, 1) |
|
|
|
leader = extend_to_batch_size(leader, batch_size) |
|
follower = extend_to_batch_size(follower, batch_size) |
|
return leader * follower |
|
|
|
return multivalA * multivalB |
|
|
|
|
|
def resize_multival(multival: Union[float, Tensor], batch_size: int, height: int, width: int): |
|
if multival is None: |
|
return 1.0 |
|
if type(multival) != Tensor: |
|
return multival |
|
multival = torch.unsqueeze(multival, 1) |
|
multival = comfy.utils.common_upscale(multival, height, width, "bilinear", "center") |
|
multival = torch.squeeze(multival, 1) |
|
multival = extend_to_batch_size(multival, batch_size) |
|
return multival |
|
|
|
|
|
def get_combined_input(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None], x: Tensor): |
|
if inputA is None: |
|
inputA = InputPIA_Multival(1.0) |
|
if inputB is None: |
|
inputB = InputPIA_Multival(1.0) |
|
return get_combined_multival(inputA.get_mask(x), inputB.get_mask(x)) |
|
|
|
|
|
def get_combined_input_effect_multival(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None]): |
|
if inputA is None: |
|
inputA = InputPIA_Multival(1.0) |
|
if inputB is None: |
|
inputB = InputPIA_Multival(1.0) |
|
return get_combined_multival(inputA.effect_multival, inputB.effect_multival) |
|
|
|
|
|
class ADKeyframe: |
|
def __init__(self, |
|
start_percent: float = 0.0, |
|
scale_multival: Union[float, Tensor]=None, |
|
effect_multival: Union[float, Tensor]=None, |
|
cameractrl_multival: Union[float, Tensor]=None, |
|
pia_input: InputPIA=None, |
|
inherit_missing: bool=True, |
|
guarantee_steps: int=1, |
|
default: bool=False, |
|
): |
|
self.start_percent = start_percent |
|
self.start_t = 999999999.9 |
|
self.scale_multival = scale_multival |
|
self.effect_multival = effect_multival |
|
self.cameractrl_multival = cameractrl_multival |
|
self.pia_input = pia_input |
|
self.inherit_missing = inherit_missing |
|
self.guarantee_steps = guarantee_steps |
|
self.default = default |
|
|
|
def has_scale(self): |
|
return self.scale_multival is not None |
|
|
|
def has_effect(self): |
|
return self.effect_multival is not None |
|
|
|
def has_cameractrl_effect(self): |
|
return self.cameractrl_multival is not None |
|
|
|
def has_pia_input(self): |
|
return self.pia_input is not None |
|
|
|
|
|
class ADKeyframeGroup: |
|
def __init__(self): |
|
self.keyframes: list[ADKeyframe] = [] |
|
self.keyframes.append(ADKeyframe(guarantee_steps=1, default=True)) |
|
|
|
def add(self, keyframe: ADKeyframe): |
|
|
|
default_to_delete = [] |
|
for i in range(len(self.keyframes)): |
|
if self.keyframes[i].default and self.keyframes[i].start_percent == keyframe.start_percent: |
|
default_to_delete.append(i) |
|
for i in reversed(default_to_delete): |
|
self.keyframes.pop(i) |
|
|
|
self.keyframes.append(keyframe) |
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") |
|
|
|
def get_index(self, index: int) -> Union[ADKeyframe, None]: |
|
try: |
|
return self.keyframes[index] |
|
except IndexError: |
|
return None |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >=0 and index < len(self.keyframes) |
|
|
|
def __getitem__(self, index) -> ADKeyframe: |
|
return self.keyframes[index] |
|
|
|
def __len__(self) -> int: |
|
return len(self.keyframes) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self) -> 'ADKeyframeGroup': |
|
cloned = ADKeyframeGroup() |
|
for tk in self.keyframes: |
|
if not tk.default: |
|
cloned.add(tk) |
|
return cloned |
|
|
|
|
|
class DummyNNModule(nn.Module): |
|
class DoNothingWhenCalled: |
|
def __call__(self, *args, **kwargs): |
|
return |
|
|
|
''' |
|
Class that does not throw exceptions for almost anything you throw at it. As name implies, does nothing. |
|
''' |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def __getattr__(self, *args, **kwargs): |
|
return self.DoNothingWhenCalled() |
|
|
|
def __setattr__(self, name, value): |
|
pass |
|
|
|
def __iter__(self, *args, **kwargs): |
|
pass |
|
|
|
def __next__(self, *args, **kwargs): |
|
pass |
|
|
|
def __len__(self, *args, **kwargs): |
|
pass |
|
|
|
def __getitem__(self, *args, **kwargs): |
|
pass |
|
|
|
def __setitem__(self, *args, **kwargs): |
|
pass |
|
|
|
def __call__(self, *args, **kwargs): |
|
pass |
|
|