""" Taken from https://github.com/roserustowicz/crop-type-mapping/ Implementation by the authors of the paper : "Semantic Segmentation of crop type in Africa: A novel Dataset and analysis of deep learning methods" R.M. Rustowicz et al. Slightly modified to support image sequences of varying length in the same batch. """ import torch import torch.nn as nn def conv_block(in_dim, middle_dim, out_dim): model = nn.Sequential( nn.Conv3d(in_dim, middle_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm3d(middle_dim), nn.LeakyReLU(inplace=True), nn.Conv3d(middle_dim, out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm3d(out_dim), nn.LeakyReLU(inplace=True), ) return model def center_in(in_dim, out_dim): model = nn.Sequential( nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm3d(out_dim), nn.LeakyReLU(inplace=True)) return model def center_out(in_dim, out_dim): model = nn.Sequential( nn.Conv3d(in_dim, in_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm3d(in_dim), nn.LeakyReLU(inplace=True), nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1)) return model def up_conv_block(in_dim, out_dim): model = nn.Sequential( nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm3d(out_dim), nn.LeakyReLU(inplace=True), ) return model class UNet3D(nn.Module): def __init__(self, in_channel, n_classes, feats=8, pad_value=None, zero_pad=True, out_nonlin=False): super(UNet3D, self).__init__() self.in_channel = in_channel self.n_classes = n_classes self.pad_value = pad_value self.zero_pad = zero_pad self.en3 = conv_block(in_channel, feats * 4, feats * 4) self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) self.en4 = conv_block(feats * 4, feats * 8, feats * 8) self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) self.center_in = center_in(feats * 8, feats * 16) self.center_out = center_out(feats * 16, feats * 8) self.dc4 = conv_block(feats * 16, feats * 8, feats * 8) self.trans3 = up_conv_block(feats * 8, feats * 4) self.dc3 = conv_block(feats * 8, feats * 4, feats * 2) self.final = nn.Conv3d(feats * 2, n_classes, kernel_size=3, stride=1, padding=1) if out_nonlin: self.out_sigm = nn.Sigmoid() # this is for predicting mean values in [0, 1] self.out_relu = nn.ReLU() # this is for predicting var values > 0 # self.fn = nn.Linear(timesteps, 1) # self.logsoftmax = nn.LogSoftmax(dim=1) # self.dropout = nn.Dropout(p=dropout, inplace=True) def forward(self, x, batch_positions=None): out = x.permute(0, 2, 1, 3, 4) # x was BxTxCxHxW, now BxCxTxHxW if self.pad_value is not None: pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=1) # BxT pad mask if self.zero_pad: out[out == self.pad_value] = 0 en3 = self.en3(out) pool_3 = self.pool_3(en3) en4 = self.en4(pool_3) pool_4 = self.pool_4(en4) center_in = self.center_in(pool_4) center_out = self.center_out(center_in) concat4 = torch.cat([center_out, en4[:, :, :center_out.shape[2], :, :]], dim=1) dc4 = self.dc4(concat4) trans3 = self.trans3(dc4) concat3 = torch.cat([trans3, en3[:, :, :trans3.shape[2], :, :]], dim=1) dc3 = self.dc3(concat3) final = self.final(dc3) final = final.permute(0, 1, 3, 4, 2) # BxCxHxWxT # shape_num = final.shape[0:4] # final = final.reshape(-1,final.shape[4]) if self.pad_value is not None: if pad_mask.any(): # masked mean pad_mask = pad_mask[:, :final.shape[-1]] #match new temporal length (due to pooling) pad_mask = ~pad_mask # 0 on padded values out = (final.permute(1, 2, 3, 0, 4) * pad_mask[None, None, None, :, :]).sum(dim=-1) / pad_mask.sum( dim=-1)[None, None, None, :] out = out.permute(3, 0, 1, 2) else: out = final.mean(dim=-1) else: out = final.mean(dim=-1) if hasattr(self, 'out_sigm'): out_mean = self.out_sigm(out[:,:,:13,...]) # mean predictions out_std = self.out_relu(out[:,:,13:,...]) # var predictions # stack mean and var predictions out = torch.cat((out_mean, out_std), dim=2) # final = self.dropout(final) # final = self.fn(final) # final = final.reshape(shape_num) return out