Spaces:
Running
Running
"""This code is refer from: | |
https://github.com/byeonghu-na/MATRN | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length | |
from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock | |
class BaseSemanticVisual_backbone_feature(nn.Module): | |
def __init__(self, | |
d_model=512, | |
nhead=8, | |
num_layers=4, | |
dim_feedforward=2048, | |
dropout=0.0, | |
alignment_mask_example_prob=0.9, | |
alignment_mask_candidate_prob=0.9, | |
alignment_num_vis_mask=10, | |
max_length=25, | |
num_classes=37): | |
super().__init__() | |
self.mask_example_prob = alignment_mask_example_prob | |
self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9) | |
self.num_vis_mask = alignment_num_vis_mask | |
self.nhead = nhead | |
self.d_model = d_model | |
self.max_length = max_length + 1 # additional stop token | |
self.model1 = nn.ModuleList([ | |
TransformerBlock( | |
d_model=d_model, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
attention_dropout_rate=dropout, | |
residual_dropout_rate=dropout, | |
with_self_attn=True, | |
with_cross_attn=False, | |
) for i in range(num_layers) | |
]) | |
self.pos_encoder_tfm = PositionalEncoding(dim=d_model, | |
dropout=0, | |
max_len=1024) | |
self.model2_vis = PositionAttention( | |
max_length=self.max_length, # additional stop token | |
in_channels=d_model, | |
num_channels=d_model // 8, | |
mode='nearest', | |
) | |
self.cls_vis = nn.Linear(d_model, num_classes) | |
self.cls_sem = nn.Linear(d_model, num_classes) | |
self.w_att = nn.Linear(2 * d_model, d_model) | |
v_token = torch.empty((1, d_model)) | |
self.v_token = nn.Parameter(v_token) | |
torch.nn.init.uniform_(self.v_token, -0.001, 0.001) | |
self.cls = nn.Linear(d_model, num_classes) | |
def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None): | |
""" | |
Args: | |
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model | |
v_feature: (N, E, H, W) | |
lengths_l: (N,) | |
v_attn: (N, T, H, W) | |
l_logits: (N, T, C) | |
texts: (N, T, C) | |
""" | |
N, E, H, W = v_feature.size() | |
v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E) | |
v_attn = v_attn.flatten(2, 3) # (N, T, H*W) | |
if self.training: | |
for idx, length in enumerate(lengths_l): | |
if np.random.random() <= self.mask_example_prob: | |
l_idx = np.random.randint(int(length)) | |
v_random_idx = v_attn[idx, l_idx].argsort( | |
descending=True).cpu().numpy()[:self.num_vis_mask, ] | |
v_random_idx = v_random_idx[np.random.random( | |
v_random_idx.shape) <= self.mask_candidate_prob] | |
v_feature[idx, v_random_idx] = self.v_token | |
zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E) | |
base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E) | |
base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) | |
l_feature = l_feature + base_pos | |
sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E) | |
for decoder_layer in self.model1: | |
sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E) | |
sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E) | |
sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E) | |
sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W) | |
sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) | |
sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) | |
pt_v_lengths = _get_length(sv_to_v_logits) # (N,) | |
sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) | |
pt_s_lengths = _get_length(sv_to_s_logits) # (N,) | |
f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) | |
f_att = torch.sigmoid(self.w_att(f)) | |
output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature | |
logits = self.cls(output) # (N, T, C) | |
pt_lengths = _get_length(logits) | |
return { | |
'logits': logits, | |
'pt_lengths': pt_lengths, | |
'v_logits': sv_to_v_logits, | |
'pt_v_lengths': pt_v_lengths, | |
's_logits': sv_to_s_logits, | |
'pt_s_lengths': pt_s_lengths, | |
'name': 'alignment' | |
} | |
class MATRNDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
nhead=8, | |
num_layers=3, | |
dim_feedforward=2048, | |
dropout=0.1, | |
max_length=25, | |
iter_size=3, | |
**kwargs): | |
super().__init__() | |
self.max_length = max_length + 1 | |
d_model = in_channels | |
self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model) | |
self.encoder = nn.ModuleList([ | |
TransformerBlock( | |
d_model=d_model, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
attention_dropout_rate=dropout, | |
residual_dropout_rate=dropout, | |
with_self_attn=True, | |
with_cross_attn=False, | |
) for _ in range(num_layers) | |
]) | |
self.decoder = PositionAttention( | |
max_length=self.max_length, # additional stop token | |
in_channels=d_model, | |
num_channels=d_model // 8, | |
mode='nearest', | |
) | |
self.out_channels = out_channels | |
self.cls = nn.Linear(d_model, self.out_channels) | |
self.iter_size = iter_size | |
if iter_size > 0: | |
self.language = BCNLanguage( | |
d_model=d_model, | |
nhead=nhead, | |
num_layers=4, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
max_length=max_length, | |
num_classes=self.out_channels, | |
) | |
# alignment | |
self.semantic_visual = BaseSemanticVisual_backbone_feature( | |
d_model=d_model, | |
nhead=nhead, | |
num_layers=2, | |
dim_feedforward=dim_feedforward, | |
max_length=max_length, | |
num_classes=self.out_channels) | |
def forward(self, x, data=None): | |
# bs, c, h, w | |
x = x.permute([0, 2, 3, 1]) # bs, h, w, c | |
_, H, W, C = x.shape | |
# assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.' | |
feature = x.flatten(1, 2) # bs, h*w, c | |
feature = self.pos_encoder(feature) # bs, h*w, c | |
for encoder_layer in self.encoder: | |
feature = encoder_layer(feature) | |
# bs, h*w, c | |
feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1, | |
2) # bs, c, h, w | |
v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E) | |
vis_logits = self.cls(v_feature) # (bs[N], T, E) | |
align_lengths = _get_length(vis_logits) | |
align_logits = vis_logits | |
all_l_res, all_a_res = [], [] | |
for _ in range(self.iter_size): | |
tokens = F.softmax(align_logits, dim=-1) | |
lengths = torch.clamp( | |
align_lengths, 2, | |
self.max_length) # TODO: move to language model | |
l_feature, l_logits = self.language(tokens, lengths) | |
all_l_res.append(l_logits) | |
# alignment | |
lengths_l = _get_length(l_logits) | |
lengths_l.clamp_(2, self.max_length) | |
a_res = self.semantic_visual(l_feature, | |
feature, | |
lengths_l=lengths_l, | |
v_attn=v_attn_input) | |
a_v_res = a_res['v_logits'] | |
# {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'], | |
# 'name': 'alignment'} | |
all_a_res.append(a_v_res) | |
a_s_res = a_res['s_logits'] | |
# {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'], | |
# 'name': 'alignment'} | |
align_logits = a_res['logits'] | |
all_a_res.append(a_s_res) | |
all_a_res.append(align_logits) | |
align_lengths = a_res['pt_lengths'] | |
if self.training: | |
return { | |
'align': all_a_res, | |
'lang': all_l_res, | |
'vision': vis_logits | |
} | |
else: | |
return F.softmax(align_logits, -1) | |