File size: 4,551 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from .abinet_decoder import PositionAttention
from .nrtr_decoder import PositionalEncoding, TransformerBlock


class Trans(nn.Module):

    def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers):
        super().__init__()
        self.d_model = dim
        self.nhead = nhead

        self.pos_encoder = PositionalEncoding(dropout=0.0,
                                              dim=self.d_model,
                                              max_len=512)

        self.transformer = nn.ModuleList([
            TransformerBlock(
                dim,
                nhead,
                dim_feedforward,
                attention_dropout_rate=dropout,
                residual_dropout_rate=dropout,
                with_self_attn=True,
                with_cross_attn=False,
            ) for i in range(num_layers)
        ])

    def forward(self, feature, attn_map=None, use_mask=False):
        n, c, h, w = feature.shape
        feature = feature.flatten(2).transpose(1, 2)

        if use_mask:
            _, t, h, w = attn_map.shape
            location_mask = (attn_map.view(n, t, -1).transpose(1, 2) >
                             0.05).type(torch.float)  # n,hw,t
            location_mask = location_mask.bmm(location_mask.transpose(
                1, 2))  # n, hw, hw
            location_mask = location_mask.new_zeros(
                (h * w, h * w)).masked_fill(location_mask > 0, float('-inf'))
            location_mask = location_mask.unsqueeze(1)  # n, 1, hw, hw
        else:
            location_mask = None

        feature = self.pos_encoder(feature)
        for layer in self.transformer:
            feature = layer(feature, self_mask=location_mask)
        feature = feature.transpose(1, 2).view(n, c, h, w)
        return feature, location_mask


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class LPVDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_layer=3,
                 max_len=25,
                 use_mask=False,
                 dim_feedforward=1024,
                 nhead=8,
                 dropout=0.1,
                 trans_layer=2):
        super().__init__()
        self.use_mask = use_mask
        self.max_len = max_len
        attn_layer = PositionAttention(max_length=max_len + 1,
                                       mode='nearest',
                                       in_channels=in_channels,
                                       num_channels=in_channels // 8)
        trans_layer = Trans(dim=in_channels,
                            nhead=nhead,
                            dim_feedforward=dim_feedforward,
                            dropout=dropout,
                            num_layers=trans_layer)
        cls_layer = nn.Linear(in_channels, out_channels - 2)

        self.attention = _get_clones(attn_layer, num_layer)
        self.trans = _get_clones(trans_layer, num_layer - 1)
        self.cls = _get_clones(cls_layer, num_layer)

    def forward(self, x, data=None):
        if data is not None:
            max_len = data[1].max()
        else:
            max_len = self.max_len
        features = x  # (N, E, H, W)

        attn_vecs, attn_scores_map = self.attention[0](features)
        attn_vecs = attn_vecs[:, :max_len + 1, :]
        if not self.training:
            for i in range(1, len(self.attention)):
                features, mask = self.trans[i - 1](features,
                                                   attn_scores_map,
                                                   use_mask=self.use_mask)
                attn_vecs, attn_scores_map = self.attention[i](
                    features, attn_vecs)  # (N, T, E), (N, T, H, W)
            return F.softmax(self.cls[-1](attn_vecs), -1)
        else:
            logits = []
            logit = self.cls[0](attn_vecs)  # (N, T, C)
            logits.append(logit)
            for i in range(1, len(self.attention)):
                features, mask = self.trans[i - 1](features,
                                                   attn_scores_map,
                                                   use_mask=self.use_mask)
                attn_vecs, attn_scores_map = self.attention[i](
                    features, attn_vecs)  # (N, T, E), (N, T, H, W)
                logit = self.cls[i](attn_vecs)  # (N, T, C)
                logits.append(logit)
            return logits