|
from typing import Union, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .resnet_block import ResnetBlock3D |
|
from .attention import TemporalAttnBlock |
|
from .normalize import Normalize |
|
from .ops import cast_tuple, video_to_image |
|
from .conv import CausalConv3d |
|
from einops import rearrange |
|
from .block import Block |
|
|
|
class Upsample(Block): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.with_conv = True |
|
if self.with_conv: |
|
self.conv = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
@video_to_image |
|
def forward(self, x): |
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
if self.with_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
class Downsample(Block): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.with_conv = True |
|
if self.with_conv: |
|
|
|
self.conv = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=0) |
|
@video_to_image |
|
def forward(self, x): |
|
if self.with_conv: |
|
pad = (0,1,0,1) |
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
else: |
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
|
return x |
|
|
|
class SpatialDownsample2x(Block): |
|
def __init__( |
|
self, |
|
chan_in, |
|
chan_out, |
|
kernel_size: Union[int, Tuple[int]] = (3, 3), |
|
stride: Union[int, Tuple[int]] = (2, 2), |
|
): |
|
super().__init__() |
|
kernel_size = cast_tuple(kernel_size, 2) |
|
stride = cast_tuple(stride, 2) |
|
self.chan_in = chan_in |
|
self.chan_out = chan_out |
|
self.kernel_size = kernel_size |
|
self.conv = CausalConv3d( |
|
self.chan_in, |
|
self.chan_out, |
|
(1,) + self.kernel_size, |
|
stride=(1, ) + stride, |
|
padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
pad = (0,1,0,1,0,0) |
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class SpatialUpsample2x(Block): |
|
def __init__( |
|
self, |
|
chan_in, |
|
chan_out, |
|
kernel_size: Union[int, Tuple[int]] = (3, 3), |
|
stride: Union[int, Tuple[int]] = (1, 1), |
|
): |
|
super().__init__() |
|
self.chan_in = chan_in |
|
self.chan_out = chan_out |
|
self.kernel_size = kernel_size |
|
self.conv = CausalConv3d( |
|
self.chan_in, |
|
self.chan_out, |
|
(1,) + self.kernel_size, |
|
stride=(1, ) + stride, |
|
padding=1 |
|
) |
|
|
|
def forward(self, x): |
|
t = x.shape[2] |
|
x = rearrange(x, "b c t h w -> b (c t) h w") |
|
x = F.interpolate(x, scale_factor=(2,2), mode="nearest") |
|
x = rearrange(x, "b (c t) h w -> b c t h w", t=t) |
|
x = self.conv(x) |
|
return x |
|
|
|
class TimeDownsample2x(Block): |
|
def __init__( |
|
self, |
|
chan_in, |
|
chan_out, |
|
kernel_size: int = 3 |
|
): |
|
super().__init__() |
|
self.kernel_size = kernel_size |
|
self.conv = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) |
|
|
|
def forward(self, x): |
|
first_frame_pad = x[:, :, :1, :, :].repeat( |
|
(1, 1, self.kernel_size - 1, 1, 1) |
|
) |
|
x = torch.concatenate((first_frame_pad, x), dim=2) |
|
return self.conv(x) |
|
|
|
class TimeUpsample2x(Block): |
|
def __init__( |
|
self, |
|
chan_in, |
|
chan_out |
|
): |
|
super().__init__() |
|
def forward(self, x): |
|
if x.size(2) > 1: |
|
x,x_= x[:,:,:1],x[:,:,1:] |
|
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') |
|
x = torch.concat([x, x_], dim=2) |
|
return x |
|
|
|
class TimeDownsampleRes2x(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size: int = 3, |
|
mix_factor: float = 2, |
|
): |
|
super().__init__() |
|
self.kernel_size = cast_tuple(kernel_size, 3) |
|
self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) |
|
self.conv = nn.Conv3d( |
|
in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) |
|
) |
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) |
|
|
|
def forward(self, x): |
|
alpha = torch.sigmoid(self.mix_factor) |
|
first_frame_pad = x[:, :, :1, :, :].repeat( |
|
(1, 1, self.kernel_size[0] - 1, 1, 1) |
|
) |
|
x = torch.concatenate((first_frame_pad, x), dim=2) |
|
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) |
|
|
|
class TimeUpsampleRes2x(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size: int = 3, |
|
mix_factor: float = 2, |
|
): |
|
super().__init__() |
|
self.conv = CausalConv3d( |
|
in_channels, out_channels, kernel_size, padding=1 |
|
) |
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) |
|
|
|
def forward(self, x): |
|
alpha = torch.sigmoid(self.mix_factor) |
|
if x.size(2) > 1: |
|
x,x_= x[:,:,:1],x[:,:,1:] |
|
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') |
|
x = torch.concat([x, x_], dim=2) |
|
return alpha * x + (1-alpha) * self.conv(x) |
|
|
|
class TimeDownsampleResAdv2x(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size: int = 3 |
|
): |
|
super().__init__() |
|
self.kernel_size = cast_tuple(kernel_size, 3) |
|
self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) |
|
self.attn = TemporalAttnBlock(in_channels) |
|
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) |
|
self.conv = nn.Conv3d( |
|
in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) |
|
) |
|
self.mix_factor = 1 |
|
|
|
def forward(self, x): |
|
first_frame_pad = x[:, :, :1, :, :].repeat( |
|
(1, 1, self.kernel_size[0] - 1, 1, 1) |
|
) |
|
x = torch.concatenate((first_frame_pad, x), dim=2) |
|
return self.mix_factor * self.avg_pool(x) + (1 - self.mix_factor) * self.conv(self.attn((self.res(x)))) |
|
|
|
class TimeUpsampleResAdv2x(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size: int = 3, |
|
): |
|
super().__init__() |
|
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) |
|
self.attn = TemporalAttnBlock(in_channels) |
|
self.norm = Normalize(in_channels=in_channels) |
|
self.conv = CausalConv3d( |
|
in_channels, out_channels, kernel_size, padding=1 |
|
) |
|
self.mix_factor = 1 |
|
|
|
def forward(self, x): |
|
if x.size(2) > 1: |
|
x,x_= x[:,:,:1],x[:,:,1:] |
|
x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') |
|
x = torch.concat([x, x_], dim=2) |
|
return self.mix_factor * x + (1-self.mix_factor) * self.conv(self.attn(self.res(x))) |
|
|