Spaces:
Running
Running
from typing import Union | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
import comfy.model_management as model_management | |
import comfy.ops | |
import comfy.utils | |
from comfy.cli_args import args | |
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default | |
from .logger import logger | |
# until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out | |
# logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py | |
optimized_attention_mm = attention_basic | |
if model_management.xformers_enabled(): | |
pass | |
#optimized_attention_mm = attention_xformers | |
if model_management.pytorch_attention_enabled(): | |
optimized_attention_mm = attention_pytorch | |
else: | |
if args.use_split_cross_attention: | |
optimized_attention_mm = attention_split | |
else: | |
optimized_attention_mm = attention_sub_quad | |
class CrossAttentionMM(nn.Module): | |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, | |
operations=comfy.ops.disable_weight_init): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.heads = heads | |
self.dim_head = dim_head | |
self.scale = None | |
self.default_scale = dim_head ** -0.5 | |
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) | |
def forward(self, x, context=None, value=None, mask=None, scale_mask=None): | |
q = self.to_q(x) | |
context = default(context, x) | |
k: Tensor = self.to_k(context) | |
if value is not None: | |
v = self.to_v(value) | |
del value | |
else: | |
v = self.to_v(context) | |
# apply custom scale by multiplying k by scale factor | |
if self.scale is not None: | |
k *= self.scale | |
# apply scale mask, if present | |
if scale_mask is not None: | |
k *= scale_mask | |
out = optimized_attention_mm(q, k, v, self.heads, mask) | |
return self.to_out(out) | |
# TODO: set up comfy.ops style classes for groupnorm and other functions | |
class GroupNormAD(torch.nn.GroupNorm): | |
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, | |
device=None, dtype=None) -> None: | |
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine, device=device, dtype=dtype) | |
def forward(self, input: Tensor) -> Tensor: | |
return F.group_norm( | |
input, self.num_groups, self.weight, self.bias, self.eps) | |
# applies min-max normalization, from: | |
# https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch | |
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): | |
return linear_conversion(x, x_min=x.min(), x_max=x.max(), new_min=new_min, new_max=new_max) | |
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): | |
x_min = float(x_min) | |
x_max = float(x_max) | |
new_min = float(new_min) | |
new_max = float(new_max) | |
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min | |
# adapted from comfy/sample.py | |
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): | |
mask = mask.clone() | |
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") | |
if match_dim1: | |
mask = torch.cat([mask] * shape[1], dim=1) | |
return mask | |
def extend_to_batch_size(tensor: Tensor, batch_size: int): | |
if tensor.shape[0] > batch_size: | |
return tensor[:batch_size] | |
elif tensor.shape[0] < batch_size: | |
remainder = batch_size-tensor.shape[0] | |
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0) | |
return tensor | |
def get_sorted_list_via_attr(objects: list, attr: str) -> list: | |
if not objects: | |
return objects | |
elif len(objects) <= 1: | |
return [x for x in objects] | |
# now that we know we have to sort, do it following these rules: | |
# a) if objects have same value of attribute, maintain their relative order | |
# b) perform sorting of the groups of objects with same attributes | |
unique_attrs = {} | |
for o in objects: | |
val_attr = getattr(o, attr) | |
attr_list = unique_attrs.get(val_attr, list()) | |
attr_list.append(o) | |
if val_attr not in unique_attrs: | |
unique_attrs[val_attr] = attr_list | |
# now that we have the unique attr values grouped together in relative order, sort them by key | |
sorted_attrs = dict(sorted(unique_attrs.items())) | |
# now flatten out the dict into a list to return | |
sorted_list = [] | |
for object_list in sorted_attrs.values(): | |
sorted_list.extend(object_list) | |
return sorted_list | |
class MotionCompatibilityError(ValueError): | |
pass | |
def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]: | |
# if one is None, use the other | |
if multivalA == None: | |
return multivalB | |
elif multivalB == None: | |
return multivalA | |
# both have a value - combine them based on type | |
# if both are Tensors, make dims match before multiplying | |
if type(multivalA) == Tensor and type(multivalB) == Tensor: | |
areaA = multivalA.shape[1]*multivalA.shape[2] | |
areaB = multivalB.shape[1]*multivalB.shape[2] | |
# match height/width to mask with larger area | |
leader,follower = multivalA,multivalB if areaA >= areaB else multivalB,multivalA | |
batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] | |
# make follower same dimensions as leader | |
follower = torch.unsqueeze(follower, 1) | |
follower = comfy.utils.common_upscale(follower, leader.shape[2], leader.shape[1], "bilinear", "center") | |
follower = torch.squeeze(follower, 1) | |
# make sure batch size will match | |
leader = extend_to_batch_size(leader, batch_size) | |
follower = extend_to_batch_size(follower, batch_size) | |
return leader * follower | |
# otherwise, just multiply them together - one of them is a float | |
return multivalA * multivalB | |
class ADKeyframe: | |
def __init__(self, | |
start_percent: float = 0.0, | |
scale_multival: Union[float, Tensor]=None, | |
effect_multival: Union[float, Tensor]=None, | |
inherit_missing: bool=True, | |
guarantee_steps: int=1, | |
default: bool=False, | |
): | |
self.start_percent = start_percent | |
self.start_t = 999999999.9 | |
self.scale_multival = scale_multival | |
self.effect_multival = effect_multival | |
self.inherit_missing = inherit_missing | |
self.guarantee_steps = guarantee_steps | |
self.default = default | |
def has_scale(self): | |
return self.scale_multival is not None | |
def has_effect(self): | |
return self.effect_multival is not None | |
class ADKeyframeGroup: | |
def __init__(self): | |
self.keyframes: list[ADKeyframe] = [] | |
self.keyframes.append(ADKeyframe(guarantee_steps=1, default=True)) | |
def add(self, keyframe: ADKeyframe): | |
# remove any default keyframes that match start_percent of new keyframe | |
default_to_delete = [] | |
for i in range(len(self.keyframes)): | |
if self.keyframes[i].default and self.keyframes[i].start_percent == keyframe.start_percent: | |
default_to_delete.append(i) | |
for i in reversed(default_to_delete): | |
self.keyframes.pop(i) | |
# add to end of list, then sort | |
self.keyframes.append(keyframe) | |
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") | |
def get_index(self, index: int) -> Union[ADKeyframe, None]: | |
try: | |
return self.keyframes[index] | |
except IndexError: | |
return None | |
def has_index(self, index: int) -> int: | |
return index >=0 and index < len(self.keyframes) | |
def __getitem__(self, index) -> ADKeyframe: | |
return self.keyframes[index] | |
def __len__(self) -> int: | |
return len(self.keyframes) | |
def is_empty(self) -> bool: | |
return len(self.keyframes) == 0 | |
def clone(self) -> 'ADKeyframeGroup': | |
cloned = ADKeyframeGroup() | |
for tk in self.keyframes: | |
if not tk.default: | |
cloned.add(tk) | |
return cloned | |