Hancy's picture
init
851751e
raw
history blame
14 kB
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, bn_d=0.1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d)
self.relu1 = nn.LeakyReLU(0.1)
self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d)
self.relu2 = nn.LeakyReLU(0.1)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)
out += residual
return out
# ******************************************************************************
# number of layers per model
model_blocks = {
21: [1, 1, 2, 2, 1],
53: [1, 2, 8, 8, 4],
}
class Backbone(nn.Module):
"""
Class for DarknetSeg. Subclasses PyTorch's own "nn" module
"""
def __init__(self, params):
super(Backbone, self).__init__()
self.use_range = params["input_depth"]["range"]
self.use_xyz = params["input_depth"]["xyz"]
self.use_remission = params["input_depth"]["remission"]
self.drop_prob = params["dropout"]
self.bn_d = params["bn_d"]
self.OS = params["OS"]
self.layers = params["extra"]["layers"]
# input depth calc
self.input_depth = 0
self.input_idxs = []
if self.use_range:
self.input_depth += 1
self.input_idxs.append(0)
if self.use_xyz:
self.input_depth += 3
self.input_idxs.extend([1, 2, 3])
if self.use_remission:
self.input_depth += 1
self.input_idxs.append(4)
# stride play
self.strides = [2, 2, 2, 2, 2]
# check current stride
current_os = 1
for s in self.strides:
current_os *= s
# make the new stride
if self.OS > current_os:
print("Can't do OS, ", self.OS,
" because it is bigger than original ", current_os)
else:
# redo strides according to needed stride
for i, stride in enumerate(reversed(self.strides), 0):
if int(current_os) != self.OS:
if stride == 2:
current_os /= 2
self.strides[-1 - i] = 1
if int(current_os) == self.OS:
break
# check that darknet exists
assert self.layers in model_blocks.keys()
# generate layers depending on darknet type
self.blocks = model_blocks[self.layers]
# input layer
self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d)
self.relu1 = nn.LeakyReLU(0.1)
# encoder
self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0],
stride=self.strides[0], bn_d=self.bn_d)
self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1],
stride=self.strides[1], bn_d=self.bn_d)
self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2],
stride=self.strides[2], bn_d=self.bn_d)
self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3],
stride=self.strides[3], bn_d=self.bn_d)
self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4],
stride=self.strides[4], bn_d=self.bn_d)
# for a bit of fun
self.dropout = nn.Dropout2d(self.drop_prob)
# last channels
self.last_channels = 1024
# make layer useful function
def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1):
layers = []
# downsample
layers.append(("conv", nn.Conv2d(planes[0], planes[1],
kernel_size=3,
stride=[1, stride], dilation=1,
padding=1, bias=False)))
layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
layers.append(("relu", nn.LeakyReLU(0.1)))
# blocks
inplanes = planes[1]
for i in range(0, blocks):
layers.append(("residual_{}".format(i),
block(inplanes, planes, bn_d)))
return nn.Sequential(OrderedDict(layers))
def run_layer(self, x, layer, skips, os):
y = layer(x)
if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]:
skips[os] = x.detach()
os *= 2
x = y
return x, skips, os
def forward(self, x, return_logits=False, return_list=None):
# filter input
x = x[:, self.input_idxs]
# run cnn
# store for skip connections
skips = {}
out_dict = {}
os = 1
# first layer
x, skips, os = self.run_layer(x, self.conv1, skips, os)
x, skips, os = self.run_layer(x, self.bn1, skips, os)
x, skips, os = self.run_layer(x, self.relu1, skips, os)
if return_list and 'enc_0' in return_list:
out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024
# all encoder blocks with intermediate dropouts
x, skips, os = self.run_layer(x, self.enc1, skips, os)
if return_list and 'enc_1' in return_list:
out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512
x, skips, os = self.run_layer(x, self.dropout, skips, os)
x, skips, os = self.run_layer(x, self.enc2, skips, os)
if return_list and 'enc_2' in return_list:
out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256
x, skips, os = self.run_layer(x, self.dropout, skips, os)
x, skips, os = self.run_layer(x, self.enc3, skips, os)
if return_list and 'enc_3' in return_list:
out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128
x, skips, os = self.run_layer(x, self.dropout, skips, os)
x, skips, os = self.run_layer(x, self.enc4, skips, os)
if return_list and 'enc_4' in return_list:
out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64
x, skips, os = self.run_layer(x, self.dropout, skips, os)
x, skips, os = self.run_layer(x, self.enc5, skips, os)
if return_list and 'enc_5' in return_list:
out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32
if return_logits:
return x
x, skips, os = self.run_layer(x, self.dropout, skips, os)
if return_list is not None:
return x, skips, out_dict
return x, skips
def get_last_depth(self):
return self.last_channels
def get_input_depth(self):
return self.input_depth
class Decoder(nn.Module):
"""
Class for DarknetSeg. Subclasses PyTorch's own "nn" module
"""
def __init__(self, params, OS=32, feature_depth=1024):
super(Decoder, self).__init__()
self.backbone_OS = OS
self.backbone_feature_depth = feature_depth
self.drop_prob = params["dropout"]
self.bn_d = params["bn_d"]
self.index = 0
# stride play
self.strides = [2, 2, 2, 2, 2]
# check current stride
current_os = 1
for s in self.strides:
current_os *= s
# redo strides according to needed stride
for i, stride in enumerate(self.strides):
if int(current_os) != self.backbone_OS:
if stride == 2:
current_os /= 2
self.strides[i] = 1
if int(current_os) == self.backbone_OS:
break
# decoder
self.dec5 = self._make_dec_layer(BasicBlock,
[self.backbone_feature_depth, 512],
bn_d=self.bn_d,
stride=self.strides[0])
self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d,
stride=self.strides[1])
self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d,
stride=self.strides[2])
self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d,
stride=self.strides[3])
self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d,
stride=self.strides[4])
# layer list to execute with skips
self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1]
# for a bit of fun
self.dropout = nn.Dropout2d(self.drop_prob)
# last channels
self.last_channels = 32
def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2):
layers = []
# downsample
if stride == 2:
layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1],
kernel_size=[1, 4], stride=[1, 2],
padding=[0, 1])))
else:
layers.append(("conv", nn.Conv2d(planes[0], planes[1],
kernel_size=3, padding=1)))
layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d)))
layers.append(("relu", nn.LeakyReLU(0.1)))
# blocks
layers.append(("residual", block(planes[1], planes, bn_d)))
return nn.Sequential(OrderedDict(layers))
def run_layer(self, x, layer, skips, os):
feats = layer(x) # up
if feats.shape[-1] > x.shape[-1]:
os //= 2 # match skip
feats = feats + skips[os].detach() # add skip
x = feats
return x, skips, os
def forward(self, x, skips, return_logits=False, return_list=None):
os = self.backbone_OS
out_dict = {}
# run layers
x, skips, os = self.run_layer(x, self.dec5, skips, os)
if return_list and 'dec_4' in return_list:
out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64
x, skips, os = self.run_layer(x, self.dec4, skips, os)
if return_list and 'dec_3' in return_list:
out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128
x, skips, os = self.run_layer(x, self.dec3, skips, os)
if return_list and 'dec_2' in return_list:
out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256
x, skips, os = self.run_layer(x, self.dec2, skips, os)
if return_list and 'dec_1' in return_list:
out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512
x, skips, os = self.run_layer(x, self.dec1, skips, os)
if return_list and 'dec_0' in return_list:
out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024
logits = torch.clone(x).detach()
x = self.dropout(x)
if return_logits:
return x, logits
if return_list is not None:
return out_dict
return x
def get_last_depth(self):
return self.last_channels
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.backbone = Backbone(params=self.config["backbone"])
self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"],
feature_depth=self.backbone.get_last_depth())
def load_pretrained_weights(self, path):
w_dict = torch.load(path + "/backbone",
map_location=lambda storage, loc: storage)
self.backbone.load_state_dict(w_dict, strict=True)
w_dict = torch.load(path + "/segmentation_decoder",
map_location=lambda storage, loc: storage)
self.decoder.load_state_dict(w_dict, strict=True)
def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'):
if return_logits:
logits = self.backbone(x, return_logits)
logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze()
logits = torch.clone(logits).detach().cpu().numpy()
return logits
elif return_list is not None:
x, skips, enc_dict = self.backbone(x, return_list=return_list)
dec_dict = self.decoder(x, skips, return_list=return_list)
out_dict = {**enc_dict, **dec_dict}
return out_dict
elif return_final_logits:
assert agg_type in ['all', 'sector', 'depth']
y, skips = self.backbone(x)
y, logits = self.decoder(y, skips, True)
B, C, H, W = logits.shape
N = 16
# avg all
if agg_type == 'all':
logits = logits.mean([2, 3])
# avg in patch
elif agg_type == 'sector':
logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1)
# avg in row
elif agg_type == 'depth':
logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1)
logits = torch.clone(logits).detach().cpu().numpy()
return logits
else:
y, skips = self.backbone(x)
y = self.decoder(y, skips, False)
return y