adirathor07's picture
added doctr folder
153628e
raw
history blame
3.14 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from copy import deepcopy
from typing import Any, Dict, List, Optional
from torch import nn
from torchvision.models import vgg as tv_vgg
from doctr.datasets import VOCABS
from ...utils import load_pretrained_params
__all__ = ["vgg16_bn_r"]
default_cfgs: Dict[str, Dict[str, Any]] = {
"vgg16_bn_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-d108c19c.pt&src=0",
},
}
def _vgg(
arch: str,
pretrained: bool,
tv_arch: str,
num_rect_pools: int = 3,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> tv_vgg.VGG:
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
_cfg = deepcopy(default_cfgs[arch])
_cfg["num_classes"] = kwargs["num_classes"]
_cfg["classes"] = kwargs["classes"]
kwargs.pop("classes")
# Build the model
model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None)
# List the MaxPool2d
pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
# Replace their kernel with rectangular ones
for idx in pool_idcs[-num_rect_pools:]:
model.features[idx] = nn.MaxPool2d((2, 1))
# Patch average pool & classification head
model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
model.classifier = nn.Linear(512, kwargs["num_classes"])
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
model.cfg = _cfg
return model
def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
"""VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
<https://arxiv.org/pdf/1409.1556.pdf>`_, modified by adding batch normalization, rectangular pooling and a simpler
classification head.
>>> import torch
>>> from doctr.models import vgg16_bn_r
>>> model = vgg16_bn_r(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on ImageNet
**kwargs: keyword arguments of the VGG architecture
Returns:
-------
VGG feature extractor
"""
return _vgg(
"vgg16_bn_r",
pretrained,
"vgg16_bn",
3,
ignore_keys=["classifier.weight", "classifier.bias"],
**kwargs,
)