curt-park's picture
Refactor code
1615d09
from contextlib import ExitStack
import torch
import torch.nn.functional as F
from torch import nn
from isegm.model import ops
from .basic_blocks import SeparableConv2d
from .resnet import ResNetBackbone
class DeepLabV3Plus(nn.Module):
def __init__(
self,
backbone="resnet50",
norm_layer=nn.BatchNorm2d,
backbone_norm_layer=None,
ch=256,
project_dropout=0.5,
inference_mode=False,
**kwargs
):
super(DeepLabV3Plus, self).__init__()
if backbone_norm_layer is None:
backbone_norm_layer = norm_layer
self.backbone_name = backbone
self.norm_layer = norm_layer
self.backbone_norm_layer = backbone_norm_layer
self.inference_mode = False
self.ch = ch
self.aspp_in_channels = 2048
self.skip_project_in_channels = 256 # layer 1 out_channels
self._kwargs = kwargs
if backbone == "resnet34":
self.aspp_in_channels = 512
self.skip_project_in_channels = 64
self.backbone = ResNetBackbone(
backbone=self.backbone_name,
pretrained_base=False,
norm_layer=self.backbone_norm_layer,
**kwargs
)
self.head = _DeepLabHead(
in_channels=ch + 32,
mid_channels=ch,
out_channels=ch,
norm_layer=self.norm_layer,
)
self.skip_project = _SkipProject(
self.skip_project_in_channels, 32, norm_layer=self.norm_layer
)
self.aspp = _ASPP(
in_channels=self.aspp_in_channels,
atrous_rates=[12, 24, 36],
out_channels=ch,
project_dropout=project_dropout,
norm_layer=self.norm_layer,
)
if inference_mode:
self.set_prediction_mode()
def load_pretrained_weights(self):
pretrained = ResNetBackbone(
backbone=self.backbone_name,
pretrained_base=True,
norm_layer=self.backbone_norm_layer,
**self._kwargs
)
backbone_state_dict = self.backbone.state_dict()
pretrained_state_dict = pretrained.state_dict()
backbone_state_dict.update(pretrained_state_dict)
self.backbone.load_state_dict(backbone_state_dict)
if self.inference_mode:
for param in self.backbone.parameters():
param.requires_grad = False
def set_prediction_mode(self):
self.inference_mode = True
self.eval()
def forward(self, x, additional_features=None):
with ExitStack() as stack:
if self.inference_mode:
stack.enter_context(torch.no_grad())
c1, _, c3, c4 = self.backbone(x, additional_features)
c1 = self.skip_project(c1)
x = self.aspp(c4)
x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True)
x = torch.cat((x, c1), dim=1)
x = self.head(x)
return (x,)
class _SkipProject(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
super(_SkipProject, self).__init__()
_activation = ops.select_activation_function("relu")
self.skip_project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
norm_layer(out_channels),
_activation(),
)
def forward(self, x):
return self.skip_project(x)
class _DeepLabHead(nn.Module):
def __init__(
self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d
):
super(_DeepLabHead, self).__init__()
self.block = nn.Sequential(
SeparableConv2d(
in_channels=in_channels,
out_channels=mid_channels,
dw_kernel=3,
dw_padding=1,
activation="relu",
norm_layer=norm_layer,
),
SeparableConv2d(
in_channels=mid_channels,
out_channels=mid_channels,
dw_kernel=3,
dw_padding=1,
activation="relu",
norm_layer=norm_layer,
),
nn.Conv2d(
in_channels=mid_channels, out_channels=out_channels, kernel_size=1
),
)
def forward(self, x):
return self.block(x)
class _ASPP(nn.Module):
def __init__(
self,
in_channels,
atrous_rates,
out_channels=256,
project_dropout=0.5,
norm_layer=nn.BatchNorm2d,
):
super(_ASPP, self).__init__()
b0 = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
),
norm_layer(out_channels),
nn.ReLU(),
)
rate1, rate2, rate3 = tuple(atrous_rates)
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
project = [
nn.Conv2d(
in_channels=5 * out_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
),
norm_layer(out_channels),
nn.ReLU(),
]
if project_dropout > 0:
project.append(nn.Dropout(project_dropout))
self.project = nn.Sequential(*project)
def forward(self, x):
x = torch.cat([block(x) for block in self.concurent], dim=1)
return self.project(x)
class _AsppPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
super(_AsppPooling, self).__init__()
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
),
norm_layer(out_channels),
nn.ReLU(),
)
def forward(self, x):
pool = self.gap(x)
return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True)
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
block = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=atrous_rate,
dilation=atrous_rate,
bias=False,
),
norm_layer(out_channels),
nn.ReLU(),
)
return block