nmed2024 / adrd /nn /unet.py
xf3227's picture
ok
6fc43ab
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.nn import init
import torch.nn.functional as F
from icecream import ic
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
#super(ContBatchNorm3d, self)._check_input_dim(input)
def forward(self, input):
self._check_input_dim(input)
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
True, self.momentum, self.eps)
class LUConv(nn.Module):
def __init__(self, in_chan, out_chan, act):
super(LUConv, self).__init__()
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
self.bn1 = ContBatchNorm3d(out_chan)
if act == 'relu':
self.activation = nn.ReLU(out_chan)
elif act == 'prelu':
self.activation = nn.PReLU(out_chan)
elif act == 'elu':
self.activation = nn.ELU(inplace=True)
else:
raise
def forward(self, x):
out = self.activation(self.bn1(self.conv1(x)))
return out
def _make_nConv(in_channel, depth, act, double_chnnel=False):
if double_chnnel:
layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act)
layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act)
else:
layer1 = LUConv(in_channel, 32*(2**depth),act)
layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act)
return nn.Sequential(layer1,layer2)
class DownTransition(nn.Module):
def __init__(self, in_channel,depth, act):
super(DownTransition, self).__init__()
self.ops = _make_nConv(in_channel, depth,act)
self.maxpool = nn.MaxPool3d(2)
self.current_depth = depth
def forward(self, x):
if self.current_depth == 3:
out = self.ops(x)
out_before_pool = out
else:
out_before_pool = self.ops(x)
out = self.maxpool(out_before_pool)
return out, out_before_pool
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, depth,act):
super(UpTransition, self).__init__()
self.depth = depth
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True)
def forward(self, x, skip_x):
out_up_conv = self.up_conv(x)
concat = torch.cat((out_up_conv,skip_x),1)
out = self.ops(concat)
return out
class OutputTransition(nn.Module):
def __init__(self, inChans, n_labels):
super(OutputTransition, self).__init__()
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.sigmoid(self.final_conv(x))
return out
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'):
super().__init__()
kernel_size, kernel_stride, kernel_padding = kernel
pool_kernel, pool_stride, pool_padding = pooling
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False)
self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding)
self.BN = nn.BatchNorm3d(out_channels)
self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False)
self.dropout = nn.Dropout(drop_rate, inplace=False)
def forward(self, x):
x = self.conv(x)
x = self.pooling(x)
x = self.BN(x)
x = self.relu(x)
x = self.dropout(x)
return x
class AttentionModule(nn.Module):
def __init__(self, in_channels, out_channels, drop_rate=0.1):
super(AttentionModule, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False)
self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0))
def forward(self, x, return_attention=True):
feats = self.conv(x)
att = F.softmax(self.attention(x))
out = feats * att
if return_attention:
return att, out
return out
class UNet3D(nn.Module):
# the number of convolutions in each layer corresponds
# to what is in the actual prototxt, not the intent
def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4):
super(UNet3D, self).__init__()
self.blocks = blocks
self.down_tr64 = DownTransition(1,0,act)
self.down_tr128 = DownTransition(64,1,act)
self.down_tr256 = DownTransition(128,2,act)
self.down_tr512 = DownTransition(256,3,act)
self.up_tr256 = UpTransition(512, 512,2,act)
self.up_tr128 = UpTransition(256,256, 1,act)
self.up_tr64 = UpTransition(128,128,0,act)
self.out_tr = OutputTransition(64, 1)
self.pretrained = pretrained
self.attention = attention
if pretrained:
print("Using image pretrained model checkpoint")
weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt'
checkpoint = torch.load(weight_dir)
state_dict = checkpoint['state_dict']
unParalled_state_dict = {}
for key in state_dict.keys():
unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
self.load_state_dict(unParalled_state_dict)
del self.up_tr256
del self.up_tr128
del self.up_tr64
del self.out_tr
if self.blocks == 5:
self.down_tr1024 = DownTransition(512,4,act)
# self.conv1 = nn.Conv3d(512, 256, 1, 1, 0, bias=False)
# self.conv2 = nn.Conv3d(256, 128, 1, 1, 0, bias=False)
# self.conv3 = nn.Conv3d(128, 64, 1, 1, 0, bias=False)
if attention:
self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate)
# Output.
self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
dummy_inp = torch.rand(input_size)
dummy_feats = self.forward(dummy_inp, stage='get_features')
dummy_feats = dummy_feats[0]
self.in_features = list(dummy_feats.shape)
ic(self.in_features)
self._init_weights()
def _init_weights(self):
if not self.pretrained:
for m in self.modules():
if isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight)
elif isinstance(m, ContBatchNorm3d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
init.constant_(m.bias, 0)
elif self.attention:
for m in self.attention_module.modules():
if isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
else:
pass
# Zero initialize the last batchnorm in each residual branch.
# for m in self.modules():
# if isinstance(m, BottleneckBlock):
# init.constant_(m.out_conv.bn.weight, 0)
def forward(self, x, stage='normal', attention=False):
ic('backbone forward')
self.out64, self.skip_out64 = self.down_tr64(x)
self.out128,self.skip_out128 = self.down_tr128(self.out64)
self.out256,self.skip_out256 = self.down_tr256(self.out128)
self.out512,self.skip_out512 = self.down_tr512(self.out256)
if self.blocks == 5:
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
ic(self.out1024.shape)
# self.out = self.conv1(self.out512)
# self.out = self.conv2(self.out)
# self.out = self.conv3(self.out)
# self.out = self.conv(self.out)
ic(hasattr(self, 'attention_module'))
if hasattr(self, 'attention_module'):
att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512)
else:
feats = self.out1024 if self.blocks==5 else self.out512
ic(feats.shape)
if attention:
return att, feats
return feats