daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
5.31 kB
from torch import Tensor
from .utils_motion import normalize_min_max
class AnimateDiffSettings:
def __init__(self,
adjust_pe: 'AdjustPEGroup'=None,
pe_strength: float=1.0,
attn_strength: float=1.0,
attn_q_strength: float=1.0,
attn_k_strength: float=1.0,
attn_v_strength: float=1.0,
attn_out_weight_strength: float=1.0,
attn_out_bias_strength: float=1.0,
other_strength: float=1.0,
attn_scale: float=1.0,
mask_attn_scale: Tensor=None,
mask_attn_scale_min: float=1.0,
mask_attn_scale_max: float=1.0,
):
# PE-interpolation settings
self.adjust_pe = adjust_pe if adjust_pe is not None else AdjustPEGroup()
# general strengths
self.pe_strength = pe_strength
self.attn_strength = attn_strength
self.other_strength = other_strength
# specific attn strengths
self.attn_q_strength = attn_q_strength
self.attn_k_strength = attn_k_strength
self.attn_v_strength = attn_v_strength
self.attn_out_weight_strength = attn_out_weight_strength
self.attn_out_bias_strength = attn_out_bias_strength
# attention scale settings - DEPRECATED
self.attn_scale = attn_scale
# attention scale mask settings - DEPRECATED
self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale
self.mask_attn_scale_min = mask_attn_scale_min
self.mask_attn_scale_max = mask_attn_scale_max
self._prepare_mask_attn_scale()
def _prepare_mask_attn_scale(self):
if self.mask_attn_scale is not None:
self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max)
def has_mask_attn_scale(self) -> bool:
return self.mask_attn_scale is not None
def has_pe_strength(self) -> bool:
return self.pe_strength != 1.0
def has_attn_strength(self) -> bool:
return self.attn_strength != 1.0
def has_other_strength(self) -> bool:
return self.other_strength != 1.0
def has_anything_to_apply(self) -> bool:
return self.adjust_pe.has_anything_to_apply() \
or self.has_pe_strength() \
or self.has_attn_strength() \
or self.has_other_strength() \
or self.has_any_attn_sub_strength()
def has_any_attn_sub_strength(self) -> bool:
return self.has_attn_q_strength() \
or self.has_attn_k_strength() \
or self.has_attn_v_strength() \
or self.has_attn_out_weight_strength() \
or self.has_attn_out_bias_strength()
def has_attn_q_strength(self) -> bool:
return self.attn_q_strength != 1.0
def has_attn_k_strength(self) -> bool:
return self.attn_k_strength != 1.0
def has_attn_v_strength(self) -> bool:
return self.attn_v_strength != 1.0
def has_attn_out_weight_strength(self) -> bool:
return self.attn_out_weight_strength != 1.0
def has_attn_out_bias_strength(self) -> bool:
return self.attn_out_bias_strength != 1.0
class AdjustPE:
def __init__(self,
cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0,
initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0,
motion_pe_stretch: int=0, print_adjustment=False):
# PE-interpolation settings
self.cap_initial_pe_length = cap_initial_pe_length
self.interpolate_pe_to_length = interpolate_pe_to_length
self.initial_pe_idx_offset = initial_pe_idx_offset
self.final_pe_idx_offset = final_pe_idx_offset
self.motion_pe_stretch = motion_pe_stretch
self.print_adjustment = print_adjustment
def has_cap_initial_pe_length(self) -> bool:
return self.cap_initial_pe_length > 0
def has_interpolate_pe_to_length(self) -> bool:
return self.interpolate_pe_to_length > 0
def has_initial_pe_idx_offset(self) -> bool:
return self.initial_pe_idx_offset > 0
def has_final_pe_idx_offset(self) -> bool:
return self.final_pe_idx_offset > 0
def has_motion_pe_stretch(self) -> bool:
return self.motion_pe_stretch > 0
def has_anything_to_apply(self) -> bool:
return self.has_cap_initial_pe_length() \
or self.has_interpolate_pe_to_length() \
or self.has_initial_pe_idx_offset() \
or self.has_final_pe_idx_offset() \
or self.has_motion_pe_stretch()
class AdjustPEGroup:
def __init__(self, initial: AdjustPE=None):
self.adjusts: list[AdjustPE] = []
if initial is not None:
self.add(initial)
def add(self, adjust_pe: AdjustPE):
self.adjusts.append(adjust_pe)
def has_anything_to_apply(self):
for adjust in self.adjusts:
if adjust.has_anything_to_apply():
return True
return False
def clone(self):
new_group = AdjustPEGroup()
for adjust in self.adjusts:
new_group.add(adjust)
return new_group