from typing import Optional from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.attention import Attention class InflatedConv3d(nn.Conv2d): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class FFInflatedConv3d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, **kwargs): super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, **kwargs, ) self.conv_temp = nn.Linear(3 * out_channels, out_channels) nn.init.zeros_(self.conv_temp.weight.data) # initialized to be ones nn.init.zeros_(self.conv_temp.bias.data) def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) *_, h, w = x.shape x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length) head_frame_index = [0, ] * video_length prev_frame_index = torch.clamp( torch.arange(video_length) - 1, min=0.0 ).long() curr_frame_index = torch.arange(video_length).long() conv_temp_nn_input = torch.cat([ x[:, head_frame_index], x[:, prev_frame_index], x[:, curr_frame_index] ], dim=2).contiguous() x = x + self.conv_temp(conv_temp_nn_input) x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w) return x class FFAttention(Attention): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, *args, scale_qk: bool = True, processor: Optional["FFAttnProcessor"] = None, **kwargs ): super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs) # set attention processor # We use the AttnProcessor by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. if processor is None: processor = FFAttnProcessor() self.set_processor(processor) def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, video_length=video_length, **cross_attention_kwargs, ) class FFAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # sparse causal attention former_frame_index = torch.arange(video_length) - 1 former_frame_index[0] = 0 key = rearrange(key, "(b f) d c -> b f d c", f=video_length) key = key[:, [0] * video_length].contiguous() key = rearrange(key, "b f d c -> (b f) d c") value = rearrange(value, "(b f) d c -> b f d c", f=video_length) value = value[:, [0] * video_length].contiguous() value = rearrange(value, "b f d c -> (b f) d c") head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states