Spaces:
Running
Running
#!/usr/bin/env python3 | |
# This file is covered by the LICENSE file in the root of this project. | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BasicBlock(nn.Module): | |
def __init__(self, inplanes, planes, bn_d=0.1): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, | |
stride=1, padding=0, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) | |
self.relu1 = nn.LeakyReLU(0.1) | |
self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, | |
stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) | |
self.relu2 = nn.LeakyReLU(0.1) | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu1(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu2(out) | |
out += residual | |
return out | |
# ****************************************************************************** | |
# number of layers per model | |
model_blocks = { | |
21: [1, 1, 2, 2, 1], | |
53: [1, 2, 8, 8, 4], | |
} | |
class Backbone(nn.Module): | |
""" | |
Class for DarknetSeg. Subclasses PyTorch's own "nn" module | |
""" | |
def __init__(self, params): | |
super(Backbone, self).__init__() | |
self.use_range = params["input_depth"]["range"] | |
self.use_xyz = params["input_depth"]["xyz"] | |
self.use_remission = params["input_depth"]["remission"] | |
self.drop_prob = params["dropout"] | |
self.bn_d = params["bn_d"] | |
self.OS = params["OS"] | |
self.layers = params["extra"]["layers"] | |
# input depth calc | |
self.input_depth = 0 | |
self.input_idxs = [] | |
if self.use_range: | |
self.input_depth += 1 | |
self.input_idxs.append(0) | |
if self.use_xyz: | |
self.input_depth += 3 | |
self.input_idxs.extend([1, 2, 3]) | |
if self.use_remission: | |
self.input_depth += 1 | |
self.input_idxs.append(4) | |
# stride play | |
self.strides = [2, 2, 2, 2, 2] | |
# check current stride | |
current_os = 1 | |
for s in self.strides: | |
current_os *= s | |
# make the new stride | |
if self.OS > current_os: | |
print("Can't do OS, ", self.OS, | |
" because it is bigger than original ", current_os) | |
else: | |
# redo strides according to needed stride | |
for i, stride in enumerate(reversed(self.strides), 0): | |
if int(current_os) != self.OS: | |
if stride == 2: | |
current_os /= 2 | |
self.strides[-1 - i] = 1 | |
if int(current_os) == self.OS: | |
break | |
# check that darknet exists | |
assert self.layers in model_blocks.keys() | |
# generate layers depending on darknet type | |
self.blocks = model_blocks[self.layers] | |
# input layer | |
self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3, | |
stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d) | |
self.relu1 = nn.LeakyReLU(0.1) | |
# encoder | |
self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0], | |
stride=self.strides[0], bn_d=self.bn_d) | |
self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1], | |
stride=self.strides[1], bn_d=self.bn_d) | |
self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2], | |
stride=self.strides[2], bn_d=self.bn_d) | |
self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3], | |
stride=self.strides[3], bn_d=self.bn_d) | |
self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4], | |
stride=self.strides[4], bn_d=self.bn_d) | |
# for a bit of fun | |
self.dropout = nn.Dropout2d(self.drop_prob) | |
# last channels | |
self.last_channels = 1024 | |
# make layer useful function | |
def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1): | |
layers = [] | |
# downsample | |
layers.append(("conv", nn.Conv2d(planes[0], planes[1], | |
kernel_size=3, | |
stride=[1, stride], dilation=1, | |
padding=1, bias=False))) | |
layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) | |
layers.append(("relu", nn.LeakyReLU(0.1))) | |
# blocks | |
inplanes = planes[1] | |
for i in range(0, blocks): | |
layers.append(("residual_{}".format(i), | |
block(inplanes, planes, bn_d))) | |
return nn.Sequential(OrderedDict(layers)) | |
def run_layer(self, x, layer, skips, os): | |
y = layer(x) | |
if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: | |
skips[os] = x.detach() | |
os *= 2 | |
x = y | |
return x, skips, os | |
def forward(self, x, return_logits=False, return_list=None): | |
# filter input | |
x = x[:, self.input_idxs] | |
# run cnn | |
# store for skip connections | |
skips = {} | |
out_dict = {} | |
os = 1 | |
# first layer | |
x, skips, os = self.run_layer(x, self.conv1, skips, os) | |
x, skips, os = self.run_layer(x, self.bn1, skips, os) | |
x, skips, os = self.run_layer(x, self.relu1, skips, os) | |
if return_list and 'enc_0' in return_list: | |
out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024 | |
# all encoder blocks with intermediate dropouts | |
x, skips, os = self.run_layer(x, self.enc1, skips, os) | |
if return_list and 'enc_1' in return_list: | |
out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512 | |
x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
x, skips, os = self.run_layer(x, self.enc2, skips, os) | |
if return_list and 'enc_2' in return_list: | |
out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256 | |
x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
x, skips, os = self.run_layer(x, self.enc3, skips, os) | |
if return_list and 'enc_3' in return_list: | |
out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128 | |
x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
x, skips, os = self.run_layer(x, self.enc4, skips, os) | |
if return_list and 'enc_4' in return_list: | |
out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64 | |
x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
x, skips, os = self.run_layer(x, self.enc5, skips, os) | |
if return_list and 'enc_5' in return_list: | |
out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32 | |
if return_logits: | |
return x | |
x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
if return_list is not None: | |
return x, skips, out_dict | |
return x, skips | |
def get_last_depth(self): | |
return self.last_channels | |
def get_input_depth(self): | |
return self.input_depth | |
class Decoder(nn.Module): | |
""" | |
Class for DarknetSeg. Subclasses PyTorch's own "nn" module | |
""" | |
def __init__(self, params, OS=32, feature_depth=1024): | |
super(Decoder, self).__init__() | |
self.backbone_OS = OS | |
self.backbone_feature_depth = feature_depth | |
self.drop_prob = params["dropout"] | |
self.bn_d = params["bn_d"] | |
self.index = 0 | |
# stride play | |
self.strides = [2, 2, 2, 2, 2] | |
# check current stride | |
current_os = 1 | |
for s in self.strides: | |
current_os *= s | |
# redo strides according to needed stride | |
for i, stride in enumerate(self.strides): | |
if int(current_os) != self.backbone_OS: | |
if stride == 2: | |
current_os /= 2 | |
self.strides[i] = 1 | |
if int(current_os) == self.backbone_OS: | |
break | |
# decoder | |
self.dec5 = self._make_dec_layer(BasicBlock, | |
[self.backbone_feature_depth, 512], | |
bn_d=self.bn_d, | |
stride=self.strides[0]) | |
self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d, | |
stride=self.strides[1]) | |
self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d, | |
stride=self.strides[2]) | |
self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d, | |
stride=self.strides[3]) | |
self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d, | |
stride=self.strides[4]) | |
# layer list to execute with skips | |
self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1] | |
# for a bit of fun | |
self.dropout = nn.Dropout2d(self.drop_prob) | |
# last channels | |
self.last_channels = 32 | |
def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2): | |
layers = [] | |
# downsample | |
if stride == 2: | |
layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1], | |
kernel_size=[1, 4], stride=[1, 2], | |
padding=[0, 1]))) | |
else: | |
layers.append(("conv", nn.Conv2d(planes[0], planes[1], | |
kernel_size=3, padding=1))) | |
layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) | |
layers.append(("relu", nn.LeakyReLU(0.1))) | |
# blocks | |
layers.append(("residual", block(planes[1], planes, bn_d))) | |
return nn.Sequential(OrderedDict(layers)) | |
def run_layer(self, x, layer, skips, os): | |
feats = layer(x) # up | |
if feats.shape[-1] > x.shape[-1]: | |
os //= 2 # match skip | |
feats = feats + skips[os].detach() # add skip | |
x = feats | |
return x, skips, os | |
def forward(self, x, skips, return_logits=False, return_list=None): | |
os = self.backbone_OS | |
out_dict = {} | |
# run layers | |
x, skips, os = self.run_layer(x, self.dec5, skips, os) | |
if return_list and 'dec_4' in return_list: | |
out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64 | |
x, skips, os = self.run_layer(x, self.dec4, skips, os) | |
if return_list and 'dec_3' in return_list: | |
out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128 | |
x, skips, os = self.run_layer(x, self.dec3, skips, os) | |
if return_list and 'dec_2' in return_list: | |
out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256 | |
x, skips, os = self.run_layer(x, self.dec2, skips, os) | |
if return_list and 'dec_1' in return_list: | |
out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512 | |
x, skips, os = self.run_layer(x, self.dec1, skips, os) | |
if return_list and 'dec_0' in return_list: | |
out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024 | |
logits = torch.clone(x).detach() | |
x = self.dropout(x) | |
if return_logits: | |
return x, logits | |
if return_list is not None: | |
return out_dict | |
return x | |
def get_last_depth(self): | |
return self.last_channels | |
class Model(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.backbone = Backbone(params=self.config["backbone"]) | |
self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"], | |
feature_depth=self.backbone.get_last_depth()) | |
def load_pretrained_weights(self, path): | |
w_dict = torch.load(path + "/backbone", | |
map_location=lambda storage, loc: storage) | |
self.backbone.load_state_dict(w_dict, strict=True) | |
w_dict = torch.load(path + "/segmentation_decoder", | |
map_location=lambda storage, loc: storage) | |
self.decoder.load_state_dict(w_dict, strict=True) | |
def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'): | |
if return_logits: | |
logits = self.backbone(x, return_logits) | |
logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze() | |
logits = torch.clone(logits).detach().cpu().numpy() | |
return logits | |
elif return_list is not None: | |
x, skips, enc_dict = self.backbone(x, return_list=return_list) | |
dec_dict = self.decoder(x, skips, return_list=return_list) | |
out_dict = {**enc_dict, **dec_dict} | |
return out_dict | |
elif return_final_logits: | |
assert agg_type in ['all', 'sector', 'depth'] | |
y, skips = self.backbone(x) | |
y, logits = self.decoder(y, skips, True) | |
B, C, H, W = logits.shape | |
N = 16 | |
# avg all | |
if agg_type == 'all': | |
logits = logits.mean([2, 3]) | |
# avg in patch | |
elif agg_type == 'sector': | |
logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1) | |
# avg in row | |
elif agg_type == 'depth': | |
logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1) | |
logits = torch.clone(logits).detach().cpu().numpy() | |
return logits | |
else: | |
y, skips = self.backbone(x) | |
y = self.decoder(y, skips, False) | |
return y | |