File size: 4,818 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.model import BaseModule
from torch import nn

from mmocr.registry import MODELS


class UpBlock(BaseModule):
    """Upsample block for DRRG and TextSnake.

    DRRG: `Deep Relational Reasoning Graph Network for Arbitrary Shape
    Text Detection <https://arxiv.org/abs/2003.07493>`_.

    TextSnake: `A Flexible Representation for Detecting Text of Arbitrary
    Shapes <https://arxiv.org/abs/1807.01544>`_.

    Args:
        in_channels (list[int]): Number of input channels at each scale. The
            length of the list should be 4.
        out_channels (int): The number of output channels.
        init_cfg (dict or list[dict], optional): Initialization configs.
            Defaults to None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
        super().__init__(init_cfg=init_cfg)

        assert isinstance(in_channels, int)
        assert isinstance(out_channels, int)

        self.conv1x1 = nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.deconv = nn.ConvTranspose2d(
            out_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward propagation."""
        x = F.relu(self.conv1x1(x))
        x = F.relu(self.conv3x3(x))
        x = self.deconv(x)
        return x


@MODELS.register_module()
class FPN_UNet(BaseModule):
    """The class for implementing DRRG and TextSnake U-Net-like FPN.

    DRRG: `Deep Relational Reasoning Graph Network for Arbitrary Shape
    Text Detection <https://arxiv.org/abs/2003.07493>`_.

    TextSnake: `A Flexible Representation for Detecting Text of Arbitrary
    Shapes <https://arxiv.org/abs/1807.01544>`_.

    Args:
        in_channels (list[int]): Number of input channels at each scale. The
            length of the list should be 4.
        out_channels (int): The number of output channels.
        init_cfg (dict or list[dict], optional): Initialization configs.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
            type='Xavier',
            layer=['Conv2d', 'ConvTranspose2d'],
            distribution='uniform')
    ) -> None:
        super().__init__(init_cfg=init_cfg)

        assert len(in_channels) == 4
        assert isinstance(out_channels, int)

        blocks_out_channels = [out_channels] + [
            min(out_channels * 2**i, 256) for i in range(4)
        ]
        blocks_in_channels = [blocks_out_channels[1]] + [
            in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
        ] + [in_channels[3]]

        self.up4 = nn.ConvTranspose2d(
            blocks_in_channels[4],
            blocks_out_channels[4],
            kernel_size=4,
            stride=2,
            padding=1)
        self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
        self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
        self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
        self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])

    def forward(self, x: List[Union[torch.Tensor,
                                    Tuple[torch.Tensor]]]) -> torch.Tensor:
        """
        Args:
            x (list[Tensor] | tuple[Tensor]): A list of four tensors of shape
                :math:`(N, C_i, H_i, W_i)`, representing C2, C3, C4, C5
                features respectively. :math:`C_i` should matches the number in
                ``in_channels``.

        Returns:
            Tensor: Shape :math:`(N, C, H, W)` where :math:`H=4H_0` and
            :math:`W=4W_0`.
        """
        c2, c3, c4, c5 = x

        x = F.relu(self.up4(c5))

        c4 = F.interpolate(
            c4, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, c4], dim=1)
        x = F.relu(self.up_block3(x))

        c3 = F.interpolate(
            c3, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, c3], dim=1)
        x = F.relu(self.up_block2(x))

        c2 = F.interpolate(
            c2, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, c2], dim=1)
        x = F.relu(self.up_block1(x))

        x = self.up_block0(x)
        # the output should be of the same height and width as backbone input
        return x