File size: 1,483 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch.nn as nn
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2
from torch import Tensor

from mmocr.registry import MODELS
from mmocr.utils.typing_utils import InitConfigType


@MODELS.register_module()
class MobileNetV2(MMDet_MobileNetV2):
    """See mmdet.models.backbones.MobileNetV2 for details.

    Args:
        pooling_layers (list): List of indices of pooling layers.
        init_cfg (InitConfigType, optional): Initialization config dict.
    """
    # Parameters to build layers. 4 parameters are needed to construct a
    # layer, from left to right: expand_ratio, channel, num_blocks, stride.
    arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1],
                     [6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1],
                     [6, 320, 1, 1]]

    def __init__(self,
                 pooling_layers: List = [3, 4, 5],
                 init_cfg: InitConfigType = None) -> None:
        super().__init__(init_cfg=init_cfg)
        self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
        self.pooling_layers = pooling_layers

    def forward(self, x: Tensor) -> Tensor:
        """Forward function."""

        x = self.conv1(x)
        for i, layer_name in enumerate(self.layers):
            layer = getattr(self, layer_name)
            x = layer(x)
            if i in self.pooling_layers:
                x = self.pooling(x)

        return x