P-PD / networks /drn_seg.py
mrneuralnet's picture
Initial commit
e875957
import math
import torch
import torch.nn as nn
from networks.drn import drn_c_26
def fill_up_weights(up):
w = up.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]
class DRNSeg(nn.Module):
def __init__(self, classes, pretrained_drn=False,
pretrained_model=None, use_torch_up=False):
super(DRNSeg, self).__init__()
model = drn_c_26(pretrained=pretrained_drn)
self.base = nn.Sequential(*list(model.children())[:-2])
if pretrained_model:
self.load_pretrained(pretrained_model)
self.seg = nn.Conv2d(model.out_dim, classes,
kernel_size=1, bias=True)
m = self.seg
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.bias.data.zero_()
if use_torch_up:
self.up = nn.UpsamplingBilinear2d(scale_factor=8)
else:
up = nn.ConvTranspose2d(classes, classes, 16, stride=8, padding=4,
output_padding=0, groups=classes,
bias=False)
fill_up_weights(up)
up.weight.requires_grad = False
self.up = up
def forward(self, x):
x = self.base(x)
x = self.seg(x)
y = self.up(x)
return y
def optim_parameters(self, memo=None):
for param in self.base.parameters():
yield param
for param in self.seg.parameters():
yield param
def load_pretrained(self, pretrained_model):
print("loading the pretrained drn model from %s" % pretrained_model)
state_dict = torch.load(pretrained_model, map_location='cpu')
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# filter out unnecessary keys
pretrained_dict = state_dict['model']
pretrained_dict = {k[5:]: v for k, v in pretrained_dict.items() if k.split('.')[0] == 'base'}
# load the pretrained state dict
self.base.load_state_dict(pretrained_dict)
class DRNSub(nn.Module):
def __init__(self, num_classes, pretrained_model=None, fix_base=False):
super(DRNSub, self).__init__()
drnseg = DRNSeg(2)
if pretrained_model:
print("loading the pretrained drn model from %s" % pretrained_model)
state_dict = torch.load(pretrained_model, map_location='cpu')
drnseg.load_state_dict(state_dict['model'])
self.base = drnseg.base
if fix_base:
for param in self.base.parameters():
param.requires_grad = False
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.base(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x