Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
def cast_tuple(t, length = 1): | |
return t if isinstance(t, tuple) else ((t,) * length) | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def is_odd(n): | |
return not divisible_by(n, 2) | |
class CausalConv3d(nn.Module): | |
def __init__( | |
self, | |
chan_in, | |
chan_out, | |
kernel_size, | |
pad_mode = 'constant', | |
**kwargs | |
): | |
super().__init__() | |
kernel_size = cast_tuple(kernel_size, 3) | |
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size | |
assert is_odd(height_kernel_size) and is_odd(width_kernel_size) | |
dilation = kwargs.pop('dilation', 1) | |
stride = kwargs.pop('stride', 1) | |
self.pad_mode = pad_mode | |
time_pad = dilation * (time_kernel_size - 1) + (1 - stride) | |
height_pad = height_kernel_size // 2 | |
width_pad = width_kernel_size // 2 | |
self.time_pad = time_pad | |
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) | |
stride = (stride, 1, 1) | |
dilation = (dilation, 1, 1) | |
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) | |
def forward(self, x): | |
x = F.pad(x, self.time_causal_padding, mode = 'replicate') | |
return self.conv(x) | |
class Swish(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, x): | |
return x * F.sigmoid(x) | |
class ResBlockX(nn.Module): | |
def __init__(self, inchannel) -> None: | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.GroupNorm(32, inchannel), | |
Swish(), | |
CausalConv3d(inchannel, inchannel, 3), | |
nn.GroupNorm(32, inchannel), | |
Swish(), | |
CausalConv3d(inchannel, inchannel, 3) | |
) | |
def forward(self, x): | |
return x + self.conv(x) | |
class ResBlockXY(nn.Module): | |
def __init__(self, inchannel, outchannel) -> None: | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.GroupNorm(32, inchannel), | |
Swish(), | |
CausalConv3d(inchannel, outchannel, 3), | |
nn.GroupNorm(32, outchannel), | |
Swish(), | |
CausalConv3d(outchannel, outchannel, 3) | |
) | |
self.conv_1 = nn.Conv3d(inchannel, outchannel, 1) | |
def forward(self, x): | |
return self.conv_1(x) + self.conv(x) | |
class PoolDown222(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.pool = nn.AvgPool3d(2, 2) | |
def forward(self, x): | |
x = F.pad(x, (0, 0, 0, 0, 1, 0), 'replicate') | |
return self.pool(x) | |
class PoolDown122(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2)) | |
def forward(self, x): | |
return self.pool(x) | |
class Unpool222(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.up = nn.Upsample(scale_factor=2, mode='nearest') | |
def forward(self, x): | |
x = self.up(x) | |
return x[:, :, 1:] | |
class Unpool122(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.up = nn.Upsample(scale_factor=(1, 2, 2), mode='nearest') | |
def forward(self, x): | |
x = self.up(x) | |
return x | |
class ResBlockDown(nn.Module): | |
def __init__(self, inchannel, outchannel) -> None: | |
super().__init__() | |
self.blcok = nn.Sequential( | |
CausalConv3d(inchannel, outchannel, 3), | |
nn.LeakyReLU(inplace=True), | |
PoolDown222(), | |
CausalConv3d(outchannel, outchannel, 3), | |
nn.LeakyReLU(inplace=True) | |
) | |
self.res = nn.Sequential( | |
PoolDown222(), | |
nn.Conv3d(inchannel, outchannel, 1) | |
) | |
def forward(self, x): | |
return self.res(x) + self.blcok(x) | |
class Discriminator(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.block = nn.Sequential( | |
CausalConv3d(3, 64, 3), | |
nn.LeakyReLU(inplace=True), | |
ResBlockDown(64, 128), | |
ResBlockDown(128, 256), | |
ResBlockDown(256, 256), | |
ResBlockDown(256, 256), | |
ResBlockDown(256, 256), | |
CausalConv3d(256, 256, 3), | |
nn.LeakyReLU(inplace=True), | |
nn.AdaptiveAvgPool3d(1), | |
nn.Flatten(), | |
nn.Linear(256, 256), | |
nn.LeakyReLU(inplace=True), | |
nn.Linear(256, 1) | |
) | |
def forward(self, x): | |
if x.ndim==4: | |
x = x.unsqueeze(2) | |
return self.block(x) | |
class Encoder(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.encoder = nn.Sequential( | |
CausalConv3d(3, 64, 3), | |
ResBlockX(64), | |
ResBlockX(64), | |
PoolDown222(), | |
ResBlockXY(64, 128), | |
ResBlockX(128), | |
PoolDown222(), | |
ResBlockX(128), | |
ResBlockX(128), | |
PoolDown122(), | |
ResBlockXY(128, 256), | |
ResBlockX(256), | |
ResBlockX(256), | |
ResBlockX(256), | |
nn.GroupNorm(32, 256), | |
Swish(), | |
nn.Conv3d(256, 16, 1) | |
) | |
def forward(self, x): | |
return self.encoder(x) | |
class Decoder(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.decoder = nn.Sequential( | |
CausalConv3d(8, 256, 3), | |
ResBlockX(256), | |
ResBlockX(256), | |
ResBlockX(256), | |
ResBlockX(256), | |
Unpool122(), | |
CausalConv3d(256, 256, 3), | |
ResBlockXY(256, 128), | |
ResBlockX(128), | |
Unpool222(), | |
CausalConv3d(128, 128, 3), | |
ResBlockX(128), | |
ResBlockX(128), | |
Unpool222(), | |
CausalConv3d(128, 128, 3), | |
ResBlockXY(128, 64), | |
ResBlockX(64), | |
nn.GroupNorm(32, 64), | |
Swish(), | |
CausalConv3d(64, 64, 3) | |
) | |
self.conv_out = nn.Conv3d(64, 3, 1) | |
def forward(self, x): | |
return self.conv_out(self.decoder(x)) | |
if __name__=='__main__': | |
encoder = Encoder() | |
decoder = Decoder() | |
dis = Discriminator() | |
x = torch.randn((1, 3, 1, 64, 64)) | |
embedding = encoder(x) | |
y = decoder(embedding) | |
tmp = torch.randn((1, 4, 1, 64, 64)) | |
print('something mmm') |