|
from abc import ABC, abstractmethod |
|
from typing import Union |
|
from torch import Tensor |
|
import math |
|
|
|
from .utils_motion import normalize_min_max |
|
from .logger import logger |
|
|
|
|
|
class AnimateDiffSettings: |
|
def __init__(self, |
|
adjust_pe: 'AdjustGroup'=None, |
|
adjust_weight: 'AdjustGroup'=None, |
|
attn_scale: float=1.0, |
|
mask_attn_scale: Tensor=None, |
|
mask_attn_scale_min: float=1.0, |
|
mask_attn_scale_max: float=1.0, |
|
): |
|
|
|
self.adjust_pe = adjust_pe if adjust_pe is not None else AdjustGroup() |
|
|
|
self.adjust_weight = adjust_weight if adjust_weight is not None else AdjustGroup() |
|
|
|
self.attn_scale = attn_scale |
|
|
|
self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale |
|
self.mask_attn_scale_min = mask_attn_scale_min |
|
self.mask_attn_scale_max = mask_attn_scale_max |
|
self._prepare_mask_attn_scale() |
|
|
|
def _prepare_mask_attn_scale(self): |
|
if self.mask_attn_scale is not None: |
|
self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max) |
|
|
|
def has_mask_attn_scale(self) -> bool: |
|
return self.mask_attn_scale is not None |
|
|
|
def has_anything_to_apply(self) -> bool: |
|
return self.adjust_pe.has_anything_to_apply() \ |
|
or self.adjust_weight.has_anything_to_apply() |
|
|
|
|
|
class AdjustAbstract(ABC): |
|
def __init__(self, print_adjustment=False): |
|
self.print_adjustment = print_adjustment |
|
|
|
@abstractmethod |
|
def has_anything_to_apply(self): |
|
return False |
|
|
|
|
|
class AdjustPE(AdjustAbstract): |
|
def __init__(self, |
|
cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0, |
|
initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0, |
|
motion_pe_stretch: int=0, print_adjustment=False): |
|
super().__init__(print_adjustment=print_adjustment) |
|
|
|
self.cap_initial_pe_length = cap_initial_pe_length |
|
self.interpolate_pe_to_length = interpolate_pe_to_length |
|
self.initial_pe_idx_offset = initial_pe_idx_offset |
|
self.final_pe_idx_offset = final_pe_idx_offset |
|
self.motion_pe_stretch = motion_pe_stretch |
|
|
|
def has_cap_initial_pe_length(self) -> bool: |
|
return self.cap_initial_pe_length > 0 |
|
|
|
def has_interpolate_pe_to_length(self) -> bool: |
|
return self.interpolate_pe_to_length > 0 |
|
|
|
def has_initial_pe_idx_offset(self) -> bool: |
|
return self.initial_pe_idx_offset > 0 |
|
|
|
def has_final_pe_idx_offset(self) -> bool: |
|
return self.final_pe_idx_offset > 0 |
|
|
|
def has_motion_pe_stretch(self) -> bool: |
|
return self.motion_pe_stretch > 0 |
|
|
|
def has_anything_to_apply(self) -> bool: |
|
return self.has_cap_initial_pe_length() \ |
|
or self.has_interpolate_pe_to_length() \ |
|
or self.has_initial_pe_idx_offset() \ |
|
or self.has_final_pe_idx_offset() \ |
|
or self.has_motion_pe_stretch() |
|
|
|
|
|
class AdjustWeight(AdjustAbstract): |
|
|
|
OP_ANY = "_____ANY" |
|
OP_ADD = "_ADD" |
|
OP_MULT = "_MULT" |
|
OPS = [OP_ADD, OP_MULT] |
|
|
|
ATTR_ALL = "all" |
|
ATTR_PE = "pe" |
|
ATTR_ATTN = "attn" |
|
ATTR_ATTN_Q = "attn_q" |
|
ATTR_ATTN_K = "attn_k" |
|
ATTR_ATTN_V = "attn_v" |
|
ATTR_ATTN_OUT_WEIGHT = "attn_out_weight" |
|
ATTR_ATTN_OUT_BIAS = "attn_out_bias" |
|
ATTR_OTHER = "other" |
|
ATTRS = [ATTR_ALL, ATTR_PE, ATTR_ATTN, ATTR_ATTN_Q, ATTR_ATTN_K, ATTR_ATTN_V, ATTR_ATTN_OUT_WEIGHT, ATTR_ATTN_OUT_BIAS, ATTR_OTHER] |
|
|
|
def __init__(self, |
|
all_ADD=0.0, all_MULT=1.0, |
|
pe_ADD=0.0, pe_MULT=1.0, |
|
attn_ADD=0.0, attn_MULT=1.0, |
|
attn_q_ADD=0.0, attn_q_MULT=1.0, |
|
attn_k_ADD=0.0, attn_k_MULT=1.0, |
|
attn_v_ADD=0.0, attn_v_MULT=1.0, |
|
attn_out_weight_ADD=0.0, attn_out_weight_MULT=1.0, |
|
attn_out_bias_ADD=0.0, attn_out_bias_MULT=1.0, |
|
other_ADD=0.0, other_MULT=1.0, |
|
print_adjustment=False): |
|
|
|
self.all_ADD = all_ADD |
|
self.all_MULT = all_MULT |
|
|
|
self.pe_ADD = pe_ADD |
|
self.pe_MULT = pe_MULT |
|
|
|
self.attn_ADD = attn_ADD |
|
self.attn_MULT = attn_MULT |
|
|
|
self.attn_q_ADD = attn_q_ADD |
|
self.attn_q_MULT = attn_q_MULT |
|
|
|
self.attn_k_ADD = attn_k_ADD |
|
self.attn_k_MULT = attn_k_MULT |
|
|
|
self.attn_v_ADD = attn_v_ADD |
|
self.attn_v_MULT = attn_v_MULT |
|
|
|
self.attn_out_weight_ADD = attn_out_weight_ADD |
|
self.attn_out_weight_MULT = attn_out_weight_MULT |
|
|
|
self.attn_out_bias_ADD = attn_out_bias_ADD |
|
self.attn_out_bias_MULT = attn_out_bias_MULT |
|
|
|
self.other_ADD = other_ADD |
|
self.other_MULT = other_MULT |
|
|
|
self.print_adjustment = print_adjustment |
|
|
|
self.already_printed: dict[str, bool] = {} |
|
self.mark_attrs_as_unprinted() |
|
|
|
def mark_attrs_as_unprinted(self): |
|
for attr in self.ATTRS: |
|
for op in self.OPS: |
|
self.already_printed[attr+op] = False |
|
|
|
def mask_as_printed(self, attr: str, op: str): |
|
self.already_printed[attr+op] = True |
|
|
|
def is_already_printed(self, attr: str, op: str): |
|
return self.already_printed.get(attr+op, False) |
|
|
|
def _get_val(self, op: str, attr: str) -> float: |
|
try: |
|
return getattr(self, attr+op) |
|
except AttributeError: |
|
raise Exception(f"Parameter '{attr+op}' could not be found in AdjustWeight class.") |
|
|
|
def _has_OP(self, op: str, attr: str): |
|
value = self._get_val(op=op, attr=attr) |
|
if op == self.OP_ADD: |
|
return not math.isclose(value, 0.0) |
|
elif op == self.OP_MULT: |
|
return not math.isclose(value, 1.0) |
|
else: |
|
raise Exception(f"Operation '{op}' not recognized in AdjustWeight.") |
|
|
|
def _has_apply(self, op: str, attr: str): |
|
|
|
if op == self.OP_ANY: |
|
any = False |
|
for one_op in self.OPS: |
|
any = any or self._has_OP(op=one_op, attr=attr) |
|
return any |
|
return self._has_OP(op=op, attr=attr) |
|
|
|
def has_all(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ALL) |
|
|
|
def has_pe(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_PE) |
|
|
|
def has_attn(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN) |
|
|
|
def has_attn_q(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN_Q) |
|
|
|
def has_attn_k(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN_K) |
|
|
|
def has_attn_v(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN_V) |
|
|
|
def has_attn_out_weight(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN_OUT_WEIGHT) |
|
|
|
def has_attn_out_bias(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_ATTN_OUT_BIAS) |
|
|
|
def has_other(self, op: str) -> bool: |
|
return self._has_apply(op, self.ATTR_OTHER) |
|
|
|
def has_anything_to_apply(self): |
|
return self.has_all(self.OP_ANY) \ |
|
or self.has_pe(self.OP_ANY) \ |
|
or self.has_attn(self.OP_ANY) \ |
|
or self.has_attn_q(self.OP_ANY) \ |
|
or self.has_attn_k(self.OP_ANY) \ |
|
or self.has_attn_v(self.OP_ANY) \ |
|
or self.has_attn_out_weight(self.OP_ANY) \ |
|
or self.has_attn_out_bias(self.OP_ANY) \ |
|
or self.has_other(self.OP_ANY) |
|
|
|
def _perform_op(self, model_dict: dict[str, Tensor], key: str, op: str, attr: str): |
|
val = self._get_val(op=op, attr=attr) |
|
specific_str = f"'{attr}' weights" if attr == self.ATTR_ALL else f"every '{attr}' weight" |
|
if op == self.OP_ADD: |
|
model_dict[key] += val |
|
if self.print_adjustment and not self.is_already_printed(attr=attr, op=op): |
|
logger.info(f"[Adjust Weight]: Adding to {specific_str} value {val}") |
|
self.mask_as_printed(attr=attr, op=op) |
|
elif op == self.OP_MULT: |
|
model_dict[key] *= val |
|
if self.print_adjustment and not self.is_already_printed(attr=attr, op=op): |
|
logger.info(f"[Adjust Weight]: Multiplying {specific_str} by {val}") |
|
self.mask_as_printed(attr=attr, op=op) |
|
else: |
|
raise Exception(f"Operation '{op}' not recognized in AdjustWeight.") |
|
|
|
def perform_applicable_ops(self, attr: str, model_dict: dict[str, Tensor], key: str): |
|
for op in self.OPS: |
|
if self._has_apply(op=op, attr=attr): |
|
self._perform_op(model_dict=model_dict, key=key, op=op, attr=attr) |
|
|
|
|
|
ADJUST_TYPES = Union[AdjustPE, AdjustWeight] |
|
class AdjustGroup: |
|
def __init__(self, initial: ADJUST_TYPES=None): |
|
self.adjusts: list[ADJUST_TYPES] = [] |
|
if initial is not None: |
|
self.add(initial) |
|
|
|
def add(self, adjust: ADJUST_TYPES): |
|
self.adjusts.append(adjust) |
|
|
|
def has_anything_to_apply(self): |
|
for adjust in self.adjusts: |
|
if adjust.has_anything_to_apply(): |
|
return True |
|
return False |
|
|
|
def clone(self): |
|
new_group = AdjustGroup() |
|
for adjust in self.adjusts: |
|
new_group.add(adjust=adjust) |
|
return new_group |
|
|