Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from feature_extractor_models.base import modules | |
class PSPBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): | |
super().__init__() | |
if pool_size == 1: | |
use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape | |
self.pool = nn.Sequential( | |
nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), | |
modules.Conv2dReLU( | |
in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm | |
), | |
) | |
def forward(self, x): | |
h, w = x.size(2), x.size(3) | |
x = self.pool(x) | |
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=True) | |
return x | |
class PSPModule(nn.Module): | |
def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): | |
super().__init__() | |
self.blocks = nn.ModuleList( | |
[ | |
PSPBlock( | |
in_channels, | |
in_channels // len(sizes), | |
size, | |
use_bathcnorm=use_bathcnorm, | |
) | |
for size in sizes | |
] | |
) | |
def forward(self, x): | |
xs = [block(x) for block in self.blocks] + [x] | |
x = torch.cat(xs, dim=1) | |
return x | |
class PSPDecoder(nn.Module): | |
def __init__( | |
self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2 | |
): | |
super().__init__() | |
self.psp = PSPModule( | |
in_channels=encoder_channels[-1], | |
sizes=(1, 2, 3, 6), | |
use_bathcnorm=use_batchnorm, | |
) | |
self.conv = modules.Conv2dReLU( | |
in_channels=encoder_channels[-1] * 2, | |
out_channels=out_channels, | |
kernel_size=1, | |
use_batchnorm=use_batchnorm, | |
) | |
self.dropout = nn.Dropout2d(p=dropout) | |
def forward(self, *features): | |
x = features[-1] | |
x = self.psp(x) | |
x = self.conv(x) | |
x = self.dropout(x) | |
return x | |