# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. # Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py from copy import deepcopy from typing import Any, Dict, List, Optional from torchvision.models import mobilenetv3 from doctr.datasets import VOCABS from ...utils import load_pretrained_params __all__ = [ "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", "mobilenet_v3_large_r", "mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation", ] default_cfgs: Dict[str, Dict[str, Any]] = { "mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-11fc8cb9.pt&src=0", }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-74a22066.pt&src=0", }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-6a4bfa6b.pt&src=0", }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0", }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 256, 256), "classes": [0, -90, 180, 90], "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0", }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 512, 512), "classes": [0, -90, 180, 90], "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0", }, } def _mobilenet_v3( arch: str, pretrained: bool, rect_strides: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> mobilenetv3.MobileNetV3: kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) _cfg = deepcopy(default_cfgs[arch]) _cfg["num_classes"] = kwargs["num_classes"] _cfg["classes"] = kwargs["classes"] kwargs.pop("classes") if arch.startswith("mobilenet_v3_small"): model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None) else: model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None) # Rectangular strides if isinstance(rect_strides, list): for layer_name in rect_strides: m = model for child in layer_name.split("."): m = getattr(m, child) m.stride = (2, 1) # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => # remove the last layer weights _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) model.cfg = _cfg return model def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_. >>> import torch >>> from doctr.models import mobilenet_v3_small >>> model = mobilenetv3_small(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs ) def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_, with rectangular pooling. >>> import torch >>> from doctr.models import mobilenet_v3_small_r >>> model = mobilenet_v3_small_r(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_small_r", pretrained, ["features.2.block.1.0", "features.4.block.1.0", "features.9.block.1.0"], ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, ) def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Large architecture as described in `"Searching for MobileNetV3", `_. >>> import torch >>> from doctr.models import mobilenet_v3_large >>> model = mobilenet_v3_large(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_large", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, ) def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Large architecture as described in `"Searching for MobileNetV3", `_, with rectangular pooling. >>> import torch >>> from doctr.models import mobilenet_v3_large_r >>> model = mobilenet_v3_large_r(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_large_r", pretrained, ["features.4.block.1.0", "features.7.block.1.0", "features.13.block.1.0"], ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, ) def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_. >>> import torch >>> from doctr.models import mobilenet_v3_small_crop_orientation >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_small_crop_orientation", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, ) def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_. >>> import torch >>> from doctr.models import mobilenet_v3_small_page_orientation >>> model = mobilenet_v3_small_page_orientation(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- a torch.nn.Module """ return _mobilenet_v3( "mobilenet_v3_small_page_orientation", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, )