Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SAREncoder(nn.Module): | |
def __init__(self, | |
enc_bi_rnn=False, | |
enc_drop_rnn=0.1, | |
in_channels=512, | |
d_enc=512, | |
**kwargs): | |
super().__init__() | |
# LSTM Encoder | |
if enc_bi_rnn: | |
bidirectional = True | |
else: | |
bidirectional = False | |
hidden_size = d_enc | |
self.rnn_encoder = nn.LSTM(input_size=in_channels, | |
hidden_size=hidden_size, | |
num_layers=2, | |
dropout=enc_drop_rnn, | |
bidirectional=bidirectional, | |
batch_first=True) | |
# global feature transformation | |
encoder_rnn_out_size = hidden_size * (int(enc_bi_rnn) + 1) | |
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) | |
def forward(self, feat): | |
h_feat = feat.shape[2] | |
feat_v = F.max_pool2d(feat, | |
kernel_size=(h_feat, 1), | |
stride=1, | |
padding=0) | |
feat_v = feat_v.squeeze(2) | |
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C | |
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * hidden_size | |
valid_hf = holistic_feat[:, -1, :] # bsz * hidden_size | |
holistic_feat = self.linear(valid_hf) # bsz * C | |
return holistic_feat | |
class SARDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
max_len=25, | |
enc_bi_rnn=False, | |
enc_drop_rnn=0.1, | |
dec_bi_rnn=False, | |
dec_drop_rnn=0.0, | |
pred_dropout=0.1, | |
pred_concat=True, | |
mask=True, | |
use_lstm=True, | |
**kwargs): | |
super(SARDecoder, self).__init__() | |
self.num_classes = out_channels | |
self.start_idx = out_channels - 2 | |
self.padding_idx = out_channels - 1 | |
self.end_idx = 0 | |
self.max_seq_len = max_len + 1 | |
self.pred_concat = pred_concat | |
self.mask = mask | |
enc_dim = in_channels | |
d = in_channels | |
embedding_dim = in_channels | |
dec_dim = in_channels | |
self.use_lstm = use_lstm | |
if use_lstm: | |
# encoder module | |
self.encoder = SAREncoder(enc_bi_rnn=enc_bi_rnn, | |
enc_drop_rnn=enc_drop_rnn, | |
in_channels=in_channels, | |
d_enc=enc_dim) | |
# decoder module | |
# 2D attention layer | |
self.conv1x1_1 = nn.Linear(dec_dim, d) | |
self.conv3x3_1 = nn.Conv2d(in_channels, | |
d, | |
kernel_size=3, | |
stride=1, | |
padding=1) | |
self.conv1x1_2 = nn.Linear(d, 1) | |
# Decoder input embedding | |
self.embedding = nn.Embedding(self.num_classes, | |
embedding_dim, | |
padding_idx=self.padding_idx) | |
self.rnndecoder = nn.LSTM(input_size=embedding_dim, | |
hidden_size=dec_dim, | |
num_layers=2, | |
dropout=dec_drop_rnn, | |
bidirectional=dec_bi_rnn, | |
batch_first=True) | |
# Prediction layer | |
self.pred_dropout = nn.Dropout(pred_dropout) | |
if pred_concat: | |
fc_in_channel = in_channels + in_channels + dec_dim | |
else: | |
fc_in_channel = in_channels | |
self.prediction = nn.Linear(fc_in_channel, self.num_classes) | |
self.softmax = nn.Softmax(dim=-1) | |
def _2d_attation(self, feat, tokens, data, training): | |
Hidden_state = self.rnndecoder(tokens)[0] | |
attn_query = self.conv1x1_1(Hidden_state) | |
bsz, seq_len, _ = attn_query.size() | |
attn_query = attn_query.unsqueeze(-1).unsqueeze(-1) | |
# bsz * seq_len+1 * attn_size * 1 * 1 | |
attn_key = self.conv3x3_1(feat).unsqueeze(1) | |
# bsz * 1 * attn_size * h * w | |
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) | |
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous() | |
attn_weight = self.conv1x1_2(attn_weight) | |
_, T, h, w, c = attn_weight.size() | |
if self.mask: | |
valid_ratios = data[-1] | |
# cal mask of attention weight | |
attn_mask = torch.zeros_like(attn_weight) | |
for i, valid_ratio in enumerate(valid_ratios): | |
valid_width = min(w, math.ceil(w * valid_ratio)) | |
attn_mask[i, :, :, valid_width:, :] = 1 | |
attn_weight = attn_weight.masked_fill(attn_mask.bool(), | |
float('-inf')) | |
attn_weight = attn_weight.view(bsz, T, -1) | |
attn_weight = F.softmax(attn_weight, dim=-1) | |
attn_weight = attn_weight.view(bsz, T, h, w, | |
c).permute(0, 1, 4, 2, 3).contiguous() | |
# bsz, T, 1, h, w | |
# bsz, 1, f_c ,h, w | |
attn_feat = torch.sum(torch.mul(feat.unsqueeze(1), attn_weight), | |
(3, 4), | |
keepdim=False) | |
return [Hidden_state, attn_feat] | |
def forward_train(self, feat, holistic_feat, data): | |
max_len = data[1].max() | |
label = data[0][:, :1 + max_len] # label | |
label_embedding = self.embedding(label) | |
holistic_feat = holistic_feat.unsqueeze(1) | |
tokens = torch.cat((holistic_feat, label_embedding), dim=1) | |
Hidden_state, attn_feat = self._2d_attation(feat, | |
tokens, | |
data, | |
training=self.training) | |
bsz, seq_len, f_c = Hidden_state.size() | |
# linear transformation | |
if self.pred_concat: | |
f_c = holistic_feat.size(-1) | |
holistic_feat = holistic_feat.expand(bsz, seq_len, f_c) | |
preds = self.prediction( | |
torch.cat((Hidden_state, attn_feat, holistic_feat), 2)) | |
else: | |
preds = self.prediction(attn_feat) | |
# bsz * (seq_len + 1) * num_classes | |
preds = self.pred_dropout(preds) | |
return preds[:, 1:, :] | |
def forward_test(self, feat, holistic_feat, data=None): | |
bsz = feat.shape[0] | |
seq_len = self.max_seq_len | |
holistic_feat = holistic_feat.unsqueeze(1) | |
tokens = torch.full((bsz, ), | |
self.start_idx, | |
device=feat.device, | |
dtype=torch.long) | |
outputs = [] | |
tokens = self.embedding(tokens) | |
tokens = tokens.unsqueeze(1).expand(-1, seq_len, -1) | |
tokens = torch.cat((holistic_feat, tokens), dim=1) | |
for i in range(1, seq_len + 1): | |
Hidden_state, attn_feat = self._2d_attation(feat, | |
tokens, | |
data=data, | |
training=self.training) | |
if self.pred_concat: | |
f_c = holistic_feat.size(-1) | |
holistic_feat = holistic_feat.expand(bsz, seq_len + 1, f_c) | |
preds = self.prediction( | |
torch.cat((Hidden_state, attn_feat, holistic_feat), 2)) | |
else: | |
preds = self.prediction(attn_feat) | |
# bsz * (seq_len + 1) * num_classes | |
char_output = preds[:, i, :] | |
char_output = F.softmax(char_output, -1) | |
outputs.append(char_output) | |
_, max_idx = torch.max(char_output, dim=1, keepdim=False) | |
char_embedding = self.embedding(max_idx) | |
if (i < seq_len): | |
tokens[:, i + 1, :] = char_embedding | |
if (tokens == self.end_idx).any(dim=-1).all(): | |
break | |
outputs = torch.stack(outputs, 1) | |
return outputs | |
def forward(self, feat, data=None): | |
if self.use_lstm: | |
holistic_feat = self.encoder(feat) # bsz c | |
else: | |
holistic_feat = F.adaptive_avg_pool2d(feat, (1, 1)).squeeze() | |
if self.training: | |
preds = self.forward_train(feat, holistic_feat, data=data) | |
else: | |
preds = self.forward_test(feat, holistic_feat, data=data) | |
# (bsz, seq_len, num_classes) | |
return preds | |