daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
8.95 kB
from typing import Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
import comfy.model_management as model_management
import comfy.ops
import comfy.utils
from comfy.cli_args import args
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
from .logger import logger
# until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out
# logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py
optimized_attention_mm = attention_basic
if model_management.xformers_enabled():
pass
#optimized_attention_mm = attention_xformers
if model_management.pytorch_attention_enabled():
optimized_attention_mm = attention_pytorch
else:
if args.use_split_cross_attention:
optimized_attention_mm = attention_split
else:
optimized_attention_mm = attention_sub_quad
class CrossAttentionMM(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None,
operations=comfy.ops.disable_weight_init):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.scale = None
self.default_scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None, scale_mask=None):
q = self.to_q(x)
context = default(context, x)
k: Tensor = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
# apply custom scale by multiplying k by scale factor
if self.scale is not None:
k *= self.scale
# apply scale mask, if present
if scale_mask is not None:
k *= scale_mask
out = optimized_attention_mm(q, k, v, self.heads, mask)
return self.to_out(out)
# TODO: set up comfy.ops style classes for groupnorm and other functions
class GroupNormAD(torch.nn.GroupNorm):
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
device=None, dtype=None) -> None:
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine, device=device, dtype=dtype)
def forward(self, input: Tensor) -> Tensor:
return F.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps)
# applies min-max normalization, from:
# https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0):
return linear_conversion(x, x_min=x.min(), x_max=x.max(), new_min=new_min, new_max=new_max)
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0):
x_min = float(x_min)
x_max = float(x_max)
new_min = float(new_min)
new_max = float(new_max)
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
# adapted from comfy/sample.py
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False):
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear")
if match_dim1:
mask = torch.cat([mask] * shape[1], dim=1)
return mask
def extend_to_batch_size(tensor: Tensor, batch_size: int):
if tensor.shape[0] > batch_size:
return tensor[:batch_size]
elif tensor.shape[0] < batch_size:
remainder = batch_size-tensor.shape[0]
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0)
return tensor
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
class MotionCompatibilityError(ValueError):
pass
def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor]) -> Union[float, Tensor]:
# if one is None, use the other
if multivalA == None:
return multivalB
elif multivalB == None:
return multivalA
# both have a value - combine them based on type
# if both are Tensors, make dims match before multiplying
if type(multivalA) == Tensor and type(multivalB) == Tensor:
areaA = multivalA.shape[1]*multivalA.shape[2]
areaB = multivalB.shape[1]*multivalB.shape[2]
# match height/width to mask with larger area
leader,follower = multivalA,multivalB if areaA >= areaB else multivalB,multivalA
batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0]
# make follower same dimensions as leader
follower = torch.unsqueeze(follower, 1)
follower = comfy.utils.common_upscale(follower, leader.shape[2], leader.shape[1], "bilinear", "center")
follower = torch.squeeze(follower, 1)
# make sure batch size will match
leader = extend_to_batch_size(leader, batch_size)
follower = extend_to_batch_size(follower, batch_size)
return leader * follower
# otherwise, just multiply them together - one of them is a float
return multivalA * multivalB
class ADKeyframe:
def __init__(self,
start_percent: float = 0.0,
scale_multival: Union[float, Tensor]=None,
effect_multival: Union[float, Tensor]=None,
inherit_missing: bool=True,
guarantee_steps: int=1,
default: bool=False,
):
self.start_percent = start_percent
self.start_t = 999999999.9
self.scale_multival = scale_multival
self.effect_multival = effect_multival
self.inherit_missing = inherit_missing
self.guarantee_steps = guarantee_steps
self.default = default
def has_scale(self):
return self.scale_multival is not None
def has_effect(self):
return self.effect_multival is not None
class ADKeyframeGroup:
def __init__(self):
self.keyframes: list[ADKeyframe] = []
self.keyframes.append(ADKeyframe(guarantee_steps=1, default=True))
def add(self, keyframe: ADKeyframe):
# remove any default keyframes that match start_percent of new keyframe
default_to_delete = []
for i in range(len(self.keyframes)):
if self.keyframes[i].default and self.keyframes[i].start_percent == keyframe.start_percent:
default_to_delete.append(i)
for i in reversed(default_to_delete):
self.keyframes.pop(i)
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
def get_index(self, index: int) -> Union[ADKeyframe, None]:
try:
return self.keyframes[index]
except IndexError:
return None
def has_index(self, index: int) -> int:
return index >=0 and index < len(self.keyframes)
def __getitem__(self, index) -> ADKeyframe:
return self.keyframes[index]
def __len__(self) -> int:
return len(self.keyframes)
def is_empty(self) -> bool:
return len(self.keyframes) == 0
def clone(self) -> 'ADKeyframeGroup':
cloned = ADKeyframeGroup()
for tk in self.keyframes:
if not tk.default:
cloned.add(tk)
return cloned