Spaces:
Runtime error
Runtime error
File size: 13,455 Bytes
231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvModule(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
activation="leaky_relu",
order=("conv", "norm", "act"),
act_inplace=True):
super().__init__()
self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.norm = nn.BatchNorm2d(out_channels)
if activation:
if activation == "leaky_relu":
self.act = nn.LeakyReLU(negative_slope=0.01, inplace=act_inplace)
elif activation == "silu":
self.act = nn.SiLU(inplace=act_inplace)
elif activation == "gelu":
self.act = nn.GELU()
else:
self.act = nn.Identity()
self.order = order
def forward(self, x):
for i in self.order:
x = getattr(self, i)(x)
return x
class BaseMergeCell(nn.Module):
"""The basic class for cells used in NAS-FPN and NAS-FCOS.
BaseMergeCell takes 2 inputs. After applying convolution
on them, they are resized to the target size. Then,
they go through binary_op, which depends on the type of cell.
If with_out_conv is True, the result of output will go through
another convolution layer.
Args:
in_channels (int): number of input channels in out_conv layer.
out_channels (int): number of output channels in out_conv layer.
with_out_conv (bool): Whether to use out_conv layer
out_conv_cfg (dict): Config dict for convolution layer, which should
contain "groups", "kernel_size", "padding", "bias" to build
out_conv layer.
out_norm_cfg (dict): Config dict for normalization layer in out_conv.
out_conv_order (tuple): The order of conv/norm/activation layers in
out_conv.
with_input1_conv (bool): Whether to use convolution on input1.
with_input2_conv (bool): Whether to use convolution on input2.
input_conv_cfg (dict): Config dict for building input1_conv layer and
input2_conv layer, which is expected to contain the type of
convolution.
Default: None, which means using conv2d.
input_norm_cfg (dict): Config dict for normalization layer in
input1_conv and input2_conv layer. Default: None.
upsample_mode (str): Interpolation method used to resize the output
of input1_conv and input2_conv to target size. Currently, we
support ['nearest', 'bilinear']. Default: 'nearest'.
"""
def __init__(self,
fused_channels=256,
out_channels=256,
with_out_conv=True,
out_conv_cfg=dict(
groups=1, kernel_size=3, padding=1, bias=True),
out_conv_order=('act', 'conv', 'norm'),
with_input1_conv=False,
with_input2_conv=False,
upsample_mode='nearest'):
super().__init__()
assert upsample_mode in ['nearest', 'bilinear']
self.with_out_conv = with_out_conv
self.with_input1_conv = with_input1_conv
self.with_input2_conv = with_input2_conv
self.upsample_mode = upsample_mode
if self.with_out_conv:
self.out_conv = ConvModule(
fused_channels,
out_channels,
**out_conv_cfg,
order=out_conv_order)
self.input1_conv = self._build_input_conv(
out_channels) if with_input1_conv else nn.Sequential()
self.input2_conv = self._build_input_conv(
out_channels) if with_input2_conv else nn.Sequential()
def _build_input_conv(self, channel):
return ConvModule(
channel,
channel,
3,
padding=1,
bias=True)
@abstractmethod
def _binary_op(self, x1, x2):
pass
def _resize(self, x, size):
if x.shape[-2:] == size:
return x
elif x.shape[-2:] < size:
return F.interpolate(x, size=size, mode=self.upsample_mode)
else:
if x.shape[-2] % size[-2] != 0 or x.shape[-1] % size[-1] != 0:
h, w = x.shape[-2:]
target_h, target_w = size
pad_h = math.ceil(h / target_h) * target_h - h
pad_w = math.ceil(w / target_w) * target_w - w
pad_l = pad_w // 2
pad_r = pad_w - pad_l
pad_t = pad_h // 2
pad_b = pad_h - pad_t
pad = (pad_l, pad_r, pad_t, pad_b)
x = F.pad(x, pad, mode='constant', value=0.0)
kernel_size = (x.shape[-2] // size[-2], x.shape[-1] // size[-1])
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x
def forward(self, x1, x2, out_size=None):
assert x1.shape[:2] == x2.shape[:2]
assert out_size is None or len(out_size) == 2
if out_size is None: # resize to larger one
out_size = max(x1.size()[2:], x2.size()[2:])
x1 = self.input1_conv(x1)
x2 = self.input2_conv(x2)
x1 = self._resize(x1, out_size)
x2 = self._resize(x2, out_size)
x = self._binary_op(x1, x2)
if self.with_out_conv:
x = self.out_conv(x)
return x
class SumCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
def _binary_op(self, x1, x2):
return x1 + x2
class ConcatCell(BaseMergeCell):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__(in_channels * 2, out_channels, **kwargs)
def _binary_op(self, x1, x2):
ret = torch.cat([x1, x2], dim=1)
return ret
class GlobalPoolingCell(BaseMergeCell):
def __init__(self, in_channels=None, out_channels=None, **kwargs):
super().__init__(in_channels, out_channels, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
def _binary_op(self, x1, x2):
x2_att = self.global_pool(x2).sigmoid()
return x2 + x2_att * x1
class Conv3x3GNReLU(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False):
super().__init__()
self.upsample = upsample
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False),
nn.GroupNorm(32, out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.block(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
return x
class SegmentationBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_upsamples=0):
super().__init__()
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
if n_upsamples > 1:
for _ in range(1, n_upsamples):
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
self.block = nn.Sequential(*blocks)
def forward(self, x):
return self.block(x)
class MergeBlock(nn.Module):
def __init__(self, policy):
super().__init__()
if policy not in ["add", "cat"]:
raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy))
self.policy = policy
def forward(self, x):
if self.policy == "add":
return sum(x)
elif self.policy == "cat":
return torch.cat(x, dim=1)
else:
raise ValueError("`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy))
class NASFPNDecoder(nn.Module):
"""NAS-FPN.
Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
for Object Detection <https://arxiv.org/abs/1904.07392>`_
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
depth (int): Number of output scales.
stack_times (int): The number of times the pyramid architecture will
be stacked.
"""
def __init__(self,
in_channels,
pyramid_channels=256,
segmentation_channels=128,
depth=5,
stack_times=3,
merge_policy="add",
deep_supervision=False):
super().__init__()
assert isinstance(in_channels, (list, tuple))
self.in_channels = in_channels
self.pyramid_channels = pyramid_channels
self.num_ins = len(in_channels) # num of input feature levels
self.depth = depth # num of output feature levels
assert self.num_ins == self.depth
self.stack_times = stack_times
self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 5
self.deep_supervision = deep_supervision
# add lateral connections
self.lateral_convs = nn.ModuleList()
for i in range(depth):
l_conv = ConvModule(
in_channels[i],
pyramid_channels,
1,
activation=None)
self.lateral_convs.append(l_conv)
# add NAS FPN connections
self.fpn_stages = nn.ModuleList()
for _ in range(self.stack_times):
stage = nn.ModuleDict()
# gp(p6, p4) -> p4_1
stage['gp_64_4'] = GlobalPoolingCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# sum(p4_1, p4) -> p4_2
stage['sum_44_4'] = SumCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# sum(p4_2, p3) -> p3_out
stage['sum_43_3'] = SumCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# sum(p3_out, p4_2) -> p4_out
stage['sum_34_4'] = SumCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# sum(p5, gp(p4_out, p3_out)) -> p5_out
stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
stage['sum_55_5'] = SumCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# sum(p7, gp(p5_out, p4_2)) -> p7_out
stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
stage['sum_77_7'] = SumCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
# gp(p7_out, p5_out) -> p6_out
stage['gp_75_6'] = GlobalPoolingCell(
in_channels=pyramid_channels,
out_channels=pyramid_channels)
self.fpn_stages.append(stage)
self.seg_blocks = nn.ModuleList(
[
SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
for n_upsamples in [4, 3, 2, 1, 0]
]
)
self.merge = MergeBlock(merge_policy)
def forward(self, *features):
"""Forward function."""
# build P1-P5
features = [
lateral_conv(features[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# This is actually P1-P5 but too lazy to change the naming scheme
p3, p4, p5, p6, p7 = features[-5:]
for stage in self.fpn_stages:
# gp(p6, p4) -> p4_1
p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
# sum(p4_1, p4) -> p4_2
p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
# sum(p4_2, p3) -> p3_out
p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
# sum(p3_out, p4_2) -> p4_out
p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
# sum(p5, gp(p4_out, p3_out)) -> p5_out
p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
# sum(p7, gp(p5_out, p4_2)) -> p7_out
p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
# gp(p7_out, p5_out) -> p6_out
p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p7, p6, p5, p4, p3])]
x = self.merge(feature_pyramid)
if self.deep_supervision and self.training:
return p4, p3, x
return x |