fffiloni's picture
Upload 244 files
b3f324b verified
from typing import Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet_block import ResnetBlock3D
from .attention import TemporalAttnBlock
from .normalize import Normalize
from .ops import cast_tuple, video_to_image
from .conv import CausalConv3d
from einops import rearrange
from .block import Block
class Upsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.with_conv = True
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
@video_to_image
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(Block):
def __init__(self, in_channels, out_channels):
super().__init__()
self.with_conv = True
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
@video_to_image
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class SpatialDownsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]] = (3, 3),
stride: Union[int, Tuple[int]] = (2, 2),
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 2)
stride = cast_tuple(stride, 2)
self.chan_in = chan_in
self.chan_out = chan_out
self.kernel_size = kernel_size
self.conv = CausalConv3d(
self.chan_in,
self.chan_out,
(1,) + self.kernel_size,
stride=(1, ) + stride,
padding=0
)
def forward(self, x):
pad = (0,1,0,1,0,0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class SpatialUpsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int]] = (3, 3),
stride: Union[int, Tuple[int]] = (1, 1),
):
super().__init__()
self.chan_in = chan_in
self.chan_out = chan_out
self.kernel_size = kernel_size
self.conv = CausalConv3d(
self.chan_in,
self.chan_out,
(1,) + self.kernel_size,
stride=(1, ) + stride,
padding=1
)
def forward(self, x):
t = x.shape[2]
x = rearrange(x, "b c t h w -> b (c t) h w")
x = F.interpolate(x, scale_factor=(2,2), mode="nearest")
x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
x = self.conv(x)
return x
class TimeDownsample2x(Block):
def __init__(
self,
chan_in,
chan_out,
kernel_size: int = 3
):
super().__init__()
self.kernel_size = kernel_size
self.conv = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1))
def forward(self, x):
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
return self.conv(x)
class TimeUpsample2x(Block):
def __init__(
self,
chan_in,
chan_out
):
super().__init__()
def forward(self, x):
if x.size(2) > 1:
x,x_= x[:,:,:1],x[:,:,1:]
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')
x = torch.concat([x, x_], dim=2)
return x
class TimeDownsampleRes2x(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
mix_factor: float = 2,
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1))
self.conv = nn.Conv3d(
in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1)
)
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
def forward(self, x):
alpha = torch.sigmoid(self.mix_factor)
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.kernel_size[0] - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
class TimeUpsampleRes2x(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
mix_factor: float = 2,
):
super().__init__()
self.conv = CausalConv3d(
in_channels, out_channels, kernel_size, padding=1
)
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
def forward(self, x):
alpha = torch.sigmoid(self.mix_factor)
if x.size(2) > 1:
x,x_= x[:,:,:1],x[:,:,1:]
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')
x = torch.concat([x, x_], dim=2)
return alpha * x + (1-alpha) * self.conv(x)
class TimeDownsampleResAdv2x(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1))
self.attn = TemporalAttnBlock(in_channels)
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
self.conv = nn.Conv3d(
in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1)
)
self.mix_factor = 1
def forward(self, x):
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.kernel_size[0] - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
return self.mix_factor * self.avg_pool(x) + (1 - self.mix_factor) * self.conv(self.attn((self.res(x))))
class TimeUpsampleResAdv2x(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
):
super().__init__()
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
self.attn = TemporalAttnBlock(in_channels)
self.norm = Normalize(in_channels=in_channels)
self.conv = CausalConv3d(
in_channels, out_channels, kernel_size, padding=1
)
self.mix_factor = 1
def forward(self, x):
if x.size(2) > 1:
x,x_= x[:,:,:1],x[:,:,1:]
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear')
x = torch.concat([x, x_], dim=2)
return self.mix_factor * x + (1-self.mix_factor) * self.conv(self.attn(self.res(x)))