jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from einops import rearrange
import torch
import torch.nn.functional as F
def flow_attention(img_attn, num_heads, dim_head, transformer_options):
flow = transformer_options.get('FLOW', None)
if flow is None:
return img_attn
hidden_states = img_attn
cond_list = transformer_options['cond_or_uncond']
batch_size = len(hidden_states) // len(cond_list)
flow = None
flows = []
flows = transformer_options['FLOW']['forward_flows']
for possible_flow in flows:
if possible_flow['forward_trajectory'].shape[2] == hidden_states.shape[1]:
flow = possible_flow
break
if not flow:
return hidden_states
backward_trajectory = flow['backward_trajectory'].to(hidden_states.device)
forward_trajectory = flow['forward_trajectory'].to(hidden_states.device)
attn_mask = flow['attn_masks'].to(hidden_states.device)
hidden_states = rearrange(hidden_states, "(b f) d c -> f (b c) d", f=batch_size)
hidden_states = torch.gather(hidden_states, 2, forward_trajectory.expand(-1,hidden_states.shape[1],-1))
hidden_states = rearrange(hidden_states, "f (b c) d -> (b d) f c", b=len(cond_list))
hidden_states = hidden_states.view(-1, batch_size, num_heads, dim_head).transpose(1, 2).detach()
hidden_states = F.scaled_dot_product_attention(
hidden_states, hidden_states, hidden_states,
attn_mask = (attn_mask.repeat(len(cond_list),1,1,1))
)
hidden_states = rearrange(hidden_states, "(b d) h f c -> f (b h c) d", b=len(cond_list))
hidden_states = torch.gather(hidden_states, 2, backward_trajectory.expand(-1,hidden_states.shape[1],-1)).detach()
hidden_states = rearrange(hidden_states, "f (b h c) d -> (b f) h d c", b=len(cond_list), h=num_heads)
hidden_states = rearrange(hidden_states, 'b d q f -> b q (d f)')
return hidden_states