import torch import torch.nn as nn import torch.nn.functional as F class CAM(nn.Module): ''' Convolutional Alignment Module ''' # Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc. # You can adapt it to any input size by changing the padding or stride. def __init__(self, channels_list=[64, 128, 256, 512], strides_list=[[2, 2], [1, 1], [1, 1]], in_shape=[8, 32], maxT=25, depth=4, num_channels=128): super(CAM, self).__init__() # cascade multiscale features fpn = [] for i in range(1, len(channels_list)): fpn.append( nn.Sequential( nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3), (strides_list[i - 1][0], strides_list[i - 1][1]), 1), nn.BatchNorm2d(channels_list[i]), nn.ReLU(True))) self.fpn = nn.Sequential(*fpn) # convolutional alignment # convs assert depth % 2 == 0, 'the depth of CAM must be a even number.' # in_shape = scales[-1] strides = [] conv_ksizes = [] deconv_ksizes = [] h, w = in_shape[0], in_shape[1] for i in range(0, int(depth / 2)): stride = [2] if 2**(depth / 2 - i) <= h else [1] stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1] strides.append(stride) conv_ksizes.append([3, 3]) deconv_ksizes.append([_**2 for _ in stride]) convs = [ nn.Sequential( nn.Conv2d(channels_list[-1], num_channels, tuple(conv_ksizes[0]), tuple(strides[0]), (int((conv_ksizes[0][0] - 1) / 2), int((conv_ksizes[0][1] - 1) / 2))), nn.BatchNorm2d(num_channels), nn.ReLU(True)) ] for i in range(1, int(depth / 2)): convs.append( nn.Sequential( nn.Conv2d(num_channels, num_channels, tuple(conv_ksizes[i]), tuple(strides[i]), (int((conv_ksizes[i][0] - 1) / 2), int((conv_ksizes[i][1] - 1) / 2))), nn.BatchNorm2d(num_channels), nn.ReLU(True))) self.convs = nn.Sequential(*convs) # deconvs deconvs = [] for i in range(1, int(depth / 2)): deconvs.append( nn.Sequential( nn.ConvTranspose2d( num_channels, num_channels, tuple(deconv_ksizes[int(depth / 2) - i]), tuple(strides[int(depth / 2) - i]), (int(deconv_ksizes[int(depth / 2) - i][0] / 4.), int(deconv_ksizes[int(depth / 2) - i][1] / 4.))), nn.BatchNorm2d(num_channels), nn.ReLU(True))) deconvs.append( nn.Sequential( nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]), tuple(strides[0]), (int(deconv_ksizes[0][0] / 4.), int(deconv_ksizes[0][1] / 4.))), nn.Sigmoid())) self.deconvs = nn.Sequential(*deconvs) def forward(self, input): x = input[0] for i in range(0, len(self.fpn)): # print(self.fpn[i](x).shape, input[i+1].shape) x = self.fpn[i](x) + input[i + 1] conv_feats = [] for i in range(0, len(self.convs)): x = self.convs[i](x) conv_feats.append(x) for i in range(0, len(self.deconvs) - 1): x = self.deconvs[i](x) x = x + conv_feats[len(conv_feats) - 2 - i] x = self.deconvs[-1](x) return x class CAMSimp(nn.Module): def __init__(self, maxT=25, num_channels=128): super(CAMSimp, self).__init__() self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0), nn.Sigmoid()) def forward(self, x): x = self.conv(x) return x class DANDecoder(nn.Module): ''' Decoupled Text Decoder ''' def __init__(self, out_channels, in_channels, use_cam=True, max_len=25, channels_list=[64, 128, 256, 512], strides_list=[[2, 2], [1, 1], [1, 1]], in_shape=[8, 32], depth=4, dropout=0.3, **kwargs): super(DANDecoder, self).__init__() self.eos = 0 self.bos = out_channels - 2 self.ignore_index = out_channels - 1 nchannel = in_channels self.nchannel = in_channels self.use_cam = use_cam if use_cam: self.cam = CAM(channels_list=channels_list, strides_list=strides_list, in_shape=in_shape, maxT=max_len + 1, depth=depth, num_channels=nchannel) else: self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel) self.pre_lstm = nn.LSTM(nchannel, int(nchannel / 2), bidirectional=True) self.rnn = nn.GRUCell(nchannel * 2, nchannel) self.generator = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(nchannel, out_channels - 2)) self.char_embeddings = nn.Embedding(out_channels, embedding_dim=in_channels, padding_idx=out_channels - 1) def forward(self, inputs, data=None): A = self.cam(inputs) if isinstance(inputs, list): feature = inputs[-1] else: feature = inputs nB, nC, nH, nW = feature.shape nT = A.shape[1] # Normalize A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1) # weighted sum C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW) C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) # T, B, C C, _ = self.pre_lstm(C) # T, B, C C = F.dropout(C, p=0.3, training=self.training) if self.training: text = data[0] text_length = data[-1] nsteps = int(text_length.max()) gru_res = torch.zeros_like(C) hidden = torch.zeros(nB, self.nchannel).type_as(C.data) prev_emb = self.char_embeddings(text[:, 0]) for i in range(0, nsteps + 1): hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1), hidden) gru_res[i, :, :] = hidden prev_emb = self.char_embeddings(text[:, i + 1]) gru_res = self.generator(gru_res) return gru_res[:nsteps + 1, :, :].transpose(1, 0) else: gru_res = torch.zeros_like(C) hidden = torch.zeros(nB, self.nchannel).type_as(C.data) prev_emb = self.char_embeddings( torch.zeros(nB, dtype=torch.int64, device=feature.device) + self.bos) dec_seq = torch.full((nB, nT), self.ignore_index, dtype=torch.int64, device=feature.get_device()) for i in range(0, nT): hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1), hidden) gru_res[i, :, :] = hidden mid_res = self.generator(hidden).argmax(-1) dec_seq[:, i] = mid_res.squeeze(0) if (dec_seq == self.eos).any(dim=-1).all(): break prev_emb = self.char_embeddings(mid_res) gru_res = self.generator(gru_res) return F.softmax(gru_res.transpose(1, 0), -1)