File size: 3,614 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
'''
This code is refer from:
https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
'''

import torch
import torch.nn as nn
import torch.nn.functional as F


class TokenLearner(nn.Module):

    def __init__(self, input_embed_dim, out_token=30):
        super().__init__()
        self.token_norm = nn.LayerNorm(input_embed_dim)
        self.tokenLearner = nn.Sequential(
            nn.Conv2d(input_embed_dim,
                      input_embed_dim,
                      kernel_size=(1, 1),
                      stride=1,
                      groups=8,
                      bias=False),
            nn.Conv2d(input_embed_dim,
                      out_token,
                      kernel_size=(1, 1),
                      stride=1,
                      bias=False))
        self.feat = nn.Conv2d(input_embed_dim,
                              input_embed_dim,
                              kernel_size=(1, 1),
                              stride=1,
                              groups=8,
                              bias=False)
        self.norm = nn.LayerNorm(input_embed_dim)

    def forward(self, x):
        x = self.token_norm(x)  # [bs, 257, 768]
        x = x.transpose(1, 2).unsqueeze(-1)  # [bs, 768, 257, 1]
        selected = self.tokenLearner(x)  # [bs, 27, 257, 1].
        selected = selected.flatten(2)  # [bs, 27, 257].
        selected = F.softmax(selected, dim=-1)
        feat = self.feat(x)  #  [bs, 768, 257, 1].
        feat = feat.flatten(2).transpose(1, 2)  # [bs, 257, 768]
        x = torch.einsum('...si,...id->...sd', selected, feat)  # [bs, 27, 768]

        x = self.norm(x)
        return selected, x


class MGPDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 max_len=25,
                 only_char=False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        num_classes = out_channels
        embed_dim = in_channels
        self.batch_max_length = max_len + 2
        self.char_tokenLearner = TokenLearner(embed_dim, self.batch_max_length)
        self.char_head = nn.Linear(
            embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self.only_char = only_char
        if not only_char:
            self.bpe_tokenLearner = TokenLearner(embed_dim,
                                                 self.batch_max_length)
            self.wp_tokenLearner = TokenLearner(embed_dim,
                                                self.batch_max_length)
            self.bpe_head = nn.Linear(
                embed_dim, 50257) if num_classes > 0 else nn.Identity()
            self.wp_head = nn.Linear(
                embed_dim, 30522) if num_classes > 0 else nn.Identity()

    def forward(self, x, data=None):
        # attens = []
        # char
        char_attn, x_char = self.char_tokenLearner(x)
        x_char = self.char_head(x_char)
        char_out = x_char
        # attens = [char_attn]
        if not self.only_char:
            # bpe
            bpe_attn, x_bpe = self.bpe_tokenLearner(x)
            bpe_out = self.bpe_head(x_bpe)
            # attens += [bpe_attn]
            # wp
            wp_attn, x_wp = self.wp_tokenLearner(x)
            wp_out = self.wp_head(x_wp)
            return [char_out, bpe_out, wp_out] if self.training else [
                F.softmax(char_out, -1),
                F.softmax(bpe_out, -1),
                F.softmax(wp_out, -1)
            ]
            # attens += [wp_attn]

        return char_out if self.training else F.softmax(char_out, -1)