''' This code is refer from: https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR ''' import torch import torch.nn as nn import torch.nn.functional as F class TokenLearner(nn.Module): def __init__(self, input_embed_dim, out_token=30): super().__init__() self.token_norm = nn.LayerNorm(input_embed_dim) self.tokenLearner = nn.Sequential( nn.Conv2d(input_embed_dim, input_embed_dim, kernel_size=(1, 1), stride=1, groups=8, bias=False), nn.Conv2d(input_embed_dim, out_token, kernel_size=(1, 1), stride=1, bias=False)) self.feat = nn.Conv2d(input_embed_dim, input_embed_dim, kernel_size=(1, 1), stride=1, groups=8, bias=False) self.norm = nn.LayerNorm(input_embed_dim) def forward(self, x): x = self.token_norm(x) # [bs, 257, 768] x = x.transpose(1, 2).unsqueeze(-1) # [bs, 768, 257, 1] selected = self.tokenLearner(x) # [bs, 27, 257, 1]. selected = selected.flatten(2) # [bs, 27, 257]. selected = F.softmax(selected, dim=-1) feat = self.feat(x) # [bs, 768, 257, 1]. feat = feat.flatten(2).transpose(1, 2) # [bs, 257, 768] x = torch.einsum('...si,...id->...sd', selected, feat) # [bs, 27, 768] x = self.norm(x) return selected, x class MGPDecoder(nn.Module): def __init__(self, in_channels, out_channels, max_len=25, only_char=False, *args, **kwargs): super().__init__(*args, **kwargs) num_classes = out_channels embed_dim = in_channels self.batch_max_length = max_len + 2 self.char_tokenLearner = TokenLearner(embed_dim, self.batch_max_length) self.char_head = nn.Linear( embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.only_char = only_char if not only_char: self.bpe_tokenLearner = TokenLearner(embed_dim, self.batch_max_length) self.wp_tokenLearner = TokenLearner(embed_dim, self.batch_max_length) self.bpe_head = nn.Linear( embed_dim, 50257) if num_classes > 0 else nn.Identity() self.wp_head = nn.Linear( embed_dim, 30522) if num_classes > 0 else nn.Identity() def forward(self, x, data=None): # attens = [] # char char_attn, x_char = self.char_tokenLearner(x) x_char = self.char_head(x_char) char_out = x_char # attens = [char_attn] if not self.only_char: # bpe bpe_attn, x_bpe = self.bpe_tokenLearner(x) bpe_out = self.bpe_head(x_bpe) # attens += [bpe_attn] # wp wp_attn, x_wp = self.wp_tokenLearner(x) wp_out = self.wp_head(x_wp) return [char_out, bpe_out, wp_out] if self.training else [ F.softmax(char_out, -1), F.softmax(bpe_out, -1), F.softmax(wp_out, -1) ] # attens += [wp_attn] return char_out if self.training else F.softmax(char_out, -1)