File size: 8,779 Bytes
6fc43ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# This implementation is based on the DenseNet-BC implementation in torchvision
# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
# https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
self.efficient = efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm3d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
efficient=efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 3 or 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
tgt_modalities (list) - list of target modalities
efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
"""
# def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
# num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1
def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),]))
self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
self.features.add_module('relu0', nn.ReLU(inplace=True))
self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False))
self.tgt_modalities = tgt_modalities
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
efficient=efficient,
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config):
trans = _Transition(num_input_features=num_features,
num_output_features=int(num_features * compression))
self.features.add_module('transition%d' % (i + 1), trans)
num_features = int(num_features * compression)
# Final batch norm
self.features.add_module('norm_final', nn.BatchNorm3d(num_features))
# Classification heads
self.tgt = torch.nn.ModuleDict()
for k in tgt_modalities:
# self.tgt[k] = torch.nn.Linear(621, 1) # config 2
self.tgt[k] = torch.nn.Sequential(
torch.nn.Linear(self.test_size(), 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 1)
)
print(f'load_from_ckpt: {load_from_ckpt}')
# Initialization
if not load_from_ckpt:
for name, param in self.named_parameters():
if 'conv' in name and 'weight' in name:
n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
param.data.normal_().mul_(math.sqrt(2. / n))
elif 'norm' in name and 'weight' in name:
param.data.fill_(1)
elif 'norm' in name and 'bias' in name:
param.data.fill_(0)
elif ('classifier' in name or 'tgt' in name) and 'bias' in name:
param.data.fill_(0)
# self.size = self.test_size()
def forward(self, x, shap=True):
# print(x.shape)
features = self.features(x)
# print(features.shape)
out = F.relu(features, inplace=True)
# out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
# print(out.shape)
# out_tgt = self.tgt(out).squeeze(1)
# print(out_tgt)
# return F.softmax(out_tgt)
tgt_iter = self.tgt.keys()
out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter}
if shap:
out_tgt = torch.stack(list(out_tgt.values()))
return out_tgt.T
else:
return out_tgt
def test_size(self):
case = torch.ones((1, 1, 182, 218, 182))
output = self.features(case).view(-1).size(0)
return output
if __name__ == "__main__":
model = DenseNet(
tgt_modalities=['NC', 'MCI', 'DE'],
growth_rate=12,
block_config=(2, 3, 2),
compression=0.5,
num_init_features=16,
drop_rate=0.2)
print(model)
torch.manual_seed(42)
x = torch.rand((1, 1, 182, 218, 182))
# layers = list(model.features.named_children())
features = nn.Sequential(*list(model.features.children()))(x)
print(features.shape)
print(sum(p.numel() for p in model.parameters()))
# out = mdl.net_(x, shap=False)
# print(out)
out = model(x, shap=False)
print(out)
# layer_found = False
# features = None
# desired_layer_name = 'transition3'
# for name, layer in layers:
# if name == desired_layer_name:
# x = layer(x)
# print(x)
# model(x)
# print(features) |