ghlee94's picture
Init
2a13495
import torch
import torch.nn as nn
import torch.nn.functional as F
from segmentation_models_pytorch.base import modules
class PSPBlock(nn.Module):
def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True):
super().__init__()
if pool_size == 1:
use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
modules.Conv2dReLU(
in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm
),
)
def forward(self, x):
h, w = x.size(2), x.size(3)
x = self.pool(x)
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=True)
return x
class PSPModule(nn.Module):
def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True):
super().__init__()
self.blocks = nn.ModuleList(
[
PSPBlock(
in_channels,
in_channels // len(sizes),
size,
use_bathcnorm=use_bathcnorm,
)
for size in sizes
]
)
def forward(self, x):
xs = [block(x) for block in self.blocks] + [x]
x = torch.cat(xs, dim=1)
return x
class PSPDecoder(nn.Module):
def __init__(
self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2,
):
super().__init__()
self.psp = PSPModule(
in_channels=encoder_channels[-1],
sizes=(1, 2, 3, 6),
use_bathcnorm=use_batchnorm,
)
self.conv = modules.Conv2dReLU(
in_channels=encoder_channels[-1] * 2,
out_channels=out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
self.dropout = nn.Dropout2d(p=dropout)
def forward(self, *features):
x = features[-1]
x = self.psp(x)
x = self.conv(x)
x = self.dropout(x)
return x