File size: 12,233 Bytes
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from einops import repeat, rearrange
from typing import Callable, Optional, Union
from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
# from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention
from diffusers.utils.import_utils import is_xformers_available
from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams
import torch
import torch.nn.functional as F
if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None


def set_use_memory_efficient_attention_xformers(
    model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None
) -> None:
    # Recursively walk through all the children.
    # Any children which exposes the set_use_memory_efficient_attention_xformers method
    # gets the message
    def fn_recursive_set_mem_eff(module: torch.nn.Module):
        if hasattr(module, "set_processor"):

            module.set_processor(XFormersAttnProcessor(attention_op=attention_op,
                                                       num_frame_conditioning=num_frame_conditioning,
                                                       num_frames=num_frames,
                                                       attention_mask_params=attention_mask_params,)
                                 )

        for child in module.children():
            fn_recursive_set_mem_eff(child)

    for module in model.children():
        if isinstance(module, torch.nn.Module):
            fn_recursive_set_mem_eff(module)


class XFormersAttnProcessor:
    def __init__(self,
                 attention_mask_params: AttentionMaskParams,
                 attention_op: Optional[Callable] = None,
                 num_frame_conditioning: int = None,
                 num_frames: int = None,
                 use_image_embedding: bool = False,
                 ):
        self.attention_op = attention_op
        self.num_frame_conditioning = num_frame_conditioning
        self.num_frames = num_frames
        self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames
        self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames
        self.use_image_embedding = use_image_embedding

    def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        key_img = None
        value_img = None
        hidden_states_img = None
        if attention_mask is not None:
            attention_mask = repeat(
                attention_mask, "1 F D -> B F D", B=batch_size)

        attention_mask = attn.prepare_attention_mask(
            attention_mask, sequence_length, batch_size)

        query = attn.to_q(hidden_states)

        is_cross_attention = encoder_hidden_states is not None

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(
                encoder_hidden_states)
        default_attention = not hasattr(attn, "is_spatial_attention")
        if default_attention:
            assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface"
            assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface"
        is_spatial_attention = attn.is_spatial_attention if hasattr(
            attn, "is_spatial_attention") else False
        use_image_embedding = attn.use_image_embedding if hasattr(
            attn, "use_image_embedding") else False

        if is_spatial_attention and use_image_embedding and attn.cross_attention_mode:
            assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding"

            alpha = attn.alpha
            encoder_hidden_states_txt = encoder_hidden_states[:, :77, :]

            encoder_hidden_states_mixed = attn.conv(encoder_hidden_states)
            encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed)
            encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha)

            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)
        else:
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)




        if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode:
            # normal attention
            query_condition = query[:, :self.num_frame_conditioning]
            query_condition = attn.head_to_batch_dim(
                query_condition).contiguous()
            key_condition = key
            value_condition = value
            key_condition = attn.head_to_batch_dim(key_condition).contiguous()
            value_condition = attn.head_to_batch_dim(
                value_condition).contiguous()
            hidden_states_condition = xformers.ops.memory_efficient_attention(
                query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_condition = hidden_states_condition.to(query.dtype)
            hidden_states_condition = attn.batch_to_head_dim(
                hidden_states_condition)
            #
            query_uncondition = query[:, self.num_frame_conditioning:]

            key = key[:, :self.num_frame_conditioning]
            value = value[:, :self.num_frame_conditioning]
            key = rearrange(key, "(B W H) F C -> B W H F C",
                            H=hidden_state_height, W=hidden_state_width)
            value = rearrange(value, "(B W H) F C -> B W H F C",
                              H=hidden_state_height, W=hidden_state_width)

            keys = []
            values = []
            for shifts_width in [-1, 0, 1]:
                for shifts_height in [-1, 0, 1]:
                    keys.append(torch.roll(key, shifts=(
                        shifts_width, shifts_height), dims=(1, 2)))
                    values.append(torch.roll(value, shifts=(
                        shifts_width, shifts_height), dims=(1, 2)))
            key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C")
            value = rearrange(torch.cat(values, dim=3),
                              'B W H F C -> (B W H) F C')

            query = attn.head_to_batch_dim(query_uncondition).contiguous()
            key = attn.head_to_batch_dim(key).contiguous()
            value = attn.head_to_batch_dim(value).contiguous()

            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )
            hidden_states = hidden_states.to(query.dtype)
            hidden_states = attn.batch_to_head_dim(hidden_states)
            hidden_states = torch.cat(
                [hidden_states_condition, hidden_states], dim=1)
        elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode:
            # (B F) W H C -> B F W H C
            query_condition = rearrange(
                query, "(B F) S C -> B F S C", F=self.num_frames)
            query_condition = query_condition[:, :self.num_frame_conditioning]
            query_condition = rearrange(
                query_condition, "B F S C -> (B F) S C")
            query_condition = attn.head_to_batch_dim(
                query_condition).contiguous()

            key_condition = rearrange(
                key, "(B F) S C -> B F S C", F=self.num_frames)
            key_condition = key_condition[:, :self.num_frame_conditioning]
            key_condition = rearrange(key_condition, "B F S C -> (B F) S C")

            value_condition = rearrange(
                value, "(B F) S C -> B F S C", F=self.num_frames)
            value_condition = value_condition[:, :self.num_frame_conditioning]
            value_condition = rearrange(
                value_condition, "B F S C -> (B F) S C")

            key_condition = attn.head_to_batch_dim(key_condition).contiguous()
            value_condition = attn.head_to_batch_dim(
                value_condition).contiguous()
            hidden_states_condition = xformers.ops.memory_efficient_attention(
                query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_condition = hidden_states_condition.to(query.dtype)
            hidden_states_condition = attn.batch_to_head_dim(
                hidden_states_condition)

            query_uncondition = rearrange(
                query, "(B F) S C -> B F S C", F=self.num_frames)
            query_uncondition = query_uncondition[:,
                                                  self.num_frame_conditioning:]
            key_uncondition = rearrange(
                key, "(B F) S C -> B F S C", F=self.num_frames)
            value_uncondition = rearrange(
                value, "(B F) S C -> B F S C", F=self.num_frames)
            key_uncondition = key_uncondition[:,
                                              self.num_frame_conditioning-1, None]
            value_uncondition = value_uncondition[:,
                                                  self.num_frame_conditioning-1, None]
            # if self.trainer.training:
            # import pdb
            # pdb.set_trace()
            # print("now")
            query_uncondition = rearrange(
                query_uncondition, "B F S C -> (B F) S C")
            key_uncondition = repeat(rearrange(
                key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
            value_uncondition = repeat(rearrange(
                value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
            query_uncondition = attn.head_to_batch_dim(
                query_uncondition).contiguous()
            key_uncondition = attn.head_to_batch_dim(
                key_uncondition).contiguous()
            value_uncondition = attn.head_to_batch_dim(
                value_uncondition).contiguous()
            hidden_states_uncondition = xformers.ops.memory_efficient_attention(
                query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_uncondition = hidden_states_uncondition.to(
                query.dtype)
            hidden_states_uncondition = attn.batch_to_head_dim(
                hidden_states_uncondition)
            hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange(
                hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1)
            hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C")
        else:
            query = attn.head_to_batch_dim(query).contiguous()
            key = attn.head_to_batch_dim(key).contiguous()
            value = attn.head_to_batch_dim(value).contiguous()

            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )

            hidden_states = hidden_states.to(query.dtype)
            hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states