jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from abc import ABC, abstractmethod
from typing import Union
from torch import Tensor
import math
from .utils_motion import normalize_min_max
from .logger import logger
class AnimateDiffSettings:
def __init__(self,
adjust_pe: 'AdjustGroup'=None,
adjust_weight: 'AdjustGroup'=None,
attn_scale: float=1.0,
mask_attn_scale: Tensor=None,
mask_attn_scale_min: float=1.0,
mask_attn_scale_max: float=1.0,
):
# PE-interpolation settings
self.adjust_pe = adjust_pe if adjust_pe is not None else AdjustGroup()
# Weight settings
self.adjust_weight = adjust_weight if adjust_weight is not None else AdjustGroup()
# attention scale settings - DEPRECATED (part of scale_multival now)
self.attn_scale = attn_scale
# attention scale mask settings - DEPRECATED (part of scale_multival now)
self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale
self.mask_attn_scale_min = mask_attn_scale_min
self.mask_attn_scale_max = mask_attn_scale_max
self._prepare_mask_attn_scale()
def _prepare_mask_attn_scale(self):
if self.mask_attn_scale is not None:
self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max)
def has_mask_attn_scale(self) -> bool:
return self.mask_attn_scale is not None
def has_anything_to_apply(self) -> bool:
return self.adjust_pe.has_anything_to_apply() \
or self.adjust_weight.has_anything_to_apply()
class AdjustAbstract(ABC):
def __init__(self, print_adjustment=False):
self.print_adjustment = print_adjustment
@abstractmethod
def has_anything_to_apply(self):
return False
class AdjustPE(AdjustAbstract):
def __init__(self,
cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0,
initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0,
motion_pe_stretch: int=0, print_adjustment=False):
super().__init__(print_adjustment=print_adjustment)
# PE-interpolation settings
self.cap_initial_pe_length = cap_initial_pe_length
self.interpolate_pe_to_length = interpolate_pe_to_length
self.initial_pe_idx_offset = initial_pe_idx_offset
self.final_pe_idx_offset = final_pe_idx_offset
self.motion_pe_stretch = motion_pe_stretch
def has_cap_initial_pe_length(self) -> bool:
return self.cap_initial_pe_length > 0
def has_interpolate_pe_to_length(self) -> bool:
return self.interpolate_pe_to_length > 0
def has_initial_pe_idx_offset(self) -> bool:
return self.initial_pe_idx_offset > 0
def has_final_pe_idx_offset(self) -> bool:
return self.final_pe_idx_offset > 0
def has_motion_pe_stretch(self) -> bool:
return self.motion_pe_stretch > 0
def has_anything_to_apply(self) -> bool:
return self.has_cap_initial_pe_length() \
or self.has_interpolate_pe_to_length() \
or self.has_initial_pe_idx_offset() \
or self.has_final_pe_idx_offset() \
or self.has_motion_pe_stretch()
class AdjustWeight(AdjustAbstract):
# possible operations
OP_ANY = "_____ANY"
OP_ADD = "_ADD"
OP_MULT = "_MULT"
OPS = [OP_ADD, OP_MULT]
# possible attributes
ATTR_ALL = "all"
ATTR_PE = "pe"
ATTR_ATTN = "attn"
ATTR_ATTN_Q = "attn_q"
ATTR_ATTN_K = "attn_k"
ATTR_ATTN_V = "attn_v"
ATTR_ATTN_OUT_WEIGHT = "attn_out_weight"
ATTR_ATTN_OUT_BIAS = "attn_out_bias"
ATTR_OTHER = "other"
ATTRS = [ATTR_ALL, ATTR_PE, ATTR_ATTN, ATTR_ATTN_Q, ATTR_ATTN_K, ATTR_ATTN_V, ATTR_ATTN_OUT_WEIGHT, ATTR_ATTN_OUT_BIAS, ATTR_OTHER]
def __init__(self,
all_ADD=0.0, all_MULT=1.0,
pe_ADD=0.0, pe_MULT=1.0,
attn_ADD=0.0, attn_MULT=1.0,
attn_q_ADD=0.0, attn_q_MULT=1.0,
attn_k_ADD=0.0, attn_k_MULT=1.0,
attn_v_ADD=0.0, attn_v_MULT=1.0,
attn_out_weight_ADD=0.0, attn_out_weight_MULT=1.0,
attn_out_bias_ADD=0.0, attn_out_bias_MULT=1.0,
other_ADD=0.0, other_MULT=1.0,
print_adjustment=False):
# all
self.all_ADD = all_ADD
self.all_MULT = all_MULT
# pe
self.pe_ADD = pe_ADD
self.pe_MULT = pe_MULT
# attn
self.attn_ADD = attn_ADD
self.attn_MULT = attn_MULT
# attn_q
self.attn_q_ADD = attn_q_ADD
self.attn_q_MULT = attn_q_MULT
# attn_k
self.attn_k_ADD = attn_k_ADD
self.attn_k_MULT = attn_k_MULT
# attn_v
self.attn_v_ADD = attn_v_ADD
self.attn_v_MULT = attn_v_MULT
# attn_out_weight
self.attn_out_weight_ADD = attn_out_weight_ADD
self.attn_out_weight_MULT = attn_out_weight_MULT
# attn_out_bias
self.attn_out_bias_ADD = attn_out_bias_ADD
self.attn_out_bias_MULT = attn_out_bias_MULT
# other
self.other_ADD = other_ADD
self.other_MULT = other_MULT
# additional vars
self.print_adjustment = print_adjustment
# temp var
self.already_printed: dict[str, bool] = {}
self.mark_attrs_as_unprinted()
def mark_attrs_as_unprinted(self):
for attr in self.ATTRS:
for op in self.OPS:
self.already_printed[attr+op] = False
def mask_as_printed(self, attr: str, op: str):
self.already_printed[attr+op] = True
def is_already_printed(self, attr: str, op: str):
return self.already_printed.get(attr+op, False)
def _get_val(self, op: str, attr: str) -> float:
try:
return getattr(self, attr+op)
except AttributeError:
raise Exception(f"Parameter '{attr+op}' could not be found in AdjustWeight class.")
def _has_OP(self, op: str, attr: str):
value = self._get_val(op=op, attr=attr)
if op == self.OP_ADD:
return not math.isclose(value, 0.0)
elif op == self.OP_MULT:
return not math.isclose(value, 1.0)
else:
raise Exception(f"Operation '{op}' not recognized in AdjustWeight.")
def _has_apply(self, op: str, attr: str):
# determine if attr with specific operation is to be applied
if op == self.OP_ANY:
any = False
for one_op in self.OPS:
any = any or self._has_OP(op=one_op, attr=attr)
return any
return self._has_OP(op=op, attr=attr)
def has_all(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ALL)
def has_pe(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_PE)
def has_attn(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN)
def has_attn_q(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN_Q)
def has_attn_k(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN_K)
def has_attn_v(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN_V)
def has_attn_out_weight(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN_OUT_WEIGHT)
def has_attn_out_bias(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_ATTN_OUT_BIAS)
def has_other(self, op: str) -> bool:
return self._has_apply(op, self.ATTR_OTHER)
def has_anything_to_apply(self):
return self.has_all(self.OP_ANY) \
or self.has_pe(self.OP_ANY) \
or self.has_attn(self.OP_ANY) \
or self.has_attn_q(self.OP_ANY) \
or self.has_attn_k(self.OP_ANY) \
or self.has_attn_v(self.OP_ANY) \
or self.has_attn_out_weight(self.OP_ANY) \
or self.has_attn_out_bias(self.OP_ANY) \
or self.has_other(self.OP_ANY)
def _perform_op(self, model_dict: dict[str, Tensor], key: str, op: str, attr: str):
val = self._get_val(op=op, attr=attr)
specific_str = f"'{attr}' weights" if attr == self.ATTR_ALL else f"every '{attr}' weight"
if op == self.OP_ADD:
model_dict[key] += val
if self.print_adjustment and not self.is_already_printed(attr=attr, op=op):
logger.info(f"[Adjust Weight]: Adding to {specific_str} value {val}")
self.mask_as_printed(attr=attr, op=op)
elif op == self.OP_MULT:
model_dict[key] *= val
if self.print_adjustment and not self.is_already_printed(attr=attr, op=op):
logger.info(f"[Adjust Weight]: Multiplying {specific_str} by {val}")
self.mask_as_printed(attr=attr, op=op)
else:
raise Exception(f"Operation '{op}' not recognized in AdjustWeight.")
def perform_applicable_ops(self, attr: str, model_dict: dict[str, Tensor], key: str):
for op in self.OPS:
if self._has_apply(op=op, attr=attr):
self._perform_op(model_dict=model_dict, key=key, op=op, attr=attr)
ADJUST_TYPES = Union[AdjustPE, AdjustWeight]
class AdjustGroup:
def __init__(self, initial: ADJUST_TYPES=None):
self.adjusts: list[ADJUST_TYPES] = []
if initial is not None:
self.add(initial)
def add(self, adjust: ADJUST_TYPES):
self.adjusts.append(adjust)
def has_anything_to_apply(self):
for adjust in self.adjusts:
if adjust.has_anything_to_apply():
return True
return False
def clone(self):
new_group = AdjustGroup()
for adjust in self.adjusts:
new_group.add(adjust=adjust)
return new_group