import torch import torch.nn.functional as F from torch import nn def cast_tuple(t, length = 1): return t if isinstance(t, tuple) else ((t,) * length) def divisible_by(num, den): return (num % den) == 0 def is_odd(n): return not divisible_by(n, 2) class CausalConv3d(nn.Module): def __init__( self, chan_in, chan_out, kernel_size, pad_mode = 'constant', **kwargs ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) time_kernel_size, height_kernel_size, width_kernel_size = kernel_size assert is_odd(height_kernel_size) and is_odd(width_kernel_size) dilation = kwargs.pop('dilation', 1) stride = kwargs.pop('stride', 1) self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) height_pad = height_kernel_size // 2 width_pad = width_kernel_size // 2 self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) stride = (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode = 'replicate') return self.conv(x) class Swish(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x * F.sigmoid(x) class ResBlockX(nn.Module): def __init__(self, inchannel) -> None: super().__init__() self.conv = nn.Sequential( nn.GroupNorm(32, inchannel), Swish(), CausalConv3d(inchannel, inchannel, 3), nn.GroupNorm(32, inchannel), Swish(), CausalConv3d(inchannel, inchannel, 3) ) def forward(self, x): return x + self.conv(x) class ResBlockXY(nn.Module): def __init__(self, inchannel, outchannel) -> None: super().__init__() self.conv = nn.Sequential( nn.GroupNorm(32, inchannel), Swish(), CausalConv3d(inchannel, outchannel, 3), nn.GroupNorm(32, outchannel), Swish(), CausalConv3d(outchannel, outchannel, 3) ) self.conv_1 = nn.Conv3d(inchannel, outchannel, 1) def forward(self, x): return self.conv_1(x) + self.conv(x) class PoolDown222(nn.Module): def __init__(self) -> None: super().__init__() self.pool = nn.AvgPool3d(2, 2) def forward(self, x): x = F.pad(x, (0, 0, 0, 0, 1, 0), 'replicate') return self.pool(x) class PoolDown122(nn.Module): def __init__(self) -> None: super().__init__() self.pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2)) def forward(self, x): return self.pool(x) class Unpool222(nn.Module): def __init__(self) -> None: super().__init__() self.up = nn.Upsample(scale_factor=2, mode='nearest') def forward(self, x): x = self.up(x) return x[:, :, 1:] class Unpool122(nn.Module): def __init__(self) -> None: super().__init__() self.up = nn.Upsample(scale_factor=(1, 2, 2), mode='nearest') def forward(self, x): x = self.up(x) return x class ResBlockDown(nn.Module): def __init__(self, inchannel, outchannel) -> None: super().__init__() self.blcok = nn.Sequential( CausalConv3d(inchannel, outchannel, 3), nn.LeakyReLU(inplace=True), PoolDown222(), CausalConv3d(outchannel, outchannel, 3), nn.LeakyReLU(inplace=True) ) self.res = nn.Sequential( PoolDown222(), nn.Conv3d(inchannel, outchannel, 1) ) def forward(self, x): return self.res(x) + self.blcok(x) class Discriminator(nn.Module): def __init__(self) -> None: super().__init__() self.block = nn.Sequential( CausalConv3d(3, 64, 3), nn.LeakyReLU(inplace=True), ResBlockDown(64, 128), ResBlockDown(128, 256), ResBlockDown(256, 256), ResBlockDown(256, 256), ResBlockDown(256, 256), CausalConv3d(256, 256, 3), nn.LeakyReLU(inplace=True), nn.AdaptiveAvgPool3d(1), nn.Flatten(), nn.Linear(256, 256), nn.LeakyReLU(inplace=True), nn.Linear(256, 1) ) def forward(self, x): if x.ndim==4: x = x.unsqueeze(2) return self.block(x) class Encoder(nn.Module): def __init__(self) -> None: super().__init__() self.encoder = nn.Sequential( CausalConv3d(3, 64, 3), ResBlockX(64), ResBlockX(64), PoolDown222(), ResBlockXY(64, 128), ResBlockX(128), PoolDown222(), ResBlockX(128), ResBlockX(128), PoolDown122(), ResBlockXY(128, 256), ResBlockX(256), ResBlockX(256), ResBlockX(256), nn.GroupNorm(32, 256), Swish(), nn.Conv3d(256, 16, 1) ) def forward(self, x): return self.encoder(x) class Decoder(nn.Module): def __init__(self) -> None: super().__init__() self.decoder = nn.Sequential( CausalConv3d(8, 256, 3), ResBlockX(256), ResBlockX(256), ResBlockX(256), ResBlockX(256), Unpool122(), CausalConv3d(256, 256, 3), ResBlockXY(256, 128), ResBlockX(128), Unpool222(), CausalConv3d(128, 128, 3), ResBlockX(128), ResBlockX(128), Unpool222(), CausalConv3d(128, 128, 3), ResBlockXY(128, 64), ResBlockX(64), nn.GroupNorm(32, 64), Swish(), CausalConv3d(64, 64, 3) ) self.conv_out = nn.Conv3d(64, 3, 1) def forward(self, x): return self.conv_out(self.decoder(x)) if __name__=='__main__': encoder = Encoder() decoder = Decoder() dis = Discriminator() x = torch.randn((1, 3, 1, 64, 64)) embedding = encoder(x) y = decoder(embedding) tmp = torch.randn((1, 4, 1, 64, 64)) print('something mmm')