OpenOCR-Demo / openrec /modeling /decoders /cdistnet_decoder.py
topdu's picture
openocr demo
29f689c
raw
history blame
10 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam
from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder
def generate_square_subsequent_mask(sz):
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0)))
return mask
class SEM_Pre(nn.Module):
def __init__(
self,
d_model=512,
dst_vocab_size=40,
residual_dropout_rate=0.1,
):
super(SEM_Pre, self).__init__()
self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size)
self.positional_encoding = PositionalEncoding(
dropout=residual_dropout_rate,
dim=d_model,
)
def forward(self, tgt):
tgt = self.embedding(tgt)
tgt = self.positional_encoding(tgt)
tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device)
return tgt, tgt_mask
class POS_Pre(nn.Module):
def __init__(
self,
d_model=512,
):
super(POS_Pre, self).__init__()
self.pos_encoding = PositionalEncoding(
dropout=0.1,
dim=d_model,
)
self.linear1 = nn.Linear(d_model, d_model)
self.linear2 = nn.Linear(d_model, d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, tgt):
pos = tgt.new_zeros(*tgt.shape)
pos = self.pos_encoding(pos)
pos2 = self.linear2(F.relu(self.linear1(pos)))
pos = self.norm2(pos + pos2)
return pos
class DSF(nn.Module):
def __init__(self, d_model, fusion_num):
super(DSF, self).__init__()
self.w_att = nn.Linear(fusion_num * d_model, d_model)
def forward(self, l_feature, v_feature):
"""
Args:
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
v_feature: (N, T, E) shape the same as l_feature
l_lengths: (N,)
v_lengths: (N,)
"""
f = torch.cat((l_feature, v_feature), dim=2)
f_att = torch.sigmoid(self.w_att(f))
output = f_att * v_feature + (1 - f_att) * l_feature
return output
class MDCDP(nn.Module):
r"""
Multi-Domain CharacterDistance Perception
"""
def __init__(self, d_model, n_head, d_inner, num_layers):
super(MDCDP, self).__init__()
self.num_layers = num_layers
# step 1 SAE
self.layers_pos = nn.ModuleList([
TransformerBlock(d_model, n_head, d_inner)
for _ in range(num_layers)
])
# step 2 CBI:
self.layers2 = nn.ModuleList([
TransformerBlock(
d_model,
n_head,
d_inner,
with_self_attn=False,
with_cross_attn=True,
) for _ in range(num_layers)
])
self.layers3 = nn.ModuleList([
TransformerBlock(
d_model,
n_head,
d_inner,
with_self_attn=False,
with_cross_attn=True,
) for _ in range(num_layers)
])
# step 3 :DSF
self.dynamic_shared_fusion = DSF(d_model, 2)
def forward(
self,
sem,
vis,
pos,
tgt_mask=None,
memory_mask=None,
):
for i in range(self.num_layers):
# ----------step 1 -----------: SAE: Self-Attention Enhancement
pos = self.layers_pos[i](pos, self_mask=tgt_mask)
# ----------step 2 -----------: CBI: Cross-Branch Interaction
# CBI-V
pos_vis = self.layers2[i](
pos,
vis,
cross_mask=memory_mask,
)
# CBI-S
pos_sem = self.layers3[i](
pos,
sem,
cross_mask=tgt_mask,
)
# ----------step 3 -----------: DSF: Dynamic Shared Fusion
pos = self.dynamic_shared_fusion(pos_vis, pos_sem)
output = pos
return output
class ConvBnRelu(nn.Module):
# adapt padding for kernel_size change
def __init__(
self,
in_channels,
out_channels,
kernel_size,
conv=nn.Conv2d,
stride=2,
inplace=True,
):
super().__init__()
p_size = [int(k // 2) for k in kernel_size]
# p_size = int(kernel_size//2)
self.conv = conv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=p_size,
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=inplace)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class CDistNetDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
n_head=None,
num_encoder_blocks=3,
num_decoder_blocks=3,
beam_size=0,
max_len=25,
residual_dropout_rate=0.1,
add_conv=False,
**kwargs):
super(CDistNetDecoder, self).__init__()
dst_vocab_size = out_channels
self.ignore_index = dst_vocab_size - 1
self.bos = dst_vocab_size - 2
self.eos = 0
self.beam_size = beam_size
self.max_len = max_len
self.add_conv = add_conv
d_model = in_channels
dim_feedforward = d_model * 4
n_head = n_head if n_head is not None else d_model // 32
if add_conv:
self.convbnrelu = ConvBnRelu(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(1, 3),
stride=(1, 2),
)
if num_encoder_blocks > 0:
self.positional_encoding = PositionalEncoding(
dropout=0.1,
dim=d_model,
)
self.trans_encoder = Transformer_Encoder(
n_layers=num_encoder_blocks,
n_head=n_head,
d_model=d_model,
d_inner=dim_feedforward,
)
else:
self.trans_encoder = None
self.semantic_branch = SEM_Pre(
d_model=d_model,
dst_vocab_size=dst_vocab_size,
residual_dropout_rate=residual_dropout_rate,
)
self.positional_branch = POS_Pre(d_model=d_model)
self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2,
num_decoder_blocks)
self._reset_parameters()
self.tgt_word_prj = nn.Linear(
d_model, dst_vocab_size - 2,
bias=False) # We don't predict <bos> nor <pad>
self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5)
def forward(self, x, data=None):
if self.add_conv:
x = self.convbnrelu(x)
# x = rearrange(x, "b c h w -> b (w h) c")
x = x.flatten(2).transpose(1, 2)
if self.trans_encoder is not None:
x = self.positional_encoding(x)
vis_feat = self.trans_encoder(x, src_mask=None)
else:
vis_feat = x
if self.training:
max_len = data[1].max()
tgt = data[0][:, :1 + max_len]
res = self.forward_train(vis_feat, tgt)
else:
if self.beam_size > 0:
res = self.forward_beam(vis_feat)
else:
res = self.forward_test(vis_feat)
return res
def forward_train(self, vis_feat, tgt):
sem_feat, sem_mask = self.semantic_branch(tgt)
pos_feat = self.positional_branch(sem_feat)
output = self.mdcdp(
sem_feat,
vis_feat,
pos_feat,
tgt_mask=sem_mask,
memory_mask=None,
)
logit = self.tgt_word_prj(output)
return logit
def forward_test(self, vis_feat):
bs = vis_feat.size(0)
dec_seq = torch.full(
(bs, self.max_len + 1),
self.ignore_index,
dtype=torch.int64,
device=vis_feat.device,
)
dec_seq[:, 0] = self.bos
logits = []
for len_dec_seq in range(0, self.max_len):
sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq +
1])
pos_feat = self.positional_branch(sem_feat)
output = self.mdcdp(
sem_feat,
vis_feat,
pos_feat,
tgt_mask=sem_mask,
memory_mask=None,
)
dec_output = output[:, -1:, :]
word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
logits.append(word_prob)
if len_dec_seq < self.max_len:
# greedy decode. add the next token index to the target input
dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if (dec_seq == self.eos).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
return logits
def forward_beam(self, x):
"""Translation work in one batch."""
# to do
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)