|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from torchvision import models |
|
from torch.nn import init |
|
import torch.nn.functional as F |
|
from icecream import ic |
|
|
|
|
|
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): |
|
def _check_input_dim(self, input): |
|
|
|
if input.dim() != 5: |
|
raise ValueError('expected 5D input (got {}D input)'.format(input.dim())) |
|
|
|
|
|
def forward(self, input): |
|
self._check_input_dim(input) |
|
return F.batch_norm( |
|
input, self.running_mean, self.running_var, self.weight, self.bias, |
|
True, self.momentum, self.eps) |
|
|
|
|
|
class LUConv(nn.Module): |
|
def __init__(self, in_chan, out_chan, act): |
|
super(LUConv, self).__init__() |
|
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1) |
|
self.bn1 = ContBatchNorm3d(out_chan) |
|
|
|
if act == 'relu': |
|
self.activation = nn.ReLU(out_chan) |
|
elif act == 'prelu': |
|
self.activation = nn.PReLU(out_chan) |
|
elif act == 'elu': |
|
self.activation = nn.ELU(inplace=True) |
|
else: |
|
raise |
|
|
|
def forward(self, x): |
|
out = self.activation(self.bn1(self.conv1(x))) |
|
return out |
|
|
|
|
|
def _make_nConv(in_channel, depth, act, double_chnnel=False): |
|
if double_chnnel: |
|
layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act) |
|
layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act) |
|
else: |
|
layer1 = LUConv(in_channel, 32*(2**depth),act) |
|
layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act) |
|
|
|
return nn.Sequential(layer1,layer2) |
|
|
|
|
|
class DownTransition(nn.Module): |
|
def __init__(self, in_channel,depth, act): |
|
super(DownTransition, self).__init__() |
|
self.ops = _make_nConv(in_channel, depth,act) |
|
self.maxpool = nn.MaxPool3d(2) |
|
self.current_depth = depth |
|
|
|
def forward(self, x): |
|
if self.current_depth == 3: |
|
out = self.ops(x) |
|
out_before_pool = out |
|
else: |
|
out_before_pool = self.ops(x) |
|
out = self.maxpool(out_before_pool) |
|
return out, out_before_pool |
|
|
|
class UpTransition(nn.Module): |
|
def __init__(self, inChans, outChans, depth,act): |
|
super(UpTransition, self).__init__() |
|
self.depth = depth |
|
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2) |
|
self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True) |
|
|
|
def forward(self, x, skip_x): |
|
out_up_conv = self.up_conv(x) |
|
concat = torch.cat((out_up_conv,skip_x),1) |
|
out = self.ops(concat) |
|
return out |
|
|
|
class OutputTransition(nn.Module): |
|
def __init__(self, inChans, n_labels): |
|
|
|
super(OutputTransition, self).__init__() |
|
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
out = self.sigmoid(self.final_conv(x)) |
|
return out |
|
|
|
class ConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'): |
|
super().__init__() |
|
kernel_size, kernel_stride, kernel_padding = kernel |
|
pool_kernel, pool_stride, pool_padding = pooling |
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False) |
|
self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding) |
|
self.BN = nn.BatchNorm3d(out_channels) |
|
self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False) |
|
self.dropout = nn.Dropout(drop_rate, inplace=False) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.pooling(x) |
|
x = self.BN(x) |
|
x = self.relu(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
class AttentionModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, drop_rate=0.1): |
|
super(AttentionModule, self).__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False) |
|
self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0)) |
|
|
|
def forward(self, x, return_attention=True): |
|
feats = self.conv(x) |
|
att = F.softmax(self.attention(x)) |
|
|
|
out = feats * att |
|
|
|
if return_attention: |
|
return att, out |
|
|
|
return out |
|
|
|
class UNet3D(nn.Module): |
|
|
|
|
|
def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4): |
|
super(UNet3D, self).__init__() |
|
|
|
self.blocks = blocks |
|
self.down_tr64 = DownTransition(1,0,act) |
|
self.down_tr128 = DownTransition(64,1,act) |
|
self.down_tr256 = DownTransition(128,2,act) |
|
self.down_tr512 = DownTransition(256,3,act) |
|
|
|
self.up_tr256 = UpTransition(512, 512,2,act) |
|
self.up_tr128 = UpTransition(256,256, 1,act) |
|
self.up_tr64 = UpTransition(128,128,0,act) |
|
self.out_tr = OutputTransition(64, 1) |
|
|
|
self.pretrained = pretrained |
|
self.attention = attention |
|
if pretrained: |
|
print("Using image pretrained model checkpoint") |
|
weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt' |
|
checkpoint = torch.load(weight_dir) |
|
state_dict = checkpoint['state_dict'] |
|
unParalled_state_dict = {} |
|
for key in state_dict.keys(): |
|
unParalled_state_dict[key.replace("module.", "")] = state_dict[key] |
|
self.load_state_dict(unParalled_state_dict) |
|
del self.up_tr256 |
|
del self.up_tr128 |
|
del self.up_tr64 |
|
del self.out_tr |
|
|
|
if self.blocks == 5: |
|
self.down_tr1024 = DownTransition(512,4,act) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention: |
|
self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate) |
|
|
|
self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6)) |
|
|
|
dummy_inp = torch.rand(input_size) |
|
dummy_feats = self.forward(dummy_inp, stage='get_features') |
|
dummy_feats = dummy_feats[0] |
|
self.in_features = list(dummy_feats.shape) |
|
ic(self.in_features) |
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
if not self.pretrained: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv3d): |
|
init.kaiming_normal_(m.weight) |
|
elif isinstance(m, ContBatchNorm3d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Linear): |
|
init.kaiming_normal_(m.weight) |
|
init.constant_(m.bias, 0) |
|
elif self.attention: |
|
for m in self.attention_module.modules(): |
|
if isinstance(m, nn.Conv3d): |
|
init.kaiming_normal_(m.weight) |
|
elif isinstance(m, nn.BatchNorm3d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias, 0) |
|
else: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, stage='normal', attention=False): |
|
ic('backbone forward') |
|
self.out64, self.skip_out64 = self.down_tr64(x) |
|
self.out128,self.skip_out128 = self.down_tr128(self.out64) |
|
self.out256,self.skip_out256 = self.down_tr256(self.out128) |
|
self.out512,self.skip_out512 = self.down_tr512(self.out256) |
|
if self.blocks == 5: |
|
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512) |
|
ic(self.out1024.shape) |
|
|
|
|
|
|
|
|
|
ic(hasattr(self, 'attention_module')) |
|
if hasattr(self, 'attention_module'): |
|
att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512) |
|
else: |
|
feats = self.out1024 if self.blocks==5 else self.out512 |
|
ic(feats.shape) |
|
if attention: |
|
return att, feats |
|
return feats |