|
""" |
|
https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py |
|
https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating) |
|
""" |
|
import math |
|
import numpy as np |
|
import importlib |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class SEGating(nn.Module): |
|
|
|
def __init__(self , inplanes , reduction=16): |
|
|
|
super().__init__() |
|
|
|
self.pool = nn.AdaptiveAvgPool3d(1) |
|
self.attn_layer = nn.Sequential( |
|
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self , x): |
|
|
|
out = self.pool(x) |
|
y = self.attn_layer(out) |
|
return x * y |
|
|
|
def joinTensors(X1 , X2 , type="concat"): |
|
|
|
if type == "concat": |
|
return torch.cat([X1 , X2] , dim=1) |
|
elif type == "add": |
|
return X1 + X2 |
|
else: |
|
return X1 |
|
|
|
|
|
class Conv_2d(nn.Module): |
|
|
|
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False): |
|
|
|
super().__init__() |
|
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] |
|
|
|
if batchnorm: |
|
self.conv += [nn.BatchNorm2d(out_ch)] |
|
|
|
self.conv = nn.Sequential(*self.conv) |
|
|
|
def forward(self, x): |
|
|
|
return self.conv(x) |
|
|
|
class upConv3D(nn.Module): |
|
|
|
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): |
|
|
|
super().__init__() |
|
|
|
self.upmode = upmode |
|
|
|
if self.upmode=="transpose": |
|
self.upconv = nn.ModuleList( |
|
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding), |
|
SEGating(out_ch) |
|
] |
|
) |
|
|
|
else: |
|
self.upconv = nn.ModuleList( |
|
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False), |
|
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1), |
|
SEGating(out_ch) |
|
] |
|
) |
|
|
|
if batchnorm: |
|
self.upconv += [nn.BatchNorm3d(out_ch)] |
|
|
|
self.upconv = nn.Sequential(*self.upconv) |
|
|
|
def forward(self, x): |
|
|
|
return self.upconv(x) |
|
|
|
class Conv_3d(nn.Module): |
|
|
|
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False): |
|
|
|
super().__init__() |
|
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), |
|
SEGating(out_ch) |
|
] |
|
|
|
if batchnorm: |
|
self.conv += [nn.BatchNorm3d(out_ch)] |
|
|
|
self.conv = nn.Sequential(*self.conv) |
|
|
|
def forward(self, x): |
|
|
|
return self.conv(x) |
|
|
|
class upConv2D(nn.Module): |
|
|
|
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): |
|
|
|
super().__init__() |
|
|
|
self.upmode = upmode |
|
|
|
if self.upmode=="transpose": |
|
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)] |
|
|
|
else: |
|
self.upconv = [ |
|
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), |
|
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1) |
|
] |
|
|
|
if batchnorm: |
|
self.upconv += [nn.BatchNorm2d(out_ch)] |
|
|
|
self.upconv = nn.Sequential(*self.upconv) |
|
|
|
def forward(self, x): |
|
|
|
return self.upconv(x) |
|
|
|
|
|
class UNet_3D_3D(nn.Module): |
|
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"): |
|
super().__init__() |
|
|
|
nf = [512 , 256 , 128 , 64] |
|
out_channels = 3*n_outputs |
|
self.joinType = joinType |
|
self.n_outputs = n_outputs |
|
|
|
growth = 2 if joinType == "concat" else 1 |
|
self.lrelu = nn.LeakyReLU(0.2, True) |
|
|
|
unet_3D = importlib.import_module(".resnet_3D", "models.flavr") |
|
if n_outputs > 1: |
|
unet_3D.useBias = True |
|
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm) |
|
|
|
self.decoder = nn.Sequential( |
|
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), |
|
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), |
|
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), |
|
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), |
|
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm) |
|
) |
|
|
|
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm) |
|
|
|
self.outconv = nn.Sequential( |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0) |
|
) |
|
|
|
def forward(self, images): |
|
|
|
images = torch.stack(images , dim=2) |
|
|
|
|
|
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True) |
|
images = images-mean_ |
|
|
|
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images) |
|
|
|
dx_3 = self.lrelu(self.decoder[0](x_4)) |
|
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType) |
|
|
|
dx_2 = self.lrelu(self.decoder[1](dx_3)) |
|
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType) |
|
|
|
dx_1 = self.lrelu(self.decoder[2](dx_2)) |
|
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType) |
|
|
|
dx_0 = self.lrelu(self.decoder[3](dx_1)) |
|
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType) |
|
|
|
dx_out = self.lrelu(self.decoder[4](dx_0)) |
|
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1) |
|
|
|
out = self.lrelu(self.feature_fuse(dx_out)) |
|
out = self.outconv(out) |
|
|
|
out = torch.split(out, dim=1, split_size_or_sections=3) |
|
mean_ = mean_.squeeze(2) |
|
out = [o+mean_ for o in out] |
|
|
|
return out |
|
|
|
class InputPadder: |
|
""" Pads images such that dimensions are divisible by divisor """ |
|
def __init__(self, dims, divisor=16): |
|
self.ht, self.wd = dims[-2:] |
|
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor |
|
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor |
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] |
|
|
|
def pad(self, input_tensor): |
|
return F.pad(input_tensor, self._pad, mode='replicate') |
|
|
|
def unpad(self, input_tensor): |
|
return self._unpad(input_tensor) |
|
|
|
def _unpad(self, x): |
|
ht, wd = x.shape[-2:] |
|
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] |
|
return x[..., c[0]:c[1], c[2]:c[3]] |