#!/usr/bin/env python # This model is part of the paper "ReconResNet: Regularised Residual Learning for MR Image Reconstruction of Undersampled Cartesian and Radial Data" (https://doi.org/10.1016/j.compbiomed.2022.105321) # and has been published on GitHub: https://github.com/soumickmj/NCC1701/blob/main/Bridge/WarpDrives/ReconResNet/ReconResNet.py import torch.nn as nn from tricorder.torch.transforms import Interpolator __author__ = "Soumick Chatterjee" __copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL" __credits__ = ["Soumick Chatterjee"] __license__ = "apache-2.0" __version__ = "1.0.0" __email__ = "soumick@live.com" __status__ = "Published" class ResidualBlock(nn.Module): def __init__(self, in_features, drop_prob=0.2): super(ResidualBlock, self).__init__() conv_block = [layer_pad(1), layer_conv(in_features, in_features, 3), layer_norm(in_features), act_relu(), layer_drop(p=drop_prob, inplace=True), layer_pad(1), layer_conv(in_features, in_features, 3), layer_norm(in_features)] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class DownsamplingBlock(nn.Module): def __init__(self, in_features, out_features): super(DownsamplingBlock, self).__init__() conv_block = [layer_conv(in_features, out_features, 3, stride=2, padding=1), layer_norm(out_features), act_relu()] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return self.conv_block(x) class UpsamplingBlock(nn.Module): def __init__(self, in_features, out_features, mode="convtrans", interpolator=None, post_interp_convtrans=False): super(UpsamplingBlock, self).__init__() self.interpolator = interpolator self.mode = mode self.post_interp_convtrans = post_interp_convtrans if self.post_interp_convtrans: self.post_conv = layer_conv(out_features, out_features, 1) if mode == "convtrans": conv_block = [layer_convtrans( in_features, out_features, 3, stride=2, padding=1, output_padding=1), ] else: conv_block = [layer_pad(1), layer_conv(in_features, out_features, 3), ] conv_block += [layer_norm(out_features), act_relu()] self.conv_block = nn.Sequential(*conv_block) def forward(self, x, out_shape=None): if self.mode == "convtrans": if self.post_interp_convtrans: x = self.conv_block(x) if x.shape[2:] != out_shape: return self.post_conv(self.interpolator(x, out_shape)) else: return x else: return self.conv_block(x) else: return self.conv_block(self.interpolator(x, out_shape)) class ReconResNetBase(nn.Module): def __init__(self, in_channels=1, out_channels=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2, is_replicatepad=0, out_act="sigmoid", forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False, is3D=False): # should use 14 as that gives number of trainable parameters close to number of possible pixel values in a image 256x256 super(ReconResNetBase, self).__init__() layers = {} if is3D: layers["layer_conv"] = nn.Conv3d layers["layer_convtrans"] = nn.ConvTranspose3d if do_batchnorm: layers["layer_norm"] = nn.BatchNorm3d else: layers["layer_norm"] = nn.InstanceNorm3d layers["layer_drop"] = nn.Dropout3d if is_replicatepad == 0: layers["layer_pad"] = nn.ReflectionPad3d elif is_replicatepad == 1: layers["layer_pad"] = nn.ReplicationPad3d layers["interp_mode"] = 'trilinear' else: layers["layer_conv"] = nn.Conv2d layers["layer_convtrans"] = nn.ConvTranspose2d if do_batchnorm: layers["layer_norm"] = nn.BatchNorm2d else: layers["layer_norm"] = nn.InstanceNorm2d layers["layer_drop"] = nn.Dropout2d if is_replicatepad == 0: layers["layer_pad"] = nn.ReflectionPad2d elif is_replicatepad == 1: layers["layer_pad"] = nn.ReplicationPad2d layers["interp_mode"] = 'bilinear' if is_relu_leaky: layers["act_relu"] = nn.PReLU else: layers["act_relu"] = nn.ReLU globals().update(layers) self.forwardV = forwardV self.upinterp_algo = upinterp_algo interpolator = Interpolator( mode=layers["interp_mode"] if self.upinterp_algo == "convtrans" else self.upinterp_algo) # Initial convolution block intialConv = [layer_pad(3), layer_conv(in_channels, starting_nfeatures, 7), layer_norm(starting_nfeatures), act_relu()] # Downsampling [need to save the shape for upsample] downsam = [] in_features = starting_nfeatures out_features = in_features*2 for _ in range(updown_blocks): downsam.append(DownsamplingBlock(in_features, out_features)) in_features = out_features out_features = in_features*2 # Residual blocks resblocks = [] for _ in range(res_blocks): resblocks += [ResidualBlock(in_features, res_drop_prob)] # Upsampling upsam = [] out_features = in_features//2 for _ in range(updown_blocks): upsam.append(UpsamplingBlock(in_features, out_features, self.upinterp_algo, interpolator, post_interp_convtrans)) in_features = out_features out_features = in_features//2 # Output layer finalconv = [layer_pad(3), layer_conv(starting_nfeatures, out_channels, 7), ] if out_act == "sigmoid": finalconv += [nn.Sigmoid(), ] elif out_act == "relu": finalconv += [act_relu(), ] elif out_act == "tanh": finalconv += [nn.Tanh(), ] self.intialConv = nn.Sequential(*intialConv) self.downsam = nn.ModuleList(downsam) self.resblocks = nn.Sequential(*resblocks) self.upsam = nn.ModuleList(upsam) self.finalconv = nn.Sequential(*finalconv) if self.forwardV == 0: self.forward = self.forwardV0 elif self.forwardV == 1: self.forward = self.forwardV1 elif self.forwardV == 2: self.forward = self.forwardV2 elif self.forwardV == 3: self.forward = self.forwardV3 elif self.forwardV == 4: self.forward = self.forwardV4 elif self.forwardV == 5: self.forward = self.forwardV5 def forwardV0(self, x): # v0: Original Version x = self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(x.shape[2:]) x = downblock(x) x = self.resblocks(x) for i, upblock in enumerate(self.upsam): x = upblock(x, shapes[-1-i]) return self.finalconv(x) def forwardV1(self, x): # v1: input is added to the final output out = self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(out.shape[2:]) out = downblock(out) out = self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) return x + self.finalconv(out) def forwardV2(self, x): # v2: residual of v1 + input to the residual blocks added back with the output out = self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(out.shape[2:]) out = downblock(out) out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) return x + self.finalconv(out) def forwardV3(self, x): # v3: residual of v2 + input of the initial conv added back with the output out = x + self.intialConv(x) shapes = [] for downblock in self.downsam: shapes.append(out.shape[2:]) out = downblock(out) out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) return x + self.finalconv(out) def forwardV4(self, x): # v4: residual of v3 + output of the initial conv added back with the input of final conv iniconv = x + self.intialConv(x) shapes = [] if len(self.downsam) > 0: for i, downblock in enumerate(self.downsam): if i == 0: shapes.append(iniconv.shape[2:]) out = downblock(iniconv) else: shapes.append(out.shape[2:]) out = downblock(out) else: out = iniconv out = out + self.resblocks(out) for i, upblock in enumerate(self.upsam): out = upblock(out, shapes[-1-i]) out = iniconv + out return x + self.finalconv(out) def forwardV5(self, x): # v5: residual of v4 + individual down blocks with individual up blocks outs = [x + self.intialConv(x)] shapes = [] for i, downblock in enumerate(self.downsam): shapes.append(outs[-1].shape[2:]) outs.append(downblock(outs[-1])) outs[-1] = outs[-1] + self.resblocks(outs[-1]) for i, upblock in enumerate(self.upsam): outs[-1] = upblock(outs[-1], shapes[-1-i]) outs[-1] = outs[-2] + outs.pop() return x + self.finalconv(outs.pop())