hangg-sai's picture
Initial commit
a342aa8
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: int | None = None,
mult: int = 4,
dropout: float = 0.0,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim
self.net = nn.Sequential(
GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
context_dim: int | None = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
):
super().__init__()
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
context_dim = context_dim or query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(
self, x: torch.Tensor, context: torch.Tensor | None = None
) -> torch.Tensor:
q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(
lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
(q, k, v),
)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(q, k, v)
out = rearrange(out, "b h l d -> b l (h d)")
out = self.to_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
context_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.attn1 = Attention(
query_dim=dim,
context_dim=None,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
)
self.ff = FeedForward(dim, dropout=dropout)
self.attn2 = Attention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class TransformerBlockTimeMix(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
context_dim: int,
dropout: float = 0.0,
):
super().__init__()
inner_dim = n_heads * d_head
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout)
self.attn1 = Attention(
query_dim=inner_dim,
context_dim=None,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout)
self.attn2 = Attention(
query_dim=inner_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
)
self.norm1 = nn.LayerNorm(inner_dim)
self.norm2 = nn.LayerNorm(inner_dim)
self.norm3 = nn.LayerNorm(inner_dim)
def forward(
self, x: torch.Tensor, context: torch.Tensor, num_frames: int
) -> torch.Tensor:
_, s, _ = x.shape
x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames)
x = self.ff_in(self.norm_in(x)) + x
x = self.attn1(self.norm1(x), context=None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x))
x = rearrange(x, "(b s) t c -> (b t) s c", s=s)
return x
class SkipConnect(nn.Module):
def __init__(self):
super().__init__()
def forward(
self, x_spatial: torch.Tensor, x_temporal: torch.Tensor
) -> torch.Tensor:
return x_spatial + x_temporal
class MultiviewTransformer(nn.Module):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
name: str,
unflatten_names: list[str] = [],
depth: int = 1,
context_dim: int = 1024,
dropout: float = 0.0,
):
super().__init__()
self.in_channels = in_channels
self.name = name
self.unflatten_names = unflatten_names
inner_dim = n_heads * d_head
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
inner_dim,
n_heads,
d_head,
context_dim=context_dim,
dropout=dropout,
)
for _ in range(depth)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
self.time_mixer = SkipConnect()
self.time_mix_blocks = nn.ModuleList(
[
TransformerBlockTimeMix(
inner_dim,
n_heads,
d_head,
context_dim=context_dim,
dropout=dropout,
)
for _ in range(depth)
]
)
def forward(
self, x: torch.Tensor, context: torch.Tensor, num_frames: int
) -> torch.Tensor:
assert context.ndim == 3
_, _, h, w = x.shape
x_in = x
time_context = context
time_context_first_timestep = time_context[::num_frames]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
if self.name in self.unflatten_names:
context = context[::num_frames]
x = self.norm(x)
x = rearrange(x, "b c h w -> b (h w) c")
x = self.proj_in(x)
for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks):
if self.name in self.unflatten_names:
x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w)
x = block(x, context=context)
if self.name in self.unflatten_names:
x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w)
x_mix = mix_block(x, context=time_context, num_frames=num_frames)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix)
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
out = x + x_in
return out