Spaces:
Running
Running
File size: 5,983 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
class Embedding(nn.Module):
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
super(Embedding, self).__init__()
self.in_timestep = in_timestep
self.in_planes = in_planes
self.embed_dim = embed_dim
self.mid_dim = mid_dim
self.eEmbed = nn.Linear(
in_timestep * in_planes,
self.embed_dim) # Embed encoder output to a word-embedding like
def forward(self, x):
x = x.flatten(1)
x = self.eEmbed(x)
return x
class Attn_Rnn_Block(nn.Module):
def __init__(self, featdim, hiddendim, embedding_dim, out_channels,
attndim):
super(Attn_Rnn_Block, self).__init__()
self.attndim = attndim
self.embedding_dim = embedding_dim
self.feat_embed = nn.Linear(featdim, attndim)
self.hidden_embed = nn.Linear(hiddendim, attndim)
self.attnfeat_embed = nn.Linear(attndim, 1)
self.gru = nn.GRU(input_size=featdim + self.embedding_dim,
hidden_size=hiddendim,
batch_first=True)
self.fc = nn.Linear(hiddendim, out_channels)
self.init_weights()
def init_weights(self):
init.normal_(self.hidden_embed.weight, std=0.01)
init.constant_(self.hidden_embed.bias, 0)
init.normal_(self.attnfeat_embed.weight, std=0.01)
init.constant_(self.attnfeat_embed.bias, 0)
def _attn(self, feat, h_state):
b, t, _ = feat.shape
feat = self.feat_embed(feat)
h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1)
h_state = h_state.expand(b, t, self.attndim)
sumTanh = torch.tanh(feat + h_state)
attn_w = self.attnfeat_embed(sumTanh).squeeze(-1)
attn_w = F.softmax(attn_w, dim=1).unsqueeze(1)
# [B,1,25]
return attn_w
def forward(self, feat, h_state, label_input):
attn_w = self._attn(feat, h_state)
attn_feat = attn_w @ feat
attn_feat = attn_feat.squeeze(1)
output, h_state = self.gru(
torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state)
pred = self.fc(output)
return pred, h_state
class ASTERDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
embedding_dim=256,
hiddendim=256,
attndim=256,
max_len=25,
seed=False,
time_step=32,
**kwargs):
super(ASTERDecoder, self).__init__()
self.num_classes = out_channels
self.bos = out_channels - 2
self.eos = 0
self.padding_idx = out_channels - 1
self.seed = seed
if seed:
self.embeder = Embedding(
in_timestep=time_step,
in_planes=in_channels,
)
self.word_embedding = nn.Embedding(self.num_classes,
embedding_dim,
padding_idx=self.padding_idx)
self.attndim = attndim
self.hiddendim = hiddendim
self.max_seq_len = max_len + 1
self.featdim = in_channels
self.attn_rnn_block = Attn_Rnn_Block(
featdim=self.featdim,
hiddendim=hiddendim,
embedding_dim=embedding_dim,
out_channels=out_channels - 2,
attndim=attndim,
)
self.embed_fc = nn.Linear(300, self.hiddendim)
def get_initial_state(self, embed, tile_times=1):
assert embed.shape[1] == 300
state = self.embed_fc(embed) # N * sDim
if tile_times != 1:
state = state.unsqueeze(1)
trans_state = state.transpose(0, 1)
state = trans_state.tile([tile_times, 1, 1])
trans_state = state.transpose(0, 1)
state = trans_state.reshape(-1, self.hiddendim)
state = state.unsqueeze(0) # 1 * N * sDim
return state
def forward(self, feat, data=None):
# b,25,512
b = feat.size(0)
if self.seed:
embedding_vectors = self.embeder(feat)
h_state = self.get_initial_state(embedding_vectors)
else:
h_state = torch.zeros(1, b, self.hiddendim).to(feat.device)
outputs = []
if self.training:
label = data[0]
label_embedding = self.word_embedding(label) # [B,25,256]
tokens = label_embedding[:, 0, :]
max_len = data[1].max() + 1
else:
tokens = torch.full([b, 1],
self.bos,
device=feat.device,
dtype=torch.long)
tokens = self.word_embedding(tokens.squeeze(1))
max_len = self.max_seq_len
pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
outputs.append(pred)
dec_seq = torch.full((feat.shape[0], max_len),
self.padding_idx,
dtype=torch.int64,
device=feat.get_device())
dec_seq[:, :1] = torch.argmax(pred, dim=-1)
for i in range(1, max_len):
if not self.training:
max_idx = torch.argmax(pred, dim=-1).squeeze(1)
tokens = self.word_embedding(max_idx)
dec_seq[:, i] = max_idx
if (dec_seq == self.eos).any(dim=-1).all():
break
else:
tokens = label_embedding[:, i, :]
pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
outputs.append(pred)
preds = torch.cat(outputs, 1)
if self.seed and self.training:
return [embedding_vectors, preds]
return preds if self.training else F.softmax(preds, -1)
|