File size: 7,225 Bytes
3c8ff2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import torch.nn as nn
import torch
from src.backbones.convlstm import ConvLSTM
class FPNConvLSTM(nn.Module):
def __init__(
self,
input_dim,
num_classes,
inconv=[32, 64],
n_levels=5,
n_channels=64,
hidden_size=88,
input_shape=(128, 128),
mid_conv=True,
pad_value=0,
):
"""
Feature Pyramid Network with ConvLSTM baseline.
Args:
input_dim (int): Number of channels in the input images.
num_classes (int): Number of classes.
inconv (List[int]): Widths of the input convolutional layers.
n_levels (int): Number of different levels in the feature pyramid.
n_channels (int): Number of channels for each channel of the pyramid.
hidden_size (int): Hidden size of the ConvLSTM.
input_shape (int,int): Shape (H,W) of the input images.
mid_conv (bool): If True, the feature pyramid is fed to a convolutional layer
to reduce dimensionality before being given to the ConvLSTM.
pad_value (float): Padding value (temporal) used by the dataloader.
"""
super(FPNConvLSTM, self).__init__()
self.pad_value = pad_value
self.inconv = ConvBlock(
nkernels=[input_dim] + inconv, norm="group", pad_value=pad_value
)
self.pyramid = PyramidBlock(
input_dim=inconv[-1],
n_channels=n_channels,
n_levels=n_levels,
pad_value=pad_value,
)
if mid_conv:
dim = n_channels * n_levels // 2
self.mid_conv = ConvBlock(
nkernels=[self.pyramid.out_channels, dim],
pad_value=pad_value,
norm="group",
)
else:
dim = self.pyramid.out_channels
self.mid_conv = None
self.convlstm = ConvLSTM(
input_dim=dim,
input_size=input_shape,
hidden_dim=hidden_size,
kernel_size=(3, 3),
return_all_layers=False,
)
self.outconv = nn.Conv2d(
in_channels=hidden_size, out_channels=num_classes, kernel_size=1
)
def forward(self, input, batch_positions=None):
pad_mask = (
(input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1)
) # BxT pad mask
pad_mask = pad_mask if pad_mask.any() else None
out = self.inconv.smart_forward(input)
out = self.pyramid.smart_forward(out)
if self.mid_conv is not None:
out = self.mid_conv.smart_forward(out)
_, out = self.convlstm(out, pad_mask=pad_mask)
out = out[0][1]
out = self.outconv(out)
return out
class TemporallySharedBlock(nn.Module):
def __init__(self, pad_value=None):
super(TemporallySharedBlock, self).__init__()
self.out_shape = None
self.pad_value = pad_value
def smart_forward(self, input):
if len(input.shape) == 4:
return self.forward(input)
else:
b, t, c, h, w = input.shape
if self.pad_value is not None:
dummy = torch.zeros(input.shape, device=input.device).float()
self.out_shape = self.forward(dummy.view(b * t, c, h, w)).shape
out = input.view(b * t, c, h, w)
if self.pad_value is not None:
pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1)
if pad_mask.any():
temp = (
torch.ones(
self.out_shape, device=input.device, requires_grad=False
)
* self.pad_value
)
temp[~pad_mask] = self.forward(out[~pad_mask])
out = temp
else:
out = self.forward(out)
else:
out = self.forward(out)
_, c, h, w = out.shape
out = out.view(b, t, c, h, w)
return out
class PyramidBlock(TemporallySharedBlock):
def __init__(self, input_dim, n_levels=5, n_channels=64, pad_value=None):
"""
Feature Pyramid Block. Performs atrous convolutions with different strides
and concatenates the resulting feature maps along the channel dimension.
Args:
input_dim (int): Number of channels in the input images.
n_levels (int): Number of levels.
n_channels (int): Number of channels per level.
pad_value (float): Padding value (temporal) used by the dataloader.
"""
super(PyramidBlock, self).__init__(pad_value=pad_value)
dilations = [2 ** i for i in range(n_levels - 1)]
self.inconv = nn.Conv2d(input_dim, n_channels, kernel_size=3, padding=1)
self.convs = nn.ModuleList(
[
nn.Conv2d(
in_channels=n_channels,
out_channels=n_channels,
kernel_size=3,
stride=1,
dilation=d,
padding=d,
padding_mode="reflect",
)
for d in dilations
]
)
self.out_channels = n_levels * n_channels
def forward(self, input):
out = self.inconv(input)
global_avg_pool = out.view(*out.shape[:2], -1).max(dim=-1)[0]
out = torch.cat([cv(out) for cv in self.convs], dim=1)
h, w = out.shape[-2:]
out = torch.cat(
[
out,
global_avg_pool.unsqueeze(-1)
.repeat(1, 1, h)
.unsqueeze(-1)
.repeat(1, 1, 1, w),
],
dim=1,
)
return out
class ConvLayer(nn.Module):
def __init__(self, nkernels, norm="batch", k=3, s=1, p=1, n_groups=4):
super(ConvLayer, self).__init__()
layers = []
if norm == "batch":
nl = nn.BatchNorm2d
elif norm == "instance":
nl = nn.InstanceNorm2d
elif norm == "group":
nl = lambda num_feats: nn.GroupNorm(
num_channels=num_feats, num_groups=n_groups
)
else:
nl = None
for i in range(len(nkernels) - 1):
layers.append(
nn.Conv2d(
in_channels=nkernels[i],
out_channels=nkernels[i + 1],
kernel_size=k,
padding=p,
stride=s,
padding_mode="reflect",
)
)
if nl is not None:
layers.append(nl(nkernels[i + 1]))
layers.append(nn.ReLU())
self.conv = nn.Sequential(*layers)
def forward(self, input):
return self.conv(input)
class ConvBlock(TemporallySharedBlock):
def __init__(self, nkernels, pad_value=None, norm="batch"):
super(ConvBlock, self).__init__(pad_value=pad_value)
self.conv = ConvLayer(nkernels=nkernels, norm=norm)
def forward(self, input):
return self.conv(input)
|