Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,644 Bytes
d6d7648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from typing import Optional
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class FFInflatedConv3d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
**kwargs,
)
self.conv_temp = nn.Linear(3 * out_channels, out_channels)
nn.init.zeros_(self.conv_temp.weight.data) # initialized to be ones
nn.init.zeros_(self.conv_temp.bias.data)
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
*_, h, w = x.shape
x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length)
head_frame_index = [0, ] * video_length
prev_frame_index = torch.clamp(
torch.arange(video_length) - 1, min=0.0
).long()
curr_frame_index = torch.arange(video_length).long()
conv_temp_nn_input = torch.cat([
x[:, head_frame_index],
x[:, prev_frame_index],
x[:, curr_frame_index]
], dim=2).contiguous()
x = x + self.conv_temp(conv_temp_nn_input)
x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w)
return x
class FFAttention(Attention):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
*args,
scale_qk: bool = True,
processor: Optional["FFAttnProcessor"] = None,
**kwargs
):
super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs)
# set attention processor
# We use the AttnProcessor by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument.
if processor is None:
processor = FFAttnProcessor()
self.set_processor(processor)
def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None,
**cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
**cross_attention_kwargs,
)
class FFAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# sparse causal attention
former_frame_index = torch.arange(video_length) - 1
former_frame_index[0] = 0
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key = key[:, [0] * video_length].contiguous()
key = rearrange(key, "b f d c -> (b f) d c")
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value = value[:, [0] * video_length].contiguous()
value = rearrange(value, "b f d c -> (b f) d c")
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states |