Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
import torch.nn as nn | |
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2 | |
from torch import Tensor | |
from mmocr.registry import MODELS | |
from mmocr.utils.typing_utils import InitConfigType | |
class MobileNetV2(MMDet_MobileNetV2): | |
"""See mmdet.models.backbones.MobileNetV2 for details. | |
Args: | |
pooling_layers (list): List of indices of pooling layers. | |
init_cfg (InitConfigType, optional): Initialization config dict. | |
""" | |
# Parameters to build layers. 4 parameters are needed to construct a | |
# layer, from left to right: expand_ratio, channel, num_blocks, stride. | |
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1], | |
[6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1], | |
[6, 320, 1, 1]] | |
def __init__(self, | |
pooling_layers: List = [3, 4, 5], | |
init_cfg: InitConfigType = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) | |
self.pooling_layers = pooling_layers | |
def forward(self, x: Tensor) -> Tensor: | |
"""Forward function.""" | |
x = self.conv1(x) | |
for i, layer_name in enumerate(self.layers): | |
layer = getattr(self, layer_name) | |
x = layer(x) | |
if i in self.pooling_layers: | |
x = self.pooling(x) | |
return x | |