Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn import init | |
class Embedding(nn.Module): | |
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): | |
super(Embedding, self).__init__() | |
self.in_timestep = in_timestep | |
self.in_planes = in_planes | |
self.embed_dim = embed_dim | |
self.mid_dim = mid_dim | |
self.eEmbed = nn.Linear( | |
in_timestep * in_planes, | |
self.embed_dim) # Embed encoder output to a word-embedding like | |
def forward(self, x): | |
x = x.flatten(1) | |
x = self.eEmbed(x) | |
return x | |
class Attn_Rnn_Block(nn.Module): | |
def __init__(self, featdim, hiddendim, embedding_dim, out_channels, | |
attndim): | |
super(Attn_Rnn_Block, self).__init__() | |
self.attndim = attndim | |
self.embedding_dim = embedding_dim | |
self.feat_embed = nn.Linear(featdim, attndim) | |
self.hidden_embed = nn.Linear(hiddendim, attndim) | |
self.attnfeat_embed = nn.Linear(attndim, 1) | |
self.gru = nn.GRU(input_size=featdim + self.embedding_dim, | |
hidden_size=hiddendim, | |
batch_first=True) | |
self.fc = nn.Linear(hiddendim, out_channels) | |
self.init_weights() | |
def init_weights(self): | |
init.normal_(self.hidden_embed.weight, std=0.01) | |
init.constant_(self.hidden_embed.bias, 0) | |
init.normal_(self.attnfeat_embed.weight, std=0.01) | |
init.constant_(self.attnfeat_embed.bias, 0) | |
def _attn(self, feat, h_state): | |
b, t, _ = feat.shape | |
feat = self.feat_embed(feat) | |
h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1) | |
h_state = h_state.expand(b, t, self.attndim) | |
sumTanh = torch.tanh(feat + h_state) | |
attn_w = self.attnfeat_embed(sumTanh).squeeze(-1) | |
attn_w = F.softmax(attn_w, dim=1).unsqueeze(1) | |
# [B,1,25] | |
return attn_w | |
def forward(self, feat, h_state, label_input): | |
attn_w = self._attn(feat, h_state) | |
attn_feat = attn_w @ feat | |
attn_feat = attn_feat.squeeze(1) | |
output, h_state = self.gru( | |
torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state) | |
pred = self.fc(output) | |
return pred, h_state | |
class ASTERDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
embedding_dim=256, | |
hiddendim=256, | |
attndim=256, | |
max_len=25, | |
seed=False, | |
time_step=32, | |
**kwargs): | |
super(ASTERDecoder, self).__init__() | |
self.num_classes = out_channels | |
self.bos = out_channels - 2 | |
self.eos = 0 | |
self.padding_idx = out_channels - 1 | |
self.seed = seed | |
if seed: | |
self.embeder = Embedding( | |
in_timestep=time_step, | |
in_planes=in_channels, | |
) | |
self.word_embedding = nn.Embedding(self.num_classes, | |
embedding_dim, | |
padding_idx=self.padding_idx) | |
self.attndim = attndim | |
self.hiddendim = hiddendim | |
self.max_seq_len = max_len + 1 | |
self.featdim = in_channels | |
self.attn_rnn_block = Attn_Rnn_Block( | |
featdim=self.featdim, | |
hiddendim=hiddendim, | |
embedding_dim=embedding_dim, | |
out_channels=out_channels - 2, | |
attndim=attndim, | |
) | |
self.embed_fc = nn.Linear(300, self.hiddendim) | |
def get_initial_state(self, embed, tile_times=1): | |
assert embed.shape[1] == 300 | |
state = self.embed_fc(embed) # N * sDim | |
if tile_times != 1: | |
state = state.unsqueeze(1) | |
trans_state = state.transpose(0, 1) | |
state = trans_state.tile([tile_times, 1, 1]) | |
trans_state = state.transpose(0, 1) | |
state = trans_state.reshape(-1, self.hiddendim) | |
state = state.unsqueeze(0) # 1 * N * sDim | |
return state | |
def forward(self, feat, data=None): | |
# b,25,512 | |
b = feat.size(0) | |
if self.seed: | |
embedding_vectors = self.embeder(feat) | |
h_state = self.get_initial_state(embedding_vectors) | |
else: | |
h_state = torch.zeros(1, b, self.hiddendim).to(feat.device) | |
outputs = [] | |
if self.training: | |
label = data[0] | |
label_embedding = self.word_embedding(label) # [B,25,256] | |
tokens = label_embedding[:, 0, :] | |
max_len = data[1].max() + 1 | |
else: | |
tokens = torch.full([b, 1], | |
self.bos, | |
device=feat.device, | |
dtype=torch.long) | |
tokens = self.word_embedding(tokens.squeeze(1)) | |
max_len = self.max_seq_len | |
pred, h_state = self.attn_rnn_block(feat, h_state, tokens) | |
outputs.append(pred) | |
dec_seq = torch.full((feat.shape[0], max_len), | |
self.padding_idx, | |
dtype=torch.int64, | |
device=feat.get_device()) | |
dec_seq[:, :1] = torch.argmax(pred, dim=-1) | |
for i in range(1, max_len): | |
if not self.training: | |
max_idx = torch.argmax(pred, dim=-1).squeeze(1) | |
tokens = self.word_embedding(max_idx) | |
dec_seq[:, i] = max_idx | |
if (dec_seq == self.eos).any(dim=-1).all(): | |
break | |
else: | |
tokens = label_embedding[:, i, :] | |
pred, h_state = self.attn_rnn_block(feat, h_state, tokens) | |
outputs.append(pred) | |
preds = torch.cat(outputs, 1) | |
if self.seed and self.training: | |
return [embedding_vectors, preds] | |
return preds if self.training else F.softmax(preds, -1) | |