from typing import Union import torch import torch.nn.functional as F from torch import Tensor, nn from abc import ABC, abstractmethod from collections.abc import Iterable import comfy.model_management as model_management import comfy.ops import comfy.utils from comfy.cli_args import args from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default from .logger import logger # until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out # logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py # a fallback_attention_mm is selected to avoid CUDA configuration limitation with pytorch's scaled_dot_product optimized_attention_mm = attention_basic fallback_attention_mm = attention_basic if model_management.xformers_enabled(): pass #optimized_attention_mm = attention_xformers if model_management.pytorch_attention_enabled(): optimized_attention_mm = attention_pytorch if args.use_split_cross_attention: fallback_attention_mm = attention_split else: fallback_attention_mm = attention_sub_quad else: if args.use_split_cross_attention: optimized_attention_mm = attention_split else: optimized_attention_mm = attention_sub_quad class CrossAttentionMM(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops.disable_weight_init): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.actual_attention = optimized_attention_mm self.heads = heads self.dim_head = dim_head self.scale = None self.default_scale = dim_head ** -0.5 self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) def reset_attention_type(self): self.actual_attention = optimized_attention_mm def forward(self, x, context=None, value=None, mask=None, scale_mask=None, mm_kwargs=None, transformer_options=None): q = self.to_q(x) context = default(context, x) k: Tensor = self.to_k(context) if value is not None: v = self.to_v(value) del value else: v = self.to_v(context) # apply custom scale by multiplying k by scale factor if self.scale is not None: k *= self.scale # apply scale mask, if present if scale_mask is not None: k *= scale_mask try: out = self.actual_attention(q, k, v, self.heads, mask) except RuntimeError as e: if str(e).startswith("CUDA error: invalid configuration argument"): self.actual_attention = fallback_attention_mm out = self.actual_attention(q, k, v, self.heads, mask) else: raise return self.to_out(out) # TODO: set up comfy.ops style classes for groupnorm and other functions class GroupNormAD(torch.nn.GroupNorm): def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, device=None, dtype=None) -> None: super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine, device=device, dtype=dtype) def forward(self, input: Tensor) -> Tensor: return F.group_norm( input, self.num_groups, self.weight, self.bias, self.eps) # applies min-max normalization, from: # https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch def normalize_min_max(x: Tensor, new_min=0.0, new_max=1.0): return linear_conversion(x, x_min=x.min(), x_max=x.max(), new_min=new_min, new_max=new_max) def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min # adapted from comfy/sample.py def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): mask = mask.clone() mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") if match_dim1: mask = torch.cat([mask] * shape[1], dim=1) return mask def extend_to_batch_size(tensor: Tensor, batch_size: int): if tensor.shape[0] > batch_size: return tensor[:batch_size] elif tensor.shape[0] < batch_size: remainder = batch_size-tensor.shape[0] return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0) return tensor def extend_list_to_batch_size(_list: list, batch_size: int): if len(_list) > batch_size: return _list[:batch_size] elif len(_list) < batch_size: return _list + _list[-1:]*(batch_size-len(_list)) return _list.copy() # from comfy/controlnet.py def ade_broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor per_batch = target_batch_size // batched_number tensor = tensor[:per_batch] if per_batch > tensor.shape[0]: tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) current_batch_size = tensor.shape[0] if current_batch_size == target_batch_size: return tensor else: return torch.cat([tensor] * batched_number, dim=0) # originally from comfy_extras/nodes_mask.py::composite function def composite_extend(destination: Tensor, source: Tensor, x: int, y: int, mask: Tensor = None, multiplier = 8, resize_source = False): source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = extend_to_batch_size(source, destination.shape[0]) x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) left, top = (x // multiplier, y // multiplier) right, bottom = (left + source.shape[3], top + source.shape[2],) if mask is None: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = extend_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] inverse_mask = torch.ones_like(mask) - mask source_portion = mask * source[:, :, :visible_height, :visible_width] destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] destination[:, :, top:bottom, left:right] = source_portion + destination_portion return destination def get_sorted_list_via_attr(objects: list, attr: str) -> list: if not objects: return objects elif len(objects) <= 1: return [x for x in objects] # now that we know we have to sort, do it following these rules: # a) if objects have same value of attribute, maintain their relative order # b) perform sorting of the groups of objects with same attributes unique_attrs = {} for o in objects: val_attr = getattr(o, attr) attr_list: list = unique_attrs.get(val_attr, list()) attr_list.append(o) if val_attr not in unique_attrs: unique_attrs[val_attr] = attr_list # now that we have the unique attr values grouped together in relative order, sort them by key sorted_attrs = dict(sorted(unique_attrs.items())) # now flatten out the dict into a list to return sorted_list = [] for object_list in sorted_attrs.values(): sorted_list.extend(object_list) return sorted_list class MotionCompatibilityError(ValueError): pass class InputPIA(ABC): def __init__(self, effect_multival: Union[float, Tensor]=None): self.effect_multival = effect_multival if effect_multival is not None else 1.0 @abstractmethod def get_mask(self, x: Tensor): pass class InputPIA_Multival(InputPIA): def __init__(self, multival: Union[float, Tensor], effect_multival: Union[float, Tensor]=None): super().__init__(effect_multival=effect_multival) self.multival = multival def get_mask(self, x: Tensor): if type(self.multival) is Tensor: return self.multival # if not Tensor, then is float, and simply return a mask with the right dimensions + value b, c, h, w = x.shape mask = torch.ones(size=(b, h, w)) return mask * self.multival def create_multival_combo(float_val: Union[float, list[float]], mask_optional: Tensor=None): # first, normalize inputs # if float_val is iterable, treat as a list and assume inputs are floats float_is_iterable = False if isinstance(float_val, Iterable): float_is_iterable = True float_val = list(float_val) # if mask present, make sure float_val list can be applied to list - match lengths if mask_optional is not None: if len(float_val) < mask_optional.shape[0]: # copies last entry enough times to match mask shape float_val = extend_list_to_batch_size(float_val, mask_optional.shape[0]) if mask_optional.shape[0] < len(float_val): mask_optional = extend_to_batch_size(mask_optional, len(float_val)) float_val = float_val[:mask_optional.shape[0]] float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) # now that inputs are normalized, figure out what value to actually return if mask_optional is not None: mask_optional = mask_optional.clone() if float_is_iterable: mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) else: mask_optional = mask_optional * float_val return mask_optional else: if not float_is_iterable: return float_val # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) # purpose is for float input to work with mask code, without special cases float_len = float_val.shape[0] if float_is_iterable else 1 shape = (float_len,1,1) mask_optional = torch.ones(shape) mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) return mask_optional def get_combined_multival(multivalA: Union[float, Tensor], multivalB: Union[float, Tensor], force_leader_A=False) -> Union[float, Tensor]: if multivalA is None and multivalB is None: return 1.0 # if one is None, use the other if multivalA is None: return multivalB elif multivalB is None: return multivalA # both have a value - combine them based on type # if both are Tensors, make dims match before multiplying if type(multivalA) == Tensor and type(multivalB) == Tensor: if force_leader_A: leader,follower = (multivalA,multivalB) batch_size = multivalA.shape[0] else: areaA = multivalA.shape[1]*multivalA.shape[2] areaB = multivalB.shape[1]*multivalB.shape[2] # match height/width to mask with larger area leader,follower = (multivalA,multivalB) if areaA >= areaB else (multivalB,multivalA) batch_size = multivalA.shape[0] if multivalA.shape[0] >= multivalB.shape[0] else multivalB.shape[0] # make follower same dimensions as leader follower = torch.unsqueeze(follower, 1) follower = comfy.utils.common_upscale(follower, leader.shape[-1], leader.shape[-2], "bilinear", "center") follower = torch.squeeze(follower, 1) # make sure batch size will match leader = extend_to_batch_size(leader, batch_size) follower = extend_to_batch_size(follower, batch_size) return leader * follower # otherwise, just multiply them together - one of them is a float return multivalA * multivalB def resize_multival(multival: Union[float, Tensor], batch_size: int, height: int, width: int): if multival is None: return 1.0 if type(multival) != Tensor: return multival multival = torch.unsqueeze(multival, 1) multival = comfy.utils.common_upscale(multival, height, width, "bilinear", "center") multival = torch.squeeze(multival, 1) multival = extend_to_batch_size(multival, batch_size) return multival def get_combined_input(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None], x: Tensor): if inputA is None: inputA = InputPIA_Multival(1.0) if inputB is None: inputB = InputPIA_Multival(1.0) return get_combined_multival(inputA.get_mask(x), inputB.get_mask(x)) def get_combined_input_effect_multival(inputA: Union[InputPIA, None], inputB: Union[InputPIA, None]): if inputA is None: inputA = InputPIA_Multival(1.0) if inputB is None: inputB = InputPIA_Multival(1.0) return get_combined_multival(inputA.effect_multival, inputB.effect_multival) class ADKeyframe: def __init__(self, start_percent: float = 0.0, scale_multival: Union[float, Tensor]=None, effect_multival: Union[float, Tensor]=None, cameractrl_multival: Union[float, Tensor]=None, pia_input: InputPIA=None, inherit_missing: bool=True, guarantee_steps: int=1, default: bool=False, ): self.start_percent = start_percent self.start_t = 999999999.9 self.scale_multival = scale_multival self.effect_multival = effect_multival self.cameractrl_multival = cameractrl_multival self.pia_input = pia_input self.inherit_missing = inherit_missing self.guarantee_steps = guarantee_steps self.default = default def has_scale(self): return self.scale_multival is not None def has_effect(self): return self.effect_multival is not None def has_cameractrl_effect(self): return self.cameractrl_multival is not None def has_pia_input(self): return self.pia_input is not None class ADKeyframeGroup: def __init__(self): self.keyframes: list[ADKeyframe] = [] self.keyframes.append(ADKeyframe(guarantee_steps=1, default=True)) def add(self, keyframe: ADKeyframe): # remove any default keyframes that match start_percent of new keyframe default_to_delete = [] for i in range(len(self.keyframes)): if self.keyframes[i].default and self.keyframes[i].start_percent == keyframe.start_percent: default_to_delete.append(i) for i in reversed(default_to_delete): self.keyframes.pop(i) # add to end of list, then sort self.keyframes.append(keyframe) self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") def get_index(self, index: int) -> Union[ADKeyframe, None]: try: return self.keyframes[index] except IndexError: return None def has_index(self, index: int) -> int: return index >=0 and index < len(self.keyframes) def __getitem__(self, index) -> ADKeyframe: return self.keyframes[index] def __len__(self) -> int: return len(self.keyframes) def is_empty(self) -> bool: return len(self.keyframes) == 0 def clone(self) -> 'ADKeyframeGroup': cloned = ADKeyframeGroup() for tk in self.keyframes: if not tk.default: cloned.add(tk) return cloned class DummyNNModule(nn.Module): class DoNothingWhenCalled: def __call__(self, *args, **kwargs): return ''' Class that does not throw exceptions for almost anything you throw at it. As name implies, does nothing. ''' def __init__(self): super().__init__() def __getattr__(self, *args, **kwargs): return self.DoNothingWhenCalled() def __setattr__(self, name, value): pass def __iter__(self, *args, **kwargs): pass def __next__(self, *args, **kwargs): pass def __len__(self, *args, **kwargs): pass def __getitem__(self, *args, **kwargs): pass def __setitem__(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): pass