import torch import torch.nn as nn from torch.nn import functional as F from torch.nn import init class Embedding(nn.Module): def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): super(Embedding, self).__init__() self.in_timestep = in_timestep self.in_planes = in_planes self.embed_dim = embed_dim self.mid_dim = mid_dim self.eEmbed = nn.Linear( in_timestep * in_planes, self.embed_dim) # Embed encoder output to a word-embedding like def forward(self, x): x = x.flatten(1) x = self.eEmbed(x) return x class Attn_Rnn_Block(nn.Module): def __init__(self, featdim, hiddendim, embedding_dim, out_channels, attndim): super(Attn_Rnn_Block, self).__init__() self.attndim = attndim self.embedding_dim = embedding_dim self.feat_embed = nn.Linear(featdim, attndim) self.hidden_embed = nn.Linear(hiddendim, attndim) self.attnfeat_embed = nn.Linear(attndim, 1) self.gru = nn.GRU(input_size=featdim + self.embedding_dim, hidden_size=hiddendim, batch_first=True) self.fc = nn.Linear(hiddendim, out_channels) self.init_weights() def init_weights(self): init.normal_(self.hidden_embed.weight, std=0.01) init.constant_(self.hidden_embed.bias, 0) init.normal_(self.attnfeat_embed.weight, std=0.01) init.constant_(self.attnfeat_embed.bias, 0) def _attn(self, feat, h_state): b, t, _ = feat.shape feat = self.feat_embed(feat) h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1) h_state = h_state.expand(b, t, self.attndim) sumTanh = torch.tanh(feat + h_state) attn_w = self.attnfeat_embed(sumTanh).squeeze(-1) attn_w = F.softmax(attn_w, dim=1).unsqueeze(1) # [B,1,25] return attn_w def forward(self, feat, h_state, label_input): attn_w = self._attn(feat, h_state) attn_feat = attn_w @ feat attn_feat = attn_feat.squeeze(1) output, h_state = self.gru( torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state) pred = self.fc(output) return pred, h_state class ASTERDecoder(nn.Module): def __init__(self, in_channels, out_channels, embedding_dim=256, hiddendim=256, attndim=256, max_len=25, seed=False, time_step=32, **kwargs): super(ASTERDecoder, self).__init__() self.num_classes = out_channels self.bos = out_channels - 2 self.eos = 0 self.padding_idx = out_channels - 1 self.seed = seed if seed: self.embeder = Embedding( in_timestep=time_step, in_planes=in_channels, ) self.word_embedding = nn.Embedding(self.num_classes, embedding_dim, padding_idx=self.padding_idx) self.attndim = attndim self.hiddendim = hiddendim self.max_seq_len = max_len + 1 self.featdim = in_channels self.attn_rnn_block = Attn_Rnn_Block( featdim=self.featdim, hiddendim=hiddendim, embedding_dim=embedding_dim, out_channels=out_channels - 2, attndim=attndim, ) self.embed_fc = nn.Linear(300, self.hiddendim) def get_initial_state(self, embed, tile_times=1): assert embed.shape[1] == 300 state = self.embed_fc(embed) # N * sDim if tile_times != 1: state = state.unsqueeze(1) trans_state = state.transpose(0, 1) state = trans_state.tile([tile_times, 1, 1]) trans_state = state.transpose(0, 1) state = trans_state.reshape(-1, self.hiddendim) state = state.unsqueeze(0) # 1 * N * sDim return state def forward(self, feat, data=None): # b,25,512 b = feat.size(0) if self.seed: embedding_vectors = self.embeder(feat) h_state = self.get_initial_state(embedding_vectors) else: h_state = torch.zeros(1, b, self.hiddendim).to(feat.device) outputs = [] if self.training: label = data[0] label_embedding = self.word_embedding(label) # [B,25,256] tokens = label_embedding[:, 0, :] max_len = data[1].max() + 1 else: tokens = torch.full([b, 1], self.bos, device=feat.device, dtype=torch.long) tokens = self.word_embedding(tokens.squeeze(1)) max_len = self.max_seq_len pred, h_state = self.attn_rnn_block(feat, h_state, tokens) outputs.append(pred) dec_seq = torch.full((feat.shape[0], max_len), self.padding_idx, dtype=torch.int64, device=feat.get_device()) dec_seq[:, :1] = torch.argmax(pred, dim=-1) for i in range(1, max_len): if not self.training: max_idx = torch.argmax(pred, dim=-1).squeeze(1) tokens = self.word_embedding(max_idx) dec_seq[:, i] = max_idx if (dec_seq == self.eos).any(dim=-1).all(): break else: tokens = label_embedding[:, i, :] pred, h_state = self.attn_rnn_block(feat, h_state, tokens) outputs.append(pred) preds = torch.cat(outputs, 1) if self.seed and self.training: return [embedding_vectors, preds] return preds if self.training else F.softmax(preds, -1)