Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import math | |
from functools import partial | |
__all__ = [ | |
'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | |
'resnet152', 'resnet200' | |
] | |
class FilterResponseNormNd(nn.Module): | |
def __init__(self, ndim, num_features, eps=1e-6, | |
learnable_eps=False): | |
""" | |
Input Variables: | |
---------------- | |
ndim: An integer indicating the number of dimensions of the expected input tensor. | |
num_features: An integer indicating the number of input feature dimensions. | |
eps: A scalar constant or learnable variable. | |
learnable_eps: A bool value indicating whether the eps is learnable. | |
""" | |
assert ndim in [3, 4, 5], \ | |
'FilterResponseNorm only supports 3d, 4d or 5d inputs.' | |
super(FilterResponseNormNd, self).__init__() | |
shape = (1, num_features) + (1,) * (ndim - 2) | |
self.eps = nn.Parameter(torch.ones(*shape) * eps) | |
if not learnable_eps: | |
self.eps.requires_grad_(False) | |
self.gamma = nn.Parameter(torch.Tensor(*shape)) | |
self.beta = nn.Parameter(torch.Tensor(*shape)) | |
self.tau = nn.Parameter(torch.Tensor(*shape)) | |
self.reset_parameters() | |
def forward(self, x): | |
avg_dims = tuple(range(2, x.dim())) # (2, 3) | |
nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True) | |
x = x * torch.rsqrt(nu2 + torch.abs(self.eps)) | |
return torch.max(self.gamma * x + self.beta, self.tau) | |
def reset_parameters(self): | |
nn.init.ones_(self.gamma) | |
nn.init.zeros_(self.beta) | |
nn.init.zeros_(self.tau) | |
def conv3x3x3(in_planes, out_planes, stride=1): | |
# 3x3x3 convolution with padding | |
return nn.Conv3d( | |
in_planes, | |
out_planes, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=False) | |
def downsample_basic_block(x, planes, stride): | |
out = F.avg_pool3d(x, kernel_size=1, stride=stride) | |
zero_pads = torch.Tensor( | |
out.size(0), planes - out.size(1), out.size(2), out.size(3), | |
out.size(4)).zero_() | |
if isinstance(out.data, torch.cuda.FloatTensor): | |
zero_pads = zero_pads.cuda() | |
out = Variable(torch.cat([out.data, zero_pads], dim=1)) | |
return out | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, inplanes, planes, stride=1, downsample=None): | |
super(BasicBlock, self).__init__() | |
self.conv1 = conv3x3x3(inplanes, planes, stride) | |
self.gn1 = nn.GroupNorm(32,planes) | |
#self.bn1 = nn.BatchNorm3d(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv3x3x3(planes, planes) | |
#self.bn2 = nn.BatchNorm3d(planes) | |
self.gn2 = nn.GroupNorm(32,planes) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
#out = self.bn1(out) | |
out = self.gn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
#out = self.bn2(out) | |
out = self.gn2(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
def __init__(self, inplanes, planes, stride=1, downsample=None): | |
super(Bottleneck, self).__init__() | |
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) | |
#self.bn1 = nn.BatchNorm3d(planes) | |
self.gn1 = nn.GroupNorm(32,planes) | |
self.conv2 = nn.Conv3d( | |
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
#self.bn2 = nn.BatchNorm3d(planes) | |
self.gn2 = nn.GroupNorm(32,planes) | |
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) | |
#self.bn3 = nn.BatchNorm3d(planes * 4) | |
self.gn3 = nn.GroupNorm(32,planes*4) | |
self.relu = nn.ReLU(inplace=True) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
#out = self.bn1(out) | |
out = self.gn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
#out = self.bn2(out) | |
out = self.gn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
#out = self.bn3(out) | |
out = self.gn3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_layers: int, | |
sigmoid_output: bool = False, | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList( | |
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
) | |
self.sigmoid_output = sigmoid_output | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
if self.sigmoid_output: | |
x = F.sigmoid(x) | |
return x | |
class ResNet(nn.Module): | |
def __init__(self, | |
block, | |
layers, | |
sample_size, | |
sample_duration, | |
shortcut_type='B', | |
num_classes=400): | |
self.num_classes = num_classes | |
self.inplanes = 64 | |
super(ResNet, self).__init__() | |
self.conv1 = nn.Conv3d( | |
1, | |
64, | |
kernel_size=7, | |
stride=(1, 2, 2), | |
padding=(3, 3, 3), | |
bias=False) | |
#self.bn1 = nn.BatchNorm3d(64) | |
self.gn1 = nn.GroupNorm(32,64) | |
self.relu = nn.ReLU(inplace=True) | |
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) | |
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) | |
self.layer2 = self._make_layer( | |
block, 128, layers[1], shortcut_type, stride=2) | |
self.layer3 = self._make_layer( | |
block, 256, layers[2], shortcut_type, stride=2) | |
self.layer4 = self._make_layer( | |
block, 512, layers[3], shortcut_type, stride=2) | |
last_duration = int(math.ceil(sample_duration / 16)) | |
last_size = int(math.ceil(sample_size / 32)) | |
self.avgpool = nn.AvgPool3d( | |
(last_duration, last_size, last_size), stride=1) | |
# self.avgpool = nn.AvgPool3d( | |
# (4, 2, 2), stride=1) | |
#self.fc = nn.Linear(81920, num_classes) | |
self.classfily = MLP(81920, 256, self.num_classes, 2, sigmoid_output=False) | |
# for m in self.modules(): | |
# if isinstance(m, nn.Conv3d): | |
# m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') | |
# elif isinstance(m, nn.BatchNorm3d): | |
# m.weight.data.fill_(1) | |
# m.bias.data.zero_() | |
for m in self.modules(): | |
if isinstance(m, nn.Conv3d): | |
m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') | |
elif isinstance(m, nn.GroupNorm): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): | |
downsample = None | |
if stride != 1 or self.inplanes != planes * block.expansion: | |
if shortcut_type == 'A': | |
downsample = partial( | |
downsample_basic_block, | |
planes=planes * block.expansion, | |
stride=stride) | |
else: | |
downsample = nn.Sequential( | |
nn.Conv3d( | |
self.inplanes, | |
planes * block.expansion, | |
kernel_size=1, | |
stride=stride, | |
bias=False), nn.GroupNorm(32,planes * block.expansion)) | |
# downsample = nn.Sequential( | |
# nn.Conv3d( | |
# self.inplanes, | |
# planes * block.expansion, | |
# kernel_size=1, | |
# stride=stride, | |
# bias=False), nn.BatchNorm3d(planes * block.expansion)) | |
layers = [] | |
layers.append(block(self.inplanes, planes, stride, downsample)) | |
self.inplanes = planes * block.expansion | |
for i in range(1, blocks): | |
layers.append(block(self.inplanes, planes)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.conv1(x) | |
#x = self.bn1(x) | |
x = self.gn1(x) | |
x = self.relu(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) | |
#x = self.fc(x) | |
self.feature = x | |
x = self.classfily(x) | |
if self.num_classes==1: | |
x = F.sigmoid(x) | |
return x | |
# def initialize_weights(self): | |
# # print(self.modules()) | |
# | |
# for m in self.modules(): | |
# if isinstance(m, nn.Linear): | |
# # print(m.weight.data.type()) | |
# # input() | |
# # m.weight.data.fill_(1.0) | |
# nn.init.kaiming_normal_(m.weight,a=0, mode='fan_in', nonlinearity='relu') | |
# print(m.weight) | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv2d') != -1: | |
nn.init.xavier_normal_(m.weight.data) | |
nn.init.constant_(m.bias.data, 0.0) | |
elif classname.find('Linear') != -1: | |
nn.init.xavier_normal_(m.weight) | |
nn.init.constant_(m.bias, 0.0) | |
def get_fine_tuning_parameters(model, ft_begin_index): | |
if ft_begin_index == 0: | |
return model.parameters() | |
ft_module_names = [] | |
for i in range(ft_begin_index, 5): | |
ft_module_names.append('layer{}'.format(i)) | |
ft_module_names.append('fc') | |
parameters = [] | |
for k, v in model.named_parameters(): | |
for ft_module in ft_module_names: | |
if ft_module in k: | |
parameters.append({'params': v}) | |
break | |
else: | |
parameters.append({'params': v, 'lr': 0.0}) | |
return parameters | |
def resnet10(**kwargs): | |
"""Constructs a ResNet-18 model. | |
""" | |
model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) | |
return model | |
def resnet18(**kwargs): | |
"""Constructs a ResNet-18 model. | |
""" | |
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |
return model | |
def resnet34(**kwargs): | |
"""Constructs a ResNet-34 model. | |
""" | |
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) | |
return model | |
def resnet50(**kwargs): | |
"""Constructs a ResNet-50 model. | |
""" | |
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |
#model.apply(weights_init) | |
return model | |
def resnet101(**kwargs): | |
"""Constructs a ResNet-101 model. | |
""" | |
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) | |
# model.apply(weights_init) | |
return model | |
def resnet152(**kwargs): | |
"""Constructs a ResNet-101 model. | |
""" | |
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) | |
return model | |
def resnet200(**kwargs): | |
"""Constructs a ResNet-101 model. | |
""" | |
model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) | |
# model.apply(weights_init) | |
return model | |