|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from tricorder.torch.transforms import Interpolator |
|
|
|
__author__ = "Soumick Chatterjee" |
|
__copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL" |
|
__credits__ = ["Soumick Chatterjee"] |
|
|
|
__license__ = "apache-2.0" |
|
__version__ = "1.0.0" |
|
__email__ = "[email protected]" |
|
__status__ = "Published" |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_features, drop_prob=0.2): |
|
super(ResidualBlock, self).__init__() |
|
|
|
conv_block = [layer_pad(1), |
|
layer_conv(in_features, in_features, 3), |
|
layer_norm(in_features), |
|
act_relu(), |
|
layer_drop(p=drop_prob, inplace=True), |
|
layer_pad(1), |
|
layer_conv(in_features, in_features, 3), |
|
layer_norm(in_features)] |
|
|
|
self.conv_block = nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
return x + self.conv_block(x) |
|
|
|
|
|
class DownsamplingBlock(nn.Module): |
|
def __init__(self, in_features, out_features): |
|
super(DownsamplingBlock, self).__init__() |
|
|
|
conv_block = [layer_conv(in_features, out_features, 3, stride=2, padding=1), |
|
layer_norm(out_features), |
|
act_relu()] |
|
self.conv_block = nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
return self.conv_block(x) |
|
|
|
|
|
class UpsamplingBlock(nn.Module): |
|
def __init__(self, in_features, out_features, mode="convtrans", interpolator=None, post_interp_convtrans=False): |
|
super(UpsamplingBlock, self).__init__() |
|
|
|
self.interpolator = interpolator |
|
self.mode = mode |
|
self.post_interp_convtrans = post_interp_convtrans |
|
if self.post_interp_convtrans: |
|
self.post_conv = layer_conv(out_features, out_features, 1) |
|
|
|
if mode == "convtrans": |
|
conv_block = [layer_convtrans( |
|
in_features, out_features, 3, stride=2, padding=1, output_padding=1), ] |
|
else: |
|
conv_block = [layer_pad(1), |
|
layer_conv(in_features, out_features, 3), ] |
|
conv_block += [layer_norm(out_features), |
|
act_relu()] |
|
self.conv_block = nn.Sequential(*conv_block) |
|
|
|
def forward(self, x, out_shape=None): |
|
if self.mode == "convtrans": |
|
if self.post_interp_convtrans: |
|
x = self.conv_block(x) |
|
if x.shape[2:] != out_shape: |
|
return self.post_conv(self.interpolator(x, out_shape)) |
|
else: |
|
return x |
|
else: |
|
return self.conv_block(x) |
|
else: |
|
return self.conv_block(self.interpolator(x, out_shape)) |
|
|
|
|
|
class ReconResNetBase(nn.Module): |
|
def __init__(self, in_channels=1, out_channels=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2, |
|
is_replicatepad=0, out_act="sigmoid", forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False, is3D=False): |
|
super(ReconResNetBase, self).__init__() |
|
|
|
layers = {} |
|
if is3D: |
|
layers["layer_conv"] = nn.Conv3d |
|
layers["layer_convtrans"] = nn.ConvTranspose3d |
|
if do_batchnorm: |
|
layers["layer_norm"] = nn.BatchNorm3d |
|
else: |
|
layers["layer_norm"] = nn.InstanceNorm3d |
|
layers["layer_drop"] = nn.Dropout3d |
|
if is_replicatepad == 0: |
|
layers["layer_pad"] = nn.ReflectionPad3d |
|
elif is_replicatepad == 1: |
|
layers["layer_pad"] = nn.ReplicationPad3d |
|
layers["interp_mode"] = 'trilinear' |
|
else: |
|
layers["layer_conv"] = nn.Conv2d |
|
layers["layer_convtrans"] = nn.ConvTranspose2d |
|
if do_batchnorm: |
|
layers["layer_norm"] = nn.BatchNorm2d |
|
else: |
|
layers["layer_norm"] = nn.InstanceNorm2d |
|
layers["layer_drop"] = nn.Dropout2d |
|
if is_replicatepad == 0: |
|
layers["layer_pad"] = nn.ReflectionPad2d |
|
elif is_replicatepad == 1: |
|
layers["layer_pad"] = nn.ReplicationPad2d |
|
layers["interp_mode"] = 'bilinear' |
|
if is_relu_leaky: |
|
layers["act_relu"] = nn.PReLU |
|
else: |
|
layers["act_relu"] = nn.ReLU |
|
globals().update(layers) |
|
|
|
self.forwardV = forwardV |
|
self.upinterp_algo = upinterp_algo |
|
|
|
interpolator = Interpolator( |
|
mode=layers["interp_mode"] if self.upinterp_algo == "convtrans" else self.upinterp_algo) |
|
|
|
|
|
intialConv = [layer_pad(3), |
|
layer_conv(in_channels, starting_nfeatures, 7), |
|
layer_norm(starting_nfeatures), |
|
act_relu()] |
|
|
|
|
|
downsam = [] |
|
in_features = starting_nfeatures |
|
out_features = in_features*2 |
|
for _ in range(updown_blocks): |
|
downsam.append(DownsamplingBlock(in_features, out_features)) |
|
in_features = out_features |
|
out_features = in_features*2 |
|
|
|
|
|
resblocks = [] |
|
for _ in range(res_blocks): |
|
resblocks += [ResidualBlock(in_features, res_drop_prob)] |
|
|
|
|
|
upsam = [] |
|
out_features = in_features//2 |
|
for _ in range(updown_blocks): |
|
upsam.append(UpsamplingBlock(in_features, out_features, |
|
self.upinterp_algo, interpolator, post_interp_convtrans)) |
|
in_features = out_features |
|
out_features = in_features//2 |
|
|
|
|
|
finalconv = [layer_pad(3), |
|
layer_conv(starting_nfeatures, out_channels, 7), ] |
|
|
|
if out_act == "sigmoid": |
|
finalconv += [nn.Sigmoid(), ] |
|
elif out_act == "relu": |
|
finalconv += [act_relu(), ] |
|
elif out_act == "tanh": |
|
finalconv += [nn.Tanh(), ] |
|
|
|
self.intialConv = nn.Sequential(*intialConv) |
|
self.downsam = nn.ModuleList(downsam) |
|
self.resblocks = nn.Sequential(*resblocks) |
|
self.upsam = nn.ModuleList(upsam) |
|
self.finalconv = nn.Sequential(*finalconv) |
|
|
|
if self.forwardV == 0: |
|
self.forward = self.forwardV0 |
|
elif self.forwardV == 1: |
|
self.forward = self.forwardV1 |
|
elif self.forwardV == 2: |
|
self.forward = self.forwardV2 |
|
elif self.forwardV == 3: |
|
self.forward = self.forwardV3 |
|
elif self.forwardV == 4: |
|
self.forward = self.forwardV4 |
|
elif self.forwardV == 5: |
|
self.forward = self.forwardV5 |
|
|
|
def forwardV0(self, x): |
|
|
|
x = self.intialConv(x) |
|
shapes = [] |
|
for downblock in self.downsam: |
|
shapes.append(x.shape[2:]) |
|
x = downblock(x) |
|
x = self.resblocks(x) |
|
for i, upblock in enumerate(self.upsam): |
|
x = upblock(x, shapes[-1-i]) |
|
return self.finalconv(x) |
|
|
|
def forwardV1(self, x): |
|
|
|
out = self.intialConv(x) |
|
shapes = [] |
|
for downblock in self.downsam: |
|
shapes.append(out.shape[2:]) |
|
out = downblock(out) |
|
out = self.resblocks(out) |
|
for i, upblock in enumerate(self.upsam): |
|
out = upblock(out, shapes[-1-i]) |
|
return x + self.finalconv(out) |
|
|
|
def forwardV2(self, x): |
|
|
|
out = self.intialConv(x) |
|
shapes = [] |
|
for downblock in self.downsam: |
|
shapes.append(out.shape[2:]) |
|
out = downblock(out) |
|
out = out + self.resblocks(out) |
|
for i, upblock in enumerate(self.upsam): |
|
out = upblock(out, shapes[-1-i]) |
|
return x + self.finalconv(out) |
|
|
|
def forwardV3(self, x): |
|
|
|
out = x + self.intialConv(x) |
|
shapes = [] |
|
for downblock in self.downsam: |
|
shapes.append(out.shape[2:]) |
|
out = downblock(out) |
|
out = out + self.resblocks(out) |
|
for i, upblock in enumerate(self.upsam): |
|
out = upblock(out, shapes[-1-i]) |
|
return x + self.finalconv(out) |
|
|
|
def forwardV4(self, x): |
|
|
|
iniconv = x + self.intialConv(x) |
|
shapes = [] |
|
if len(self.downsam) > 0: |
|
for i, downblock in enumerate(self.downsam): |
|
if i == 0: |
|
shapes.append(iniconv.shape[2:]) |
|
out = downblock(iniconv) |
|
else: |
|
shapes.append(out.shape[2:]) |
|
out = downblock(out) |
|
else: |
|
out = iniconv |
|
out = out + self.resblocks(out) |
|
for i, upblock in enumerate(self.upsam): |
|
out = upblock(out, shapes[-1-i]) |
|
out = iniconv + out |
|
return x + self.finalconv(out) |
|
|
|
def forwardV5(self, x): |
|
|
|
outs = [x + self.intialConv(x)] |
|
shapes = [] |
|
for i, downblock in enumerate(self.downsam): |
|
shapes.append(outs[-1].shape[2:]) |
|
outs.append(downblock(outs[-1])) |
|
outs[-1] = outs[-1] + self.resblocks(outs[-1]) |
|
for i, upblock in enumerate(self.upsam): |
|
outs[-1] = upblock(outs[-1], shapes[-1-i]) |
|
outs[-1] = outs[-2] + outs.pop() |
|
return x + self.finalconv(outs.pop()) |