topdu's picture
openocr demo
29f689c
raw
history blame
9.25 kB
"""This code is refer from:
https://github.com/byeonghu-na/MATRN
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from openrec.modeling.decoders.abinet_decoder import BCNLanguage, PositionAttention, _get_length
from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
class BaseSemanticVisual_backbone_feature(nn.Module):
def __init__(self,
d_model=512,
nhead=8,
num_layers=4,
dim_feedforward=2048,
dropout=0.0,
alignment_mask_example_prob=0.9,
alignment_mask_candidate_prob=0.9,
alignment_num_vis_mask=10,
max_length=25,
num_classes=37):
super().__init__()
self.mask_example_prob = alignment_mask_example_prob
self.mask_candidate_prob = alignment_mask_candidate_prob #ifnone(config.model_alignment_mask_candidate_prob, 0.9)
self.num_vis_mask = alignment_num_vis_mask
self.nhead = nhead
self.d_model = d_model
self.max_length = max_length + 1 # additional stop token
self.model1 = nn.ModuleList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=True,
with_cross_attn=False,
) for i in range(num_layers)
])
self.pos_encoder_tfm = PositionalEncoding(dim=d_model,
dropout=0,
max_len=1024)
self.model2_vis = PositionAttention(
max_length=self.max_length, # additional stop token
in_channels=d_model,
num_channels=d_model // 8,
mode='nearest',
)
self.cls_vis = nn.Linear(d_model, num_classes)
self.cls_sem = nn.Linear(d_model, num_classes)
self.w_att = nn.Linear(2 * d_model, d_model)
v_token = torch.empty((1, d_model))
self.v_token = nn.Parameter(v_token)
torch.nn.init.uniform_(self.v_token, -0.001, 0.001)
self.cls = nn.Linear(d_model, num_classes)
def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None):
"""
Args:
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
v_feature: (N, E, H, W)
lengths_l: (N,)
v_attn: (N, T, H, W)
l_logits: (N, T, C)
texts: (N, T, C)
"""
N, E, H, W = v_feature.size()
v_feature = v_feature.flatten(2, 3).transpose(1, 2) #(N, H*W, E)
v_attn = v_attn.flatten(2, 3) # (N, T, H*W)
if self.training:
for idx, length in enumerate(lengths_l):
if np.random.random() <= self.mask_example_prob:
l_idx = np.random.randint(int(length))
v_random_idx = v_attn[idx, l_idx].argsort(
descending=True).cpu().numpy()[:self.num_vis_mask, ]
v_random_idx = v_random_idx[np.random.random(
v_random_idx.shape) <= self.mask_candidate_prob]
v_feature[idx, v_random_idx] = self.v_token
zeros = v_feature.new_zeros((N, H * W, E)) # (N, H*W, E)
base_pos = self.pos_encoder_tfm(zeros) # (N, H*W, E)
base_pos = torch.bmm(v_attn, base_pos) # (N, T, E)
l_feature = l_feature + base_pos
sv_feature = torch.cat((v_feature, l_feature), dim=1) # (H*W+T, N, E)
for decoder_layer in self.model1:
sv_feature = decoder_layer(sv_feature) # (H*W+T, N, E)
sv_to_v_feature = sv_feature[:, :H * W] # (N, H*W, E)
sv_to_s_feature = sv_feature[:, H * W:] # (N, T, E)
sv_to_v_feature = sv_to_v_feature.transpose(1, 2).reshape(N, E, H, W)
sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E)
sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C)
pt_v_lengths = _get_length(sv_to_v_logits) # (N,)
sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C)
pt_s_lengths = _get_length(sv_to_s_logits) # (N,)
f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2)
f_att = torch.sigmoid(self.w_att(f))
output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature
logits = self.cls(output) # (N, T, C)
pt_lengths = _get_length(logits)
return {
'logits': logits,
'pt_lengths': pt_lengths,
'v_logits': sv_to_v_logits,
'pt_v_lengths': pt_v_lengths,
's_logits': sv_to_s_logits,
'pt_s_lengths': pt_s_lengths,
'name': 'alignment'
}
class MATRNDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
nhead=8,
num_layers=3,
dim_feedforward=2048,
dropout=0.1,
max_length=25,
iter_size=3,
**kwargs):
super().__init__()
self.max_length = max_length + 1
d_model = in_channels
self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
self.encoder = nn.ModuleList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=True,
with_cross_attn=False,
) for _ in range(num_layers)
])
self.decoder = PositionAttention(
max_length=self.max_length, # additional stop token
in_channels=d_model,
num_channels=d_model // 8,
mode='nearest',
)
self.out_channels = out_channels
self.cls = nn.Linear(d_model, self.out_channels)
self.iter_size = iter_size
if iter_size > 0:
self.language = BCNLanguage(
d_model=d_model,
nhead=nhead,
num_layers=4,
dim_feedforward=dim_feedforward,
dropout=dropout,
max_length=max_length,
num_classes=self.out_channels,
)
# alignment
self.semantic_visual = BaseSemanticVisual_backbone_feature(
d_model=d_model,
nhead=nhead,
num_layers=2,
dim_feedforward=dim_feedforward,
max_length=max_length,
num_classes=self.out_channels)
def forward(self, x, data=None):
# bs, c, h, w
x = x.permute([0, 2, 3, 1]) # bs, h, w, c
_, H, W, C = x.shape
# assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
feature = x.flatten(1, 2) # bs, h*w, c
feature = self.pos_encoder(feature) # bs, h*w, c
for encoder_layer in self.encoder:
feature = encoder_layer(feature)
# bs, h*w, c
feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
2) # bs, c, h, w
v_feature, v_attn_input = self.decoder(feature) # (bs[N], T, E)
vis_logits = self.cls(v_feature) # (bs[N], T, E)
align_lengths = _get_length(vis_logits)
align_logits = vis_logits
all_l_res, all_a_res = [], []
for _ in range(self.iter_size):
tokens = F.softmax(align_logits, dim=-1)
lengths = torch.clamp(
align_lengths, 2,
self.max_length) # TODO: move to language model
l_feature, l_logits = self.language(tokens, lengths)
all_l_res.append(l_logits)
# alignment
lengths_l = _get_length(l_logits)
lengths_l.clamp_(2, self.max_length)
a_res = self.semantic_visual(l_feature,
feature,
lengths_l=lengths_l,
v_attn=v_attn_input)
a_v_res = a_res['v_logits']
# {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'],
# 'name': 'alignment'}
all_a_res.append(a_v_res)
a_s_res = a_res['s_logits']
# {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'],
# 'name': 'alignment'}
align_logits = a_res['logits']
all_a_res.append(a_s_res)
all_a_res.append(align_logits)
align_lengths = a_res['pt_lengths']
if self.training:
return {
'align': all_a_res,
'lang': all_l_res,
'vision': vis_logits
}
else:
return F.softmax(align_logits, -1)