import torch.nn as nn import torch from src.backbones.convlstm import ConvLSTM class FPNConvLSTM(nn.Module): def __init__( self, input_dim, num_classes, inconv=[32, 64], n_levels=5, n_channels=64, hidden_size=88, input_shape=(128, 128), mid_conv=True, pad_value=0, ): """ Feature Pyramid Network with ConvLSTM baseline. Args: input_dim (int): Number of channels in the input images. num_classes (int): Number of classes. inconv (List[int]): Widths of the input convolutional layers. n_levels (int): Number of different levels in the feature pyramid. n_channels (int): Number of channels for each channel of the pyramid. hidden_size (int): Hidden size of the ConvLSTM. input_shape (int,int): Shape (H,W) of the input images. mid_conv (bool): If True, the feature pyramid is fed to a convolutional layer to reduce dimensionality before being given to the ConvLSTM. pad_value (float): Padding value (temporal) used by the dataloader. """ super(FPNConvLSTM, self).__init__() self.pad_value = pad_value self.inconv = ConvBlock( nkernels=[input_dim] + inconv, norm="group", pad_value=pad_value ) self.pyramid = PyramidBlock( input_dim=inconv[-1], n_channels=n_channels, n_levels=n_levels, pad_value=pad_value, ) if mid_conv: dim = n_channels * n_levels // 2 self.mid_conv = ConvBlock( nkernels=[self.pyramid.out_channels, dim], pad_value=pad_value, norm="group", ) else: dim = self.pyramid.out_channels self.mid_conv = None self.convlstm = ConvLSTM( input_dim=dim, input_size=input_shape, hidden_dim=hidden_size, kernel_size=(3, 3), return_all_layers=False, ) self.outconv = nn.Conv2d( in_channels=hidden_size, out_channels=num_classes, kernel_size=1 ) def forward(self, input, batch_positions=None): pad_mask = ( (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) ) # BxT pad mask pad_mask = pad_mask if pad_mask.any() else None out = self.inconv.smart_forward(input) out = self.pyramid.smart_forward(out) if self.mid_conv is not None: out = self.mid_conv.smart_forward(out) _, out = self.convlstm(out, pad_mask=pad_mask) out = out[0][1] out = self.outconv(out) return out class TemporallySharedBlock(nn.Module): def __init__(self, pad_value=None): super(TemporallySharedBlock, self).__init__() self.out_shape = None self.pad_value = pad_value def smart_forward(self, input): if len(input.shape) == 4: return self.forward(input) else: b, t, c, h, w = input.shape if self.pad_value is not None: dummy = torch.zeros(input.shape, device=input.device).float() self.out_shape = self.forward(dummy.view(b * t, c, h, w)).shape out = input.view(b * t, c, h, w) if self.pad_value is not None: pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) if pad_mask.any(): temp = ( torch.ones( self.out_shape, device=input.device, requires_grad=False ) * self.pad_value ) temp[~pad_mask] = self.forward(out[~pad_mask]) out = temp else: out = self.forward(out) else: out = self.forward(out) _, c, h, w = out.shape out = out.view(b, t, c, h, w) return out class PyramidBlock(TemporallySharedBlock): def __init__(self, input_dim, n_levels=5, n_channels=64, pad_value=None): """ Feature Pyramid Block. Performs atrous convolutions with different strides and concatenates the resulting feature maps along the channel dimension. Args: input_dim (int): Number of channels in the input images. n_levels (int): Number of levels. n_channels (int): Number of channels per level. pad_value (float): Padding value (temporal) used by the dataloader. """ super(PyramidBlock, self).__init__(pad_value=pad_value) dilations = [2 ** i for i in range(n_levels - 1)] self.inconv = nn.Conv2d(input_dim, n_channels, kernel_size=3, padding=1) self.convs = nn.ModuleList( [ nn.Conv2d( in_channels=n_channels, out_channels=n_channels, kernel_size=3, stride=1, dilation=d, padding=d, padding_mode="reflect", ) for d in dilations ] ) self.out_channels = n_levels * n_channels def forward(self, input): out = self.inconv(input) global_avg_pool = out.view(*out.shape[:2], -1).max(dim=-1)[0] out = torch.cat([cv(out) for cv in self.convs], dim=1) h, w = out.shape[-2:] out = torch.cat( [ out, global_avg_pool.unsqueeze(-1) .repeat(1, 1, h) .unsqueeze(-1) .repeat(1, 1, 1, w), ], dim=1, ) return out class ConvLayer(nn.Module): def __init__(self, nkernels, norm="batch", k=3, s=1, p=1, n_groups=4): super(ConvLayer, self).__init__() layers = [] if norm == "batch": nl = nn.BatchNorm2d elif norm == "instance": nl = nn.InstanceNorm2d elif norm == "group": nl = lambda num_feats: nn.GroupNorm( num_channels=num_feats, num_groups=n_groups ) else: nl = None for i in range(len(nkernels) - 1): layers.append( nn.Conv2d( in_channels=nkernels[i], out_channels=nkernels[i + 1], kernel_size=k, padding=p, stride=s, padding_mode="reflect", ) ) if nl is not None: layers.append(nl(nkernels[i + 1])) layers.append(nn.ReLU()) self.conv = nn.Sequential(*layers) def forward(self, input): return self.conv(input) class ConvBlock(TemporallySharedBlock): def __init__(self, nkernels, pad_value=None, norm="batch"): super(ConvBlock, self).__init__(pad_value=pad_value) self.conv = ConvLayer(nkernels=nkernels, norm=norm) def forward(self, input): return self.conv(input)