topdu's picture
openocr demo
29f689c
raw
history blame
5.24 kB
"""This code is refer from:
https://github.com/jjwei66/BUSNet
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nrtr_decoder import PositionalEncoding, TransformerBlock
from .abinet_decoder import _get_mask, _get_length
class BUSDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
nhead=8,
num_layers=4,
dim_feedforward=2048,
dropout=0.1,
max_length=25,
ignore_index=100,
pretraining=False,
detach=True):
super().__init__()
d_model = in_channels
self.ignore_index = ignore_index
self.pretraining = pretraining
self.d_model = d_model
self.detach = detach
self.max_length = max_length + 1 # additional stop token
self.out_channels = out_channels
# --------------------------------------------------------------------------
# decoder specifics
self.proj = nn.Linear(out_channels, d_model, False)
self.token_encoder = PositionalEncoding(dropout=0.1,
dim=d_model,
max_len=self.max_length)
self.pos_encoder = PositionalEncoding(dropout=0.1,
dim=d_model,
max_len=self.max_length)
self.decoder = nn.ModuleList([
TransformerBlock(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
attention_dropout_rate=dropout,
residual_dropout_rate=dropout,
with_self_attn=False,
with_cross_attn=True,
) for i in range(num_layers)
])
v_mask = torch.empty((1, 1, d_model))
l_mask = torch.empty((1, 1, d_model))
self.v_mask = nn.Parameter(v_mask)
self.l_mask = nn.Parameter(l_mask)
torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
v_embeding = torch.empty((1, 1, d_model))
l_embeding = torch.empty((1, 1, d_model))
self.v_embeding = nn.Parameter(v_embeding)
self.l_embeding = nn.Parameter(l_embeding)
torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
self.cls = nn.Linear(d_model, out_channels)
def forward_decoder(self, q, x, mask=None):
for decoder_layer in self.decoder:
q = decoder_layer(q, x, cross_mask=mask)
output = q # (N, T, E)
logits = self.cls(output) # (N, T, C)
return logits
def forward(self, img_feat, data=None):
"""
Args:
tokens: (N, T, C) where T is length, N is batch size and C is classes number
lengths: (N,)
"""
img_feat = img_feat + self.v_embeding
B, L, C = img_feat.shape
# --------------------------------------------------------------------------
# decoder procedure
T = self.max_length
zeros = img_feat.new_zeros((B, T, C))
zeros_len = img_feat.new_zeros(B)
query = self.pos_encoder(zeros)
# 1. vision decode
v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
dim=1) # v
padding_mask = _get_mask(
self.max_length + zeros_len,
self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
v_mask = torch.zeros((1, 1, self.max_length, L),
device=img_feat.device).tile([B, 1, 1,
1]) # maxlen L
mask = torch.cat((v_mask, padding_mask), 3)
v_logits = self.forward_decoder(query, v_embed, mask=mask)
# 2. language decode
if self.training and self.pretraining:
tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
tokens = F.one_hot(tgt, num_classes=self.out_channels)
tokens = tokens.float()
lengths = data[-1]
else:
tokens = torch.softmax(v_logits, dim=-1)
lengths = _get_length(v_logits)
tokens = tokens.detach()
token_embed = self.proj(tokens) # (N, T, E)
token_embed = self.token_encoder(token_embed) # (T, N, E)
token_embed = token_embed + self.l_embeding
padding_mask = _get_mask(lengths,
self.max_length) # 对tokens长度以外的padding
mask = torch.cat((v_mask, padding_mask), 3)
l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
l_logits = self.forward_decoder(query, l_embed, mask=mask)
# 3. vision language decode
vl_embed = torch.cat((img_feat, token_embed), dim=1)
vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
if self.training:
return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
else:
return F.softmax(vl_logits, -1)