jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
"""
https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py
https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating)
"""
import math
import numpy as np
import importlib
import torch
import torch.nn as nn
import torch.nn.functional as F
class SEGating(nn.Module):
def __init__(self , inplanes , reduction=16):
super().__init__()
self.pool = nn.AdaptiveAvgPool3d(1)
self.attn_layer = nn.Sequential(
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
nn.Sigmoid()
)
def forward(self , x):
out = self.pool(x)
y = self.attn_layer(out)
return x * y
def joinTensors(X1 , X2 , type="concat"):
if type == "concat":
return torch.cat([X1 , X2] , dim=1)
elif type == "add":
return X1 + X2
else:
return X1
class Conv_2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
super().__init__()
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if batchnorm:
self.conv += [nn.BatchNorm2d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv3D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = nn.ModuleList(
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
SEGating(out_ch)
]
)
else:
self.upconv = nn.ModuleList(
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
SEGating(out_ch)
]
)
if batchnorm:
self.upconv += [nn.BatchNorm3d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class Conv_3d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
super().__init__()
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
SEGating(out_ch)
]
if batchnorm:
self.conv += [nn.BatchNorm3d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv2D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
else:
self.upconv = [
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
]
if batchnorm:
self.upconv += [nn.BatchNorm2d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class UNet_3D_3D(nn.Module):
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
super().__init__()
nf = [512 , 256 , 128 , 64]
out_channels = 3*n_outputs
self.joinType = joinType
self.n_outputs = n_outputs
growth = 2 if joinType == "concat" else 1
self.lrelu = nn.LeakyReLU(0.2, True)
unet_3D = importlib.import_module(".resnet_3D", "models.flavr")
if n_outputs > 1:
unet_3D.useBias = True
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
self.decoder = nn.Sequential(
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
)
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
self.outconv = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
)
def forward(self, images):
images = torch.stack(images , dim=2)
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
images = images-mean_
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
dx_3 = self.lrelu(self.decoder[0](x_4))
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
dx_2 = self.lrelu(self.decoder[1](dx_3))
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
dx_1 = self.lrelu(self.decoder[2](dx_2))
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
dx_0 = self.lrelu(self.decoder[3](dx_1))
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
dx_out = self.lrelu(self.decoder[4](dx_0))
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
out = self.lrelu(self.feature_fuse(dx_out))
out = self.outconv(out)
out = torch.split(out, dim=1, split_size_or_sections=3)
mean_ = mean_.squeeze(2)
out = [o+mean_ for o in out]
return out
class InputPadder:
""" Pads images such that dimensions are divisible by divisor """
def __init__(self, dims, divisor=16):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
def pad(self, input_tensor):
return F.pad(input_tensor, self._pad, mode='replicate')
def unpad(self, input_tensor):
return self._unpad(input_tensor)
def _unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]