from einops import rearrange import torch import torch.nn.functional as F def flow_attention(img_attn, num_heads, dim_head, transformer_options): flow = transformer_options.get('FLOW', None) if flow is None: return img_attn hidden_states = img_attn cond_list = transformer_options['cond_or_uncond'] batch_size = len(hidden_states) // len(cond_list) flow = None flows = [] flows = transformer_options['FLOW']['forward_flows'] for possible_flow in flows: if possible_flow['forward_trajectory'].shape[2] == hidden_states.shape[1]: flow = possible_flow break if not flow: return hidden_states backward_trajectory = flow['backward_trajectory'].to(hidden_states.device) forward_trajectory = flow['forward_trajectory'].to(hidden_states.device) attn_mask = flow['attn_masks'].to(hidden_states.device) hidden_states = rearrange(hidden_states, "(b f) d c -> f (b c) d", f=batch_size) hidden_states = torch.gather(hidden_states, 2, forward_trajectory.expand(-1,hidden_states.shape[1],-1)) hidden_states = rearrange(hidden_states, "f (b c) d -> (b d) f c", b=len(cond_list)) hidden_states = hidden_states.view(-1, batch_size, num_heads, dim_head).transpose(1, 2).detach() hidden_states = F.scaled_dot_product_attention( hidden_states, hidden_states, hidden_states, attn_mask = (attn_mask.repeat(len(cond_list),1,1,1)) ) hidden_states = rearrange(hidden_states, "(b d) h f c -> f (b h c) d", b=len(cond_list)) hidden_states = torch.gather(hidden_states, 2, backward_trajectory.expand(-1,hidden_states.shape[1],-1)).detach() hidden_states = rearrange(hidden_states, "f (b h c) d -> (b f) h d c", b=len(cond_list), h=num_heads) hidden_states = rearrange(hidden_states, 'b d q f -> b q (d f)') return hidden_states