lingchmao's picture
Upload 12 files
6ffe23f verified
raw
history blame
25.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
# 2D: net = UNet2D(1,2,pab_channels=64,use_batchnorm=True)
# 3D: net = UNet3D(1,2,pab_channels=32,use_batchnorm=True)
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
class Conv3dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm3d(out_channels)
else:
bn = nn.Identity()
super(Conv3dReLU, self).__init__(conv, bn, relu)
class PAB2D(nn.Module):
def __init__(self, in_channels, out_channels, pab_channels=64):
super(PAB2D, self).__init__()
# Series of 1x1 conv to generate attention feature maps
self.pab_channels = pab_channels
self.in_channels = in_channels
self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.map_softmax = nn.Softmax(dim=1)
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
bsize = x.size()[0]
h = x.size()[2]
w = x.size()[3]
x_top = self.top_conv(x)
x_center = self.center_conv(x)
x_bottom = self.bottom_conv(x)
x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)
sp_map = torch.matmul(x_center, x_top)
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
sp_map = torch.matmul(sp_map, x_bottom)
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
x = x + sp_map
x = self.out_conv(x)
# print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
return x
class MFAB2D(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
# MFAB is just a modified version of SE-blocks, one for skip, one for input
super(MFAB2D, self).__init__()
self.hl_conv = nn.Sequential(
Conv2dReLU(
in_channels,
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
),
Conv2dReLU(
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = Conv2dReLU(
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
def forward(self, x, skip=None):
x = self.hl_conv(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
attention_hl = self.SE_hl(x)
if skip is not None:
attention_ll = self.SE_ll(skip)
attention_hl = attention_hl + attention_ll
x = x * attention_hl
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class PAB3D(nn.Module):
def __init__(self, in_channels, out_channels, pab_channels=64):
super(PAB3D, self).__init__()
# Series of 1x1 conv to generate attention feature maps
self.pab_channels = pab_channels
self.in_channels = in_channels
self.top_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
self.center_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
self.bottom_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
self.map_softmax = nn.Softmax(dim=1)
self.out_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
bsize = x.size()[0]
h = x.size()[2]
w = x.size()[3]
d = x.size()[4]
x_top = self.top_conv(x)
x_center = self.center_conv(x)
x_bottom = self.bottom_conv(x)
x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)
sp_map = torch.matmul(x_center, x_top)
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w*d, h*w*d)
sp_map = torch.matmul(sp_map, x_bottom)
sp_map = sp_map.reshape(bsize, self.in_channels, h, w, d)
x = x + sp_map
x = self.out_conv(x)
# print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
return x
class MFAB3D(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
# MFAB is just a modified version of SE-blocks, one for skip, one for input
super(MFAB3D, self).__init__()
self.hl_conv = nn.Sequential(
Conv3dReLU(
in_channels,
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
),
Conv3dReLU(
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Conv3d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv3d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Conv3d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv3d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = Conv3dReLU(
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv3dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
def forward(self, x, skip=None):
x = self.hl_conv(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
attention_hl = self.SE_hl(x)
if skip is not None:
attention_ll = self.SE_ll(skip)
attention_hl = attention_hl + attention_ll
x = x * attention_hl
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class DoubleConv2D(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down2D(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
NONLocalBlock2D(in_channels),
DoubleConv2D(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up2D(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv2D(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv2D(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv2D(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv2D, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet2D(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
super(UNet2D, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv2D(n_channels, pab_channels)
self.down1 = Down2D(pab_channels, 2*pab_channels)
self.down2 = Down2D(2*pab_channels, 4*pab_channels)
self.down3 = Down2D(4*pab_channels, 8*pab_channels)
factor = 2 if bilinear else 1
self.down4 = Down2D(8*pab_channels, 16*pab_channels // factor)
self.pab = PAB2D(8*pab_channels,8*pab_channels)
self.up1 = Up2D(16*pab_channels, 8*pab_channels // factor, bilinear)
self.up2 = Up2D(8*pab_channels, 4*pab_channels // factor, bilinear)
self.up3 = Up2D(4*pab_channels, 2*pab_channels // factor, bilinear)
self.up4 = Up2D(2*pab_channels, pab_channels, bilinear)
self.mfab1 = MFAB2D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
self.mfab2 = MFAB2D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
self.mfab3 = MFAB2D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
self.mfab4 = MFAB2D(pab_channels,pab_channels,pab_channels,use_batchnorm)
self.outc = OutConv2D(pab_channels, n_classes)
if aux_classifier == False:
self.aux = None
else:
# customize the auxiliary classification loss
# self.aux = nn.Sequential(nn.AdaptiveAvgPool2d(1),
# nn.Flatten(),
# nn.Dropout(p=0.1, inplace=True),
# nn.Linear(8*pab_channels, 16, bias=True),
# nn.Dropout(p=0.1, inplace=True),
# nn.Linear(16, n_classes, bias=True),
# nn.Softmax(1))
self.aux = nn.Sequential(
NONLocalBlock2D(8*pab_channels),
nn.Conv2d(8*pab_channels,1,1),
nn.InstanceNorm2d(1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(24*24, 16, bias=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(16, n_classes, bias=True),
nn.Softmax(1))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x5 = self.pab(x5)
x = self.mfab1(x5,x4)
x = self.mfab2(x,x3)
x = self.mfab3(x,x2)
x = self.mfab4(x,x1)
# x = self.up1(x5, x4)
# x = self.up2(x, x3)
# x = self.up3(x, x2)
# x = self.up4(x, x1)
logits = self.outc(x)
logits = F.softmax(logits,1)
if self.aux ==None:
return logits
else:
aux = self.aux(x5)
return logits, aux
class DoubleConv3D(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down3D(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2),
# NONLocalBlock3D(in_channels),
DoubleConv3D(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up3D(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv3D(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv3D(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv3D(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet3D(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
super(UNet3D, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv3D(n_channels, pab_channels)
self.down1 = Down3D(pab_channels, 2*pab_channels)
self.nnblock2 = NONLocalBlock3D(2*pab_channels)
self.down2 = Down3D(2*pab_channels, 4*pab_channels)
self.down3 = Down3D(4*pab_channels, 8*pab_channels)
factor = 2 if bilinear else 1
self.down4 = Down3D(8*pab_channels, 16*pab_channels // factor)
self.pab = PAB3D(8*pab_channels,8*pab_channels)
self.up1 = Up3D(16*pab_channels, 8*pab_channels // factor, bilinear)
self.up2 = Up3D(8*pab_channels, 4*pab_channels // factor, bilinear)
self.up3 = Up3D(4*pab_channels, 2*pab_channels // factor, bilinear)
self.up4 = Up3D(2*pab_channels, pab_channels, bilinear)
self.mfab1 = MFAB3D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
self.mfab2 = MFAB3D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
self.mfab3 = MFAB3D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
self.mfab4 = MFAB3D(pab_channels,pab_channels,pab_channels,use_batchnorm)
self.outc = OutConv3D(pab_channels, n_classes)
if aux_classifier == False:
self.aux = None
else:
# customize the auxiliary classification loss
# self.aux = nn.Sequential(nn.AdaptiveMaxPool3d(1),
# nn.Flatten(),
# nn.Dropout(p=0.1, inplace=True),
# nn.Linear(8*pab_channels, 16, bias=True),
# nn.Dropout(p=0.1, inplace=True),
# nn.Linear(16, n_classes, bias=True),
# nn.Softmax(1))
self.aux = nn.Sequential(nn.Conv3d(8*pab_channels,1,1),
nn.InstanceNorm3d(1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16*16*2, 16, bias=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(16, n_classes, bias=True),
nn.Softmax(1))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
# x2 = self.nnblock2(x2)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x5 = self.pab(x5)
x = self.mfab1(x5,x4)
x = self.mfab2(x,x3)
x = self.mfab3(x,x2)
x = self.mfab4(x,x1)
# x = self.up1(x5, x4)
# x = self.up2(x, x3)
# x = self.up3(x, x2)
# x = self.up4(x, x1)
logits = self.outc(x)
logits = F.softmax(logits,1)
if self.aux ==None:
return logits
else:
aux = self.aux(x5)
return logits, aux