curt-park's picture
Refactor code
1615d09
import torch.nn as nn
from isegm.model import ops
class ConvHead(nn.Module):
def __init__(
self,
out_channels,
in_channels=32,
num_layers=1,
kernel_size=3,
padding=1,
norm_layer=nn.BatchNorm2d,
):
super(ConvHead, self).__init__()
convhead = []
for i in range(num_layers):
convhead.extend(
[
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
nn.ReLU(),
norm_layer(in_channels)
if norm_layer is not None
else nn.Identity(),
]
)
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
self.convhead = nn.Sequential(*convhead)
def forward(self, *inputs):
return self.convhead(inputs[0])
class SepConvHead(nn.Module):
def __init__(
self,
num_outputs,
in_channels,
mid_channels,
num_layers=1,
kernel_size=3,
padding=1,
dropout_ratio=0.0,
dropout_indx=0,
norm_layer=nn.BatchNorm2d,
):
super(SepConvHead, self).__init__()
sepconvhead = []
for i in range(num_layers):
sepconvhead.append(
SeparableConv2d(
in_channels=in_channels if i == 0 else mid_channels,
out_channels=mid_channels,
dw_kernel=kernel_size,
dw_padding=padding,
norm_layer=norm_layer,
activation="relu",
)
)
if dropout_ratio > 0 and dropout_indx == i:
sepconvhead.append(nn.Dropout(dropout_ratio))
sepconvhead.append(
nn.Conv2d(
in_channels=mid_channels,
out_channels=num_outputs,
kernel_size=1,
padding=0,
)
)
self.layers = nn.Sequential(*sepconvhead)
def forward(self, *inputs):
x = inputs[0]
return self.layers(x)
class SeparableConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dw_kernel,
dw_padding,
dw_stride=1,
activation=None,
use_bias=False,
norm_layer=None,
):
super(SeparableConv2d, self).__init__()
_activation = ops.select_activation_function(activation)
self.body = nn.Sequential(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=dw_kernel,
stride=dw_stride,
padding=dw_padding,
bias=use_bias,
groups=in_channels,
),
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias
),
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
_activation(),
)
def forward(self, x):
return self.body(x)