File size: 5,667 Bytes
c614b0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from engine.BiRefNet.config import Config
from engine.BiRefNet.models.modules.deform_conv import DeformableConv2d
config = Config()
class _ASPPModule(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, dilation):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(
in_channels,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False,
)
self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPP(nn.Module):
def __init__(self, in_channels=64, out_channels=None, output_stride=16):
super(ASPP, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
if output_stride == 16:
dilations = [1, 6, 12, 18]
elif output_stride == 8:
dilations = [1, 12, 24, 36]
else:
raise NotImplementedError
self.aspp1 = _ASPPModule(
in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]
)
self.aspp2 = _ASPPModule(
in_channels,
self.in_channelster,
3,
padding=dilations[1],
dilation=dilations[1],
)
self.aspp3 = _ASPPModule(
in_channels,
self.in_channelster,
3,
padding=dilations[2],
dilation=dilations[2],
)
self.aspp4 = _ASPPModule(
in_channels,
self.in_channelster,
3,
padding=dilations[3],
dilation=dilations[3],
)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
(
nn.BatchNorm2d(self.in_channelster)
if config.batch_size > 1
else nn.Identity()
),
nn.ReLU(inplace=True),
)
self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
self.bn1 = (
nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode="bilinear", align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
##################### Deformable
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(
in_channels,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False,
)
self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(
in_channels, self.in_channelster, 1, padding=0
)
self.aspp_deforms = nn.ModuleList(
[
_ASPPModuleDeformable(
in_channels,
self.in_channelster,
conv_size,
padding=int(conv_size // 2),
)
for conv_size in parallel_block_sizes
]
)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
(
nn.BatchNorm2d(self.in_channelster)
if config.batch_size > 1
else nn.Identity()
),
nn.ReLU(inplace=True),
)
self.conv1 = nn.Conv2d(
self.in_channelster * (2 + len(self.aspp_deforms)),
out_channels,
1,
bias=False,
)
self.bn1 = (
nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode="bilinear", align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
|