File size: 1,879 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from einops import rearrange
import torch
import torch.nn.functional as F


def flow_attention(img_attn, num_heads, dim_head, transformer_options):
    flow = transformer_options.get('FLOW', None)
    if flow is None:
        return img_attn
    
    hidden_states = img_attn
    cond_list = transformer_options['cond_or_uncond']
    batch_size = len(hidden_states) // len(cond_list)
    flow = None
    flows = []
    flows = transformer_options['FLOW']['forward_flows']

    for possible_flow in flows:
        if possible_flow['forward_trajectory'].shape[2] == hidden_states.shape[1]:
            flow = possible_flow
            break

    if not flow:
        return hidden_states

    backward_trajectory = flow['backward_trajectory'].to(hidden_states.device)
    forward_trajectory = flow['forward_trajectory'].to(hidden_states.device)
    attn_mask = flow['attn_masks'].to(hidden_states.device)

    hidden_states = rearrange(hidden_states, "(b f) d c -> f (b c) d", f=batch_size)

    hidden_states = torch.gather(hidden_states, 2, forward_trajectory.expand(-1,hidden_states.shape[1],-1))

    hidden_states = rearrange(hidden_states, "f (b c) d -> (b d) f c", b=len(cond_list))
    hidden_states = hidden_states.view(-1, batch_size, num_heads, dim_head).transpose(1, 2).detach()

    hidden_states = F.scaled_dot_product_attention(
        hidden_states, hidden_states, hidden_states, 
        attn_mask = (attn_mask.repeat(len(cond_list),1,1,1))
    )

    hidden_states = rearrange(hidden_states, "(b d) h f c -> f (b h c) d", b=len(cond_list))
    hidden_states = torch.gather(hidden_states, 2, backward_trajectory.expand(-1,hidden_states.shape[1],-1)).detach()
    hidden_states = rearrange(hidden_states, "f (b h c) d -> (b f) h d c", b=len(cond_list), h=num_heads)
    hidden_states = rearrange(hidden_states, 'b d q f -> b q (d f)')
    
    return hidden_states