File size: 5,305 Bytes
613c9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from torch import Tensor

from .utils_motion import normalize_min_max


class AnimateDiffSettings:
    def __init__(self,
                 adjust_pe: 'AdjustPEGroup'=None,
                 pe_strength: float=1.0,
                 attn_strength: float=1.0,
                 attn_q_strength: float=1.0,
                 attn_k_strength: float=1.0,
                 attn_v_strength: float=1.0,
                 attn_out_weight_strength: float=1.0,
                 attn_out_bias_strength: float=1.0,
                 other_strength: float=1.0,
                 attn_scale: float=1.0,
                 mask_attn_scale: Tensor=None,
                 mask_attn_scale_min: float=1.0,
                 mask_attn_scale_max: float=1.0,
                 ):
        # PE-interpolation settings
        self.adjust_pe = adjust_pe if adjust_pe is not None else AdjustPEGroup()
        # general strengths
        self.pe_strength = pe_strength
        self.attn_strength = attn_strength
        self.other_strength = other_strength
        # specific attn strengths
        self.attn_q_strength = attn_q_strength
        self.attn_k_strength = attn_k_strength
        self.attn_v_strength = attn_v_strength
        self.attn_out_weight_strength = attn_out_weight_strength
        self.attn_out_bias_strength = attn_out_bias_strength
        # attention scale settings - DEPRECATED
        self.attn_scale = attn_scale
        # attention scale mask settings - DEPRECATED
        self.mask_attn_scale = mask_attn_scale.clone() if mask_attn_scale is not None else mask_attn_scale
        self.mask_attn_scale_min = mask_attn_scale_min
        self.mask_attn_scale_max = mask_attn_scale_max
        self._prepare_mask_attn_scale()
    
    def _prepare_mask_attn_scale(self):
        if self.mask_attn_scale is not None:
            self.mask_attn_scale = normalize_min_max(self.mask_attn_scale, self.mask_attn_scale_min, self.mask_attn_scale_max)

    def has_mask_attn_scale(self) -> bool:
        return self.mask_attn_scale is not None

    def has_pe_strength(self) -> bool:
        return self.pe_strength != 1.0
    
    def has_attn_strength(self) -> bool:
        return self.attn_strength != 1.0
    
    def has_other_strength(self) -> bool:
        return self.other_strength != 1.0

    def has_anything_to_apply(self) -> bool:
        return self.adjust_pe.has_anything_to_apply() \
            or self.has_pe_strength() \
            or self.has_attn_strength() \
            or self.has_other_strength() \
            or self.has_any_attn_sub_strength()

    def has_any_attn_sub_strength(self) -> bool:
        return self.has_attn_q_strength() \
            or self.has_attn_k_strength() \
            or self.has_attn_v_strength() \
            or self.has_attn_out_weight_strength() \
            or self.has_attn_out_bias_strength()

    def has_attn_q_strength(self) -> bool:
        return self.attn_q_strength != 1.0

    def has_attn_k_strength(self) -> bool:
        return self.attn_k_strength != 1.0

    def has_attn_v_strength(self) -> bool:
        return self.attn_v_strength != 1.0

    def has_attn_out_weight_strength(self) -> bool:
        return self.attn_out_weight_strength != 1.0

    def has_attn_out_bias_strength(self) -> bool:
        return self.attn_out_bias_strength != 1.0


class AdjustPE:
    def __init__(self,
                 cap_initial_pe_length: int=0, interpolate_pe_to_length: int=0,
                 initial_pe_idx_offset: int=0, final_pe_idx_offset: int=0,
                 motion_pe_stretch: int=0, print_adjustment=False):
        # PE-interpolation settings
        self.cap_initial_pe_length = cap_initial_pe_length
        self.interpolate_pe_to_length = interpolate_pe_to_length
        self.initial_pe_idx_offset = initial_pe_idx_offset
        self.final_pe_idx_offset = final_pe_idx_offset
        self.motion_pe_stretch = motion_pe_stretch
        self.print_adjustment = print_adjustment

    def has_cap_initial_pe_length(self) -> bool:
        return self.cap_initial_pe_length > 0
    
    def has_interpolate_pe_to_length(self) -> bool:
        return self.interpolate_pe_to_length > 0
    
    def has_initial_pe_idx_offset(self) -> bool:
        return self.initial_pe_idx_offset > 0
    
    def has_final_pe_idx_offset(self) -> bool:
        return self.final_pe_idx_offset > 0

    def has_motion_pe_stretch(self) -> bool:
        return self.motion_pe_stretch > 0
    
    def has_anything_to_apply(self) -> bool:
        return self.has_cap_initial_pe_length() \
            or self.has_interpolate_pe_to_length() \
            or self.has_initial_pe_idx_offset() \
            or self.has_final_pe_idx_offset() \
            or self.has_motion_pe_stretch()


class AdjustPEGroup:
    def __init__(self, initial: AdjustPE=None):
        self.adjusts: list[AdjustPE] = []
        if initial is not None:
            self.add(initial)

    def add(self, adjust_pe: AdjustPE):
        self.adjusts.append(adjust_pe)
    
    def has_anything_to_apply(self):
        for adjust in self.adjusts:
            if adjust.has_anything_to_apply():
                return True
        return False

    def clone(self):
        new_group = AdjustPEGroup()
        for adjust in self.adjusts:
            new_group.add(adjust)
        return new_group