Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Conv3x3GNReLU(nn.Module): | |
def __init__(self, in_channels, out_channels, upsample=False): | |
super().__init__() | |
self.upsample = upsample | |
self.block = nn.Sequential( | |
nn.Conv2d( | |
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False | |
), | |
nn.GroupNorm(32, out_channels), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
x = self.block(x) | |
if self.upsample: | |
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) | |
return x | |
class DepthwiseSeparableConv2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): | |
super().__init__() | |
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels) | |
self.pointwise = nn.Conv2d(in_channels, out_channels, 1) | |
def forward(self, x): | |
x = self.depthwise(x) | |
x = self.pointwise(x) | |
return x | |
class LightFPNBlock(nn.Module): | |
def __init__(self, pyramid_channels, skip_channels): | |
super().__init__() | |
self.skip_conv = DepthwiseSeparableConv2d(skip_channels, pyramid_channels, kernel_size=1) | |
def forward(self, x, skip=None): | |
x = F.interpolate(x, scale_factor=2, mode="nearest") | |
skip = self.skip_conv(skip) | |
x = x + skip | |
return x | |
class SegmentationBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, n_upsamples=0): | |
super().__init__() | |
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] | |
if n_upsamples > 1: | |
for _ in range(1, n_upsamples): | |
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) | |
self.block = nn.Sequential(*blocks) | |
def forward(self, x): | |
return self.block(x) | |
class MergeBlock(nn.Module): | |
def __init__(self, policy): | |
super().__init__() | |
if policy not in ["add", "cat"]: | |
raise ValueError( | |
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy) | |
) | |
self.policy = policy | |
def forward(self, x): | |
if self.policy == "add": | |
return sum(x) | |
elif self.policy == "cat": | |
return torch.cat(x, dim=1) | |
else: | |
raise ValueError( | |
"`merge_policy` must be one of: ['add', 'cat'], got {}".format( | |
self.policy | |
) | |
) | |
class FPNDecoder(nn.Module): | |
def __init__( | |
self, | |
encoder_channels, | |
encoder_depth=5, | |
pyramid_channels=256, | |
segmentation_channels=128, | |
dropout=0.2, | |
merge_policy="add", | |
): | |
super().__init__() | |
self.out_channels = ( | |
segmentation_channels | |
if merge_policy == "add" | |
else segmentation_channels * 4 | |
) | |
if encoder_depth < 3: | |
raise ValueError( | |
"Encoder depth for FPN decoder cannot be less than 3, got {}.".format( | |
encoder_depth | |
) | |
) | |
encoder_channels = encoder_channels[::-1] | |
encoder_channels = encoder_channels[: encoder_depth + 1] | |
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) | |
self.p4 = LightFPNBlock(pyramid_channels, encoder_channels[1]) | |
self.p3 = LightFPNBlock(pyramid_channels, encoder_channels[2]) | |
self.p2 = LightFPNBlock(pyramid_channels, encoder_channels[3]) | |
self.seg_blocks = nn.ModuleList( | |
[ | |
SegmentationBlock( | |
pyramid_channels, segmentation_channels, n_upsamples=n_upsamples | |
) | |
for n_upsamples in [3, 2, 1, 0] | |
] | |
) | |
self.merge = MergeBlock(merge_policy) | |
self.dropout = nn.Dropout2d(p=dropout, inplace=True) | |
def forward(self, *features): | |
c2, c3, c4, c5 = features[-4:] | |
p5 = self.p5(c5) | |
p4 = self.p4(p5, c4) | |
p3 = self.p3(p4, c3) | |
p2 = self.p2(p3, c2) | |
feature_pyramid = [ | |
seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2]) | |
] | |
x = self.merge(feature_pyramid) | |
x = self.dropout(x) | |
return x | |