import torch import torch.nn as nn import torch.nn.functional as F # 2D: net = UNet2D(1,2,pab_channels=64,use_batchnorm=True) # 3D: net = UNet3D(1,2,pab_channels=32,use_batchnorm=True) class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x): ''' :param x: (b, c, t, h, w) :return: ''' batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): if use_batchnorm == "inplace" and InPlaceABN is None: raise RuntimeError( "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" ) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) if use_batchnorm == "inplace": bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) relu = nn.Identity() elif use_batchnorm and use_batchnorm != "inplace": bn = nn.BatchNorm2d(out_channels) else: bn = nn.Identity() super(Conv2dReLU, self).__init__(conv, bn, relu) class Conv3dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): if use_batchnorm == "inplace" and InPlaceABN is None: raise RuntimeError( "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" ) conv = nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) if use_batchnorm == "inplace": bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) relu = nn.Identity() elif use_batchnorm and use_batchnorm != "inplace": bn = nn.BatchNorm3d(out_channels) else: bn = nn.Identity() super(Conv3dReLU, self).__init__(conv, bn, relu) class PAB2D(nn.Module): def __init__(self, in_channels, out_channels, pab_channels=64): super(PAB2D, self).__init__() # Series of 1x1 conv to generate attention feature maps self.pab_channels = pab_channels self.in_channels = in_channels self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.map_softmax = nn.Softmax(dim=1) self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x): bsize = x.size()[0] h = x.size()[2] w = x.size()[3] x_top = self.top_conv(x) x_center = self.center_conv(x) x_bottom = self.bottom_conv(x) x_top = x_top.flatten(2) x_center = x_center.flatten(2).transpose(1, 2) x_bottom = x_bottom.flatten(2).transpose(1, 2) sp_map = torch.matmul(x_center, x_top) sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w) sp_map = torch.matmul(sp_map, x_bottom) sp_map = sp_map.reshape(bsize, self.in_channels, h, w) x = x + sp_map x = self.out_conv(x) # print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape) return x class MFAB2D(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): # MFAB is just a modified version of SE-blocks, one for skip, one for input super(MFAB2D, self).__init__() self.hl_conv = nn.Sequential( Conv2dReLU( in_channels, in_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ), Conv2dReLU( in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm, ) ) self.SE_ll = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(skip_channels, skip_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(skip_channels // reduction, skip_channels, 1), nn.Sigmoid(), ) self.SE_hl = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(skip_channels, skip_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(skip_channels // reduction, skip_channels, 1), nn.Sigmoid(), ) self.conv1 = Conv2dReLU( skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) def forward(self, x, skip=None): x = self.hl_conv(x) x = F.interpolate(x, scale_factor=2, mode="nearest") attention_hl = self.SE_hl(x) if skip is not None: attention_ll = self.SE_ll(skip) attention_hl = attention_hl + attention_ll x = x * attention_hl x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class PAB3D(nn.Module): def __init__(self, in_channels, out_channels, pab_channels=64): super(PAB3D, self).__init__() # Series of 1x1 conv to generate attention feature maps self.pab_channels = pab_channels self.in_channels = in_channels self.top_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1) self.center_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1) self.bottom_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.map_softmax = nn.Softmax(dim=1) self.out_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x): bsize = x.size()[0] h = x.size()[2] w = x.size()[3] d = x.size()[4] x_top = self.top_conv(x) x_center = self.center_conv(x) x_bottom = self.bottom_conv(x) x_top = x_top.flatten(2) x_center = x_center.flatten(2).transpose(1, 2) x_bottom = x_bottom.flatten(2).transpose(1, 2) sp_map = torch.matmul(x_center, x_top) sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w*d, h*w*d) sp_map = torch.matmul(sp_map, x_bottom) sp_map = sp_map.reshape(bsize, self.in_channels, h, w, d) x = x + sp_map x = self.out_conv(x) # print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape) return x class MFAB3D(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): # MFAB is just a modified version of SE-blocks, one for skip, one for input super(MFAB3D, self).__init__() self.hl_conv = nn.Sequential( Conv3dReLU( in_channels, in_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ), Conv3dReLU( in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm, ) ) self.SE_ll = nn.Sequential( nn.AdaptiveAvgPool3d(1), nn.Conv3d(skip_channels, skip_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv3d(skip_channels // reduction, skip_channels, 1), nn.Sigmoid(), ) self.SE_hl = nn.Sequential( nn.AdaptiveAvgPool3d(1), nn.Conv3d(skip_channels, skip_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv3d(skip_channels // reduction, skip_channels, 1), nn.Sigmoid(), ) self.conv1 = Conv3dReLU( skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv3dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) def forward(self, x, skip=None): x = self.hl_conv(x) x = F.interpolate(x, scale_factor=2, mode="nearest") attention_hl = self.SE_hl(x) if skip is not None: attention_ll = self.SE_ll(skip) attention_hl = attention_hl + attention_ll x = x * attention_hl x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class DoubleConv2D(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down2D(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), NONLocalBlock2D(in_channels), DoubleConv2D(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up2D(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv2D(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv2D(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv2D(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv2D, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet2D(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False): super(UNet2D, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv2D(n_channels, pab_channels) self.down1 = Down2D(pab_channels, 2*pab_channels) self.down2 = Down2D(2*pab_channels, 4*pab_channels) self.down3 = Down2D(4*pab_channels, 8*pab_channels) factor = 2 if bilinear else 1 self.down4 = Down2D(8*pab_channels, 16*pab_channels // factor) self.pab = PAB2D(8*pab_channels,8*pab_channels) self.up1 = Up2D(16*pab_channels, 8*pab_channels // factor, bilinear) self.up2 = Up2D(8*pab_channels, 4*pab_channels // factor, bilinear) self.up3 = Up2D(4*pab_channels, 2*pab_channels // factor, bilinear) self.up4 = Up2D(2*pab_channels, pab_channels, bilinear) self.mfab1 = MFAB2D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm) self.mfab2 = MFAB2D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm) self.mfab3 = MFAB2D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm) self.mfab4 = MFAB2D(pab_channels,pab_channels,pab_channels,use_batchnorm) self.outc = OutConv2D(pab_channels, n_classes) if aux_classifier == False: self.aux = None else: # customize the auxiliary classification loss # self.aux = nn.Sequential(nn.AdaptiveAvgPool2d(1), # nn.Flatten(), # nn.Dropout(p=0.1, inplace=True), # nn.Linear(8*pab_channels, 16, bias=True), # nn.Dropout(p=0.1, inplace=True), # nn.Linear(16, n_classes, bias=True), # nn.Softmax(1)) self.aux = nn.Sequential( NONLocalBlock2D(8*pab_channels), nn.Conv2d(8*pab_channels,1,1), nn.InstanceNorm2d(1), nn.ReLU(), nn.Flatten(), nn.Linear(24*24, 16, bias=True), nn.Dropout(p=0.2, inplace=True), nn.Linear(16, n_classes, bias=True), nn.Softmax(1)) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x5 = self.pab(x5) x = self.mfab1(x5,x4) x = self.mfab2(x,x3) x = self.mfab3(x,x2) x = self.mfab4(x,x1) # x = self.up1(x5, x4) # x = self.up2(x, x3) # x = self.up3(x, x2) # x = self.up4(x, x1) logits = self.outc(x) logits = F.softmax(logits,1) if self.aux ==None: return logits else: aux = self.aux(x5) return logits, aux class DoubleConv3D(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm3d(mid_channels), nn.ReLU(inplace=True), nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down3D(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool3d(2), # NONLocalBlock3D(in_channels), DoubleConv3D(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up3D(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv3D(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv3D(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv3D(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet3D(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False): super(UNet3D, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv3D(n_channels, pab_channels) self.down1 = Down3D(pab_channels, 2*pab_channels) self.nnblock2 = NONLocalBlock3D(2*pab_channels) self.down2 = Down3D(2*pab_channels, 4*pab_channels) self.down3 = Down3D(4*pab_channels, 8*pab_channels) factor = 2 if bilinear else 1 self.down4 = Down3D(8*pab_channels, 16*pab_channels // factor) self.pab = PAB3D(8*pab_channels,8*pab_channels) self.up1 = Up3D(16*pab_channels, 8*pab_channels // factor, bilinear) self.up2 = Up3D(8*pab_channels, 4*pab_channels // factor, bilinear) self.up3 = Up3D(4*pab_channels, 2*pab_channels // factor, bilinear) self.up4 = Up3D(2*pab_channels, pab_channels, bilinear) self.mfab1 = MFAB3D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm) self.mfab2 = MFAB3D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm) self.mfab3 = MFAB3D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm) self.mfab4 = MFAB3D(pab_channels,pab_channels,pab_channels,use_batchnorm) self.outc = OutConv3D(pab_channels, n_classes) if aux_classifier == False: self.aux = None else: # customize the auxiliary classification loss # self.aux = nn.Sequential(nn.AdaptiveMaxPool3d(1), # nn.Flatten(), # nn.Dropout(p=0.1, inplace=True), # nn.Linear(8*pab_channels, 16, bias=True), # nn.Dropout(p=0.1, inplace=True), # nn.Linear(16, n_classes, bias=True), # nn.Softmax(1)) self.aux = nn.Sequential(nn.Conv3d(8*pab_channels,1,1), nn.InstanceNorm3d(1), nn.ReLU(), nn.Flatten(), nn.Linear(16*16*2, 16, bias=True), nn.Dropout(p=0.2, inplace=True), nn.Linear(16, n_classes, bias=True), nn.Softmax(1)) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) # x2 = self.nnblock2(x2) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x5 = self.pab(x5) x = self.mfab1(x5,x4) x = self.mfab2(x,x3) x = self.mfab3(x,x2) x = self.mfab4(x,x1) # x = self.up1(x5, x4) # x = self.up2(x, x3) # x = self.up3(x, x2) # x = self.up4(x, x1) logits = self.outc(x) logits = F.softmax(logits,1) if self.aux ==None: return logits else: aux = self.aux(x5) return logits, aux