Spaces:
Running
Running
''' | |
This code is refer from: | |
https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class TokenLearner(nn.Module): | |
def __init__(self, input_embed_dim, out_token=30): | |
super().__init__() | |
self.token_norm = nn.LayerNorm(input_embed_dim) | |
self.tokenLearner = nn.Sequential( | |
nn.Conv2d(input_embed_dim, | |
input_embed_dim, | |
kernel_size=(1, 1), | |
stride=1, | |
groups=8, | |
bias=False), | |
nn.Conv2d(input_embed_dim, | |
out_token, | |
kernel_size=(1, 1), | |
stride=1, | |
bias=False)) | |
self.feat = nn.Conv2d(input_embed_dim, | |
input_embed_dim, | |
kernel_size=(1, 1), | |
stride=1, | |
groups=8, | |
bias=False) | |
self.norm = nn.LayerNorm(input_embed_dim) | |
def forward(self, x): | |
x = self.token_norm(x) # [bs, 257, 768] | |
x = x.transpose(1, 2).unsqueeze(-1) # [bs, 768, 257, 1] | |
selected = self.tokenLearner(x) # [bs, 27, 257, 1]. | |
selected = selected.flatten(2) # [bs, 27, 257]. | |
selected = F.softmax(selected, dim=-1) | |
feat = self.feat(x) # [bs, 768, 257, 1]. | |
feat = feat.flatten(2).transpose(1, 2) # [bs, 257, 768] | |
x = torch.einsum('...si,...id->...sd', selected, feat) # [bs, 27, 768] | |
x = self.norm(x) | |
return selected, x | |
class MGPDecoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
max_len=25, | |
only_char=False, | |
*args, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
num_classes = out_channels | |
embed_dim = in_channels | |
self.batch_max_length = max_len + 2 | |
self.char_tokenLearner = TokenLearner(embed_dim, self.batch_max_length) | |
self.char_head = nn.Linear( | |
embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
self.only_char = only_char | |
if not only_char: | |
self.bpe_tokenLearner = TokenLearner(embed_dim, | |
self.batch_max_length) | |
self.wp_tokenLearner = TokenLearner(embed_dim, | |
self.batch_max_length) | |
self.bpe_head = nn.Linear( | |
embed_dim, 50257) if num_classes > 0 else nn.Identity() | |
self.wp_head = nn.Linear( | |
embed_dim, 30522) if num_classes > 0 else nn.Identity() | |
def forward(self, x, data=None): | |
# attens = [] | |
# char | |
char_attn, x_char = self.char_tokenLearner(x) | |
x_char = self.char_head(x_char) | |
char_out = x_char | |
# attens = [char_attn] | |
if not self.only_char: | |
# bpe | |
bpe_attn, x_bpe = self.bpe_tokenLearner(x) | |
bpe_out = self.bpe_head(x_bpe) | |
# attens += [bpe_attn] | |
# wp | |
wp_attn, x_wp = self.wp_tokenLearner(x) | |
wp_out = self.wp_head(x_wp) | |
return [char_out, bpe_out, wp_out] if self.training else [ | |
F.softmax(char_out, -1), | |
F.softmax(bpe_out, -1), | |
F.softmax(wp_out, -1) | |
] | |
# attens += [wp_attn] | |
return char_out if self.training else F.softmax(char_out, -1) | |