nmed2024 / adrd /nn /dense_net.py
xf3227's picture
ok
6fc43ab
raw
history blame
8.78 kB
# This implementation is based on the DenseNet-BC implementation in torchvision
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
# https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
self.efficient = efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm3d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
efficient=efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 3 or 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
tgt_modalities (list) - list of target modalities
efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
"""
# def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
# num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1
def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),]))
self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
self.features.add_module('relu0', nn.ReLU(inplace=True))
self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False))
self.tgt_modalities = tgt_modalities
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
efficient=efficient,
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config):
trans = _Transition(num_input_features=num_features,
num_output_features=int(num_features * compression))
self.features.add_module('transition%d' % (i + 1), trans)
num_features = int(num_features * compression)
# Final batch norm
self.features.add_module('norm_final', nn.BatchNorm3d(num_features))
# Classification heads
self.tgt = torch.nn.ModuleDict()
for k in tgt_modalities:
# self.tgt[k] = torch.nn.Linear(621, 1) # config 2
self.tgt[k] = torch.nn.Sequential(
torch.nn.Linear(self.test_size(), 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 1)
)
print(f'load_from_ckpt: {load_from_ckpt}')
# Initialization
if not load_from_ckpt:
for name, param in self.named_parameters():
if 'conv' in name and 'weight' in name:
n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
param.data.normal_().mul_(math.sqrt(2. / n))
elif 'norm' in name and 'weight' in name:
param.data.fill_(1)
elif 'norm' in name and 'bias' in name:
param.data.fill_(0)
elif ('classifier' in name or 'tgt' in name) and 'bias' in name:
param.data.fill_(0)
# self.size = self.test_size()
def forward(self, x, shap=True):
# print(x.shape)
features = self.features(x)
# print(features.shape)
out = F.relu(features, inplace=True)
# out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
# print(out.shape)
# out_tgt = self.tgt(out).squeeze(1)
# print(out_tgt)
# return F.softmax(out_tgt)
tgt_iter = self.tgt.keys()
out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter}
if shap:
out_tgt = torch.stack(list(out_tgt.values()))
return out_tgt.T
else:
return out_tgt
def test_size(self):
case = torch.ones((1, 1, 182, 218, 182))
output = self.features(case).view(-1).size(0)
return output
if __name__ == "__main__":
model = DenseNet(
tgt_modalities=['NC', 'MCI', 'DE'],
growth_rate=12,
block_config=(2, 3, 2),
compression=0.5,
num_init_features=16,
drop_rate=0.2)
print(model)
torch.manual_seed(42)
x = torch.rand((1, 1, 182, 218, 182))
# layers = list(model.features.named_children())
features = nn.Sequential(*list(model.features.children()))(x)
print(features.shape)
print(sum(p.numel() for p in model.parameters()))
# out = mdl.net_(x, shap=False)
# print(out)
out = model(x, shap=False)
print(out)
# layer_found = False
# features = None
# desired_layer_name = 'transition3'
# for name, layer in layers:
# if name == desired_layer_name:
# x = layer(x)
# print(x)
# model(x)
# print(features)