Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
1.48 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch.nn as nn
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import InitConfigType
@MODELS.register_module()
class MobileNetV2(MMDet_MobileNetV2):
"""See mmdet.models.backbones.MobileNetV2 for details.
Args:
pooling_layers (list): List of indices of pooling layers.
init_cfg (InitConfigType, optional): Initialization config dict.
"""
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1],
[6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1],
[6, 320, 1, 1]]
def __init__(self,
pooling_layers: List = [3, 4, 5],
init_cfg: InitConfigType = None) -> None:
super().__init__(init_cfg=init_cfg)
self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.pooling_layers = pooling_layers
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
x = self.conv1(x)
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.pooling_layers:
x = self.pooling(x)
return x