Spaces:
Runtime error
Runtime error
from contextlib import ExitStack | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from isegm.model import ops | |
from .basic_blocks import SeparableConv2d | |
from .resnet import ResNetBackbone | |
class DeepLabV3Plus(nn.Module): | |
def __init__( | |
self, | |
backbone="resnet50", | |
norm_layer=nn.BatchNorm2d, | |
backbone_norm_layer=None, | |
ch=256, | |
project_dropout=0.5, | |
inference_mode=False, | |
**kwargs | |
): | |
super(DeepLabV3Plus, self).__init__() | |
if backbone_norm_layer is None: | |
backbone_norm_layer = norm_layer | |
self.backbone_name = backbone | |
self.norm_layer = norm_layer | |
self.backbone_norm_layer = backbone_norm_layer | |
self.inference_mode = False | |
self.ch = ch | |
self.aspp_in_channels = 2048 | |
self.skip_project_in_channels = 256 # layer 1 out_channels | |
self._kwargs = kwargs | |
if backbone == "resnet34": | |
self.aspp_in_channels = 512 | |
self.skip_project_in_channels = 64 | |
self.backbone = ResNetBackbone( | |
backbone=self.backbone_name, | |
pretrained_base=False, | |
norm_layer=self.backbone_norm_layer, | |
**kwargs | |
) | |
self.head = _DeepLabHead( | |
in_channels=ch + 32, | |
mid_channels=ch, | |
out_channels=ch, | |
norm_layer=self.norm_layer, | |
) | |
self.skip_project = _SkipProject( | |
self.skip_project_in_channels, 32, norm_layer=self.norm_layer | |
) | |
self.aspp = _ASPP( | |
in_channels=self.aspp_in_channels, | |
atrous_rates=[12, 24, 36], | |
out_channels=ch, | |
project_dropout=project_dropout, | |
norm_layer=self.norm_layer, | |
) | |
if inference_mode: | |
self.set_prediction_mode() | |
def load_pretrained_weights(self): | |
pretrained = ResNetBackbone( | |
backbone=self.backbone_name, | |
pretrained_base=True, | |
norm_layer=self.backbone_norm_layer, | |
**self._kwargs | |
) | |
backbone_state_dict = self.backbone.state_dict() | |
pretrained_state_dict = pretrained.state_dict() | |
backbone_state_dict.update(pretrained_state_dict) | |
self.backbone.load_state_dict(backbone_state_dict) | |
if self.inference_mode: | |
for param in self.backbone.parameters(): | |
param.requires_grad = False | |
def set_prediction_mode(self): | |
self.inference_mode = True | |
self.eval() | |
def forward(self, x, additional_features=None): | |
with ExitStack() as stack: | |
if self.inference_mode: | |
stack.enter_context(torch.no_grad()) | |
c1, _, c3, c4 = self.backbone(x, additional_features) | |
c1 = self.skip_project(c1) | |
x = self.aspp(c4) | |
x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True) | |
x = torch.cat((x, c1), dim=1) | |
x = self.head(x) | |
return (x,) | |
class _SkipProject(nn.Module): | |
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): | |
super(_SkipProject, self).__init__() | |
_activation = ops.select_activation_function("relu") | |
self.skip_project = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), | |
norm_layer(out_channels), | |
_activation(), | |
) | |
def forward(self, x): | |
return self.skip_project(x) | |
class _DeepLabHead(nn.Module): | |
def __init__( | |
self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d | |
): | |
super(_DeepLabHead, self).__init__() | |
self.block = nn.Sequential( | |
SeparableConv2d( | |
in_channels=in_channels, | |
out_channels=mid_channels, | |
dw_kernel=3, | |
dw_padding=1, | |
activation="relu", | |
norm_layer=norm_layer, | |
), | |
SeparableConv2d( | |
in_channels=mid_channels, | |
out_channels=mid_channels, | |
dw_kernel=3, | |
dw_padding=1, | |
activation="relu", | |
norm_layer=norm_layer, | |
), | |
nn.Conv2d( | |
in_channels=mid_channels, out_channels=out_channels, kernel_size=1 | |
), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class _ASPP(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
atrous_rates, | |
out_channels=256, | |
project_dropout=0.5, | |
norm_layer=nn.BatchNorm2d, | |
): | |
super(_ASPP, self).__init__() | |
b0 = nn.Sequential( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias=False, | |
), | |
norm_layer(out_channels), | |
nn.ReLU(), | |
) | |
rate1, rate2, rate3 = tuple(atrous_rates) | |
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) | |
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) | |
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) | |
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) | |
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) | |
project = [ | |
nn.Conv2d( | |
in_channels=5 * out_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias=False, | |
), | |
norm_layer(out_channels), | |
nn.ReLU(), | |
] | |
if project_dropout > 0: | |
project.append(nn.Dropout(project_dropout)) | |
self.project = nn.Sequential(*project) | |
def forward(self, x): | |
x = torch.cat([block(x) for block in self.concurent], dim=1) | |
return self.project(x) | |
class _AsppPooling(nn.Module): | |
def __init__(self, in_channels, out_channels, norm_layer): | |
super(_AsppPooling, self).__init__() | |
self.gap = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias=False, | |
), | |
norm_layer(out_channels), | |
nn.ReLU(), | |
) | |
def forward(self, x): | |
pool = self.gap(x) | |
return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True) | |
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): | |
block = nn.Sequential( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
padding=atrous_rate, | |
dilation=atrous_rate, | |
bias=False, | |
), | |
norm_layer(out_channels), | |
nn.ReLU(), | |
) | |
return block | |