Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from torch import nn | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
class GEGLU(nn.Module): | |
def __init__(self, dim_in: int, dim_out: int): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out * 2) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * F.gelu(gate) | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
dim_out: int | None = None, | |
mult: int = 4, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = dim_out or dim | |
self.net = nn.Sequential( | |
GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.net(x) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
query_dim: int, | |
context_dim: int | None = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
self.heads = heads | |
self.dim_head = dim_head | |
inner_dim = dim_head * heads | |
context_dim = context_dim or query_dim | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) | |
) | |
def forward( | |
self, x: torch.Tensor, context: torch.Tensor | None = None | |
) -> torch.Tensor: | |
q = self.to_q(x) | |
context = context if context is not None else x | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q, k, v = map( | |
lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads), | |
(q, k, v), | |
) | |
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
out = F.scaled_dot_product_attention(q, k, v) | |
out = rearrange(out, "b h l d -> b l (h d)") | |
out = self.to_out(out) | |
return out | |
class TransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
n_heads: int, | |
d_head: int, | |
context_dim: int, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
self.attn1 = Attention( | |
query_dim=dim, | |
context_dim=None, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
) | |
self.ff = FeedForward(dim, dropout=dropout) | |
self.attn2 = Attention( | |
query_dim=dim, | |
context_dim=context_dim, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
) | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.norm3 = nn.LayerNorm(dim) | |
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: | |
x = self.attn1(self.norm1(x)) + x | |
x = self.attn2(self.norm2(x), context=context) + x | |
x = self.ff(self.norm3(x)) + x | |
return x | |
class TransformerBlockTimeMix(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
n_heads: int, | |
d_head: int, | |
context_dim: int, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
inner_dim = n_heads * d_head | |
self.norm_in = nn.LayerNorm(dim) | |
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout) | |
self.attn1 = Attention( | |
query_dim=inner_dim, | |
context_dim=None, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
) | |
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout) | |
self.attn2 = Attention( | |
query_dim=inner_dim, | |
context_dim=context_dim, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
) | |
self.norm1 = nn.LayerNorm(inner_dim) | |
self.norm2 = nn.LayerNorm(inner_dim) | |
self.norm3 = nn.LayerNorm(inner_dim) | |
def forward( | |
self, x: torch.Tensor, context: torch.Tensor, num_frames: int | |
) -> torch.Tensor: | |
_, s, _ = x.shape | |
x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames) | |
x = self.ff_in(self.norm_in(x)) + x | |
x = self.attn1(self.norm1(x), context=None) + x | |
x = self.attn2(self.norm2(x), context=context) + x | |
x = self.ff(self.norm3(x)) | |
x = rearrange(x, "(b s) t c -> (b t) s c", s=s) | |
return x | |
class SkipConnect(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward( | |
self, x_spatial: torch.Tensor, x_temporal: torch.Tensor | |
) -> torch.Tensor: | |
return x_spatial + x_temporal | |
class MultiviewTransformer(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
n_heads: int, | |
d_head: int, | |
name: str, | |
unflatten_names: list[str] = [], | |
depth: int = 1, | |
context_dim: int = 1024, | |
dropout: float = 0.0, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.name = name | |
self.unflatten_names = unflatten_names | |
inner_dim = n_heads * d_head | |
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6) | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
TransformerBlock( | |
inner_dim, | |
n_heads, | |
d_head, | |
context_dim=context_dim, | |
dropout=dropout, | |
) | |
for _ in range(depth) | |
] | |
) | |
self.proj_out = nn.Linear(inner_dim, in_channels) | |
self.time_mixer = SkipConnect() | |
self.time_mix_blocks = nn.ModuleList( | |
[ | |
TransformerBlockTimeMix( | |
inner_dim, | |
n_heads, | |
d_head, | |
context_dim=context_dim, | |
dropout=dropout, | |
) | |
for _ in range(depth) | |
] | |
) | |
def forward( | |
self, x: torch.Tensor, context: torch.Tensor, num_frames: int | |
) -> torch.Tensor: | |
assert context.ndim == 3 | |
_, _, h, w = x.shape | |
x_in = x | |
time_context = context | |
time_context_first_timestep = time_context[::num_frames] | |
time_context = repeat( | |
time_context_first_timestep, "b ... -> (b n) ...", n=h * w | |
) | |
if self.name in self.unflatten_names: | |
context = context[::num_frames] | |
x = self.norm(x) | |
x = rearrange(x, "b c h w -> b (h w) c") | |
x = self.proj_in(x) | |
for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks): | |
if self.name in self.unflatten_names: | |
x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w) | |
x = block(x, context=context) | |
if self.name in self.unflatten_names: | |
x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w) | |
x_mix = mix_block(x, context=time_context, num_frames=num_frames) | |
x = self.time_mixer(x_spatial=x, x_temporal=x_mix) | |
x = self.proj_out(x) | |
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
out = x + x_in | |
return out | |