Spaces:
Runtime error
Runtime error
File size: 11,802 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from copy import deepcopy
from itertools import groupby
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from doctr.datasets import VOCABS, decode_sequence
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.pytorch import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
default_cfgs: Dict[str, Dict[str, Any]] = {
"crnn_vgg16_bn": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 128),
"vocab": VOCABS["legacy_french"],
"url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0",
},
"crnn_mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 128),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small_pt-3b919a02.pt&src=0",
},
"crnn_mobilenet_v3_large": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 128),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_large_pt-f5259ec2.pt&src=0",
},
}
class CTCPostProcessor(RecognitionPostProcessor):
"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
Args:
----
vocab: string containing the ordered sequence of supported characters
"""
@staticmethod
def ctc_best_path(
logits: torch.Tensor,
vocab: str = VOCABS["french"],
blank: int = 0,
) -> List[Tuple[str, float]]:
"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
<https://github.com/githubharald/CTCDecoder>`_.
Args:
----
logits: model output, shape: N x T x C
vocab: vocabulary to use
blank: index of blank label
Returns:
-------
A list of tuples: (word, confidence)
"""
# Gather the most confident characters, and assign the smallest conf among those to the sequence prob
probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
# collapse best path (using itertools.groupby), map to chars, join char list to string
words = [
decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
for seq in torch.argmax(logits, dim=-1)
]
return list(zip(words, probs.tolist()))
def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
"""Performs decoding of raw output with CTC and decoding of CTC predictions
with label_to_idx mapping dictionnary
Args:
----
logits: raw output of the model, shape (N, C + 1, seq_len)
Returns:
-------
A tuple of 2 lists: a list of str (words) and a list of float (probs)
"""
# Decode CTC
return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
class CRNN(RecognitionModel, nn.Module):
"""Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
Args:
----
feature_extractor: the backbone serving as feature extractor
vocab: vocabulary used for encoding
rnn_units: number of units in the LSTM layers
exportable: onnx exportable returns only logits
cfg: configuration dictionary
"""
_children_names: List[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
def __init__(
self,
feature_extractor: nn.Module,
vocab: str,
rnn_units: int = 128,
input_shape: Tuple[int, int, int] = (3, 32, 128),
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.vocab = vocab
self.cfg = cfg
self.max_length = 32
self.exportable = exportable
self.feat_extractor = feature_extractor
# Resolve the input_size of the LSTM
with torch.inference_mode():
out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape
lstm_in = out_shape[1] * out_shape[2]
self.decoder = nn.LSTM(
input_size=lstm_in,
hidden_size=rnn_units,
batch_first=True,
num_layers=2,
bidirectional=True,
)
# features units = 2 * rnn_units because bidirectional layers
self.linear = nn.Linear(in_features=2 * rnn_units, out_features=len(vocab) + 1)
self.postprocessor = CTCPostProcessor(vocab=vocab)
for n, m in self.named_modules():
# Don't override the initialization of the backbone
if n.startswith("feat_extractor."):
continue
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def compute_loss(
self,
model_output: torch.Tensor,
target: List[str],
) -> torch.Tensor:
"""Compute CTC loss for the model.
Args:
----
model_output: predicted logits of the model
target: list of target strings
Returns:
-------
The loss of the model on the batch
"""
gt, seq_len = self.build_target(target)
batch_len = model_output.shape[0]
input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
# N x T x C -> T x N x C
logits = model_output.permute(1, 0, 2)
probs = F.log_softmax(logits, dim=-1)
ctc_loss = F.ctc_loss(
probs,
torch.from_numpy(gt),
input_length,
torch.tensor(seq_len, dtype=torch.int),
len(self.vocab),
zero_infinity=True,
)
return ctc_loss
def forward(
self,
x: torch.Tensor,
target: Optional[List[str]] = None,
return_model_output: bool = False,
return_preds: bool = False,
) -> Dict[str, Any]:
if self.training and target is None:
raise ValueError("Need to provide labels during training")
features = self.feat_extractor(x)
# B x C x H x W --> B x C*H x W --> B x W x C*H
c, h, w = features.shape[1], features.shape[2], features.shape[3]
features_seq = torch.reshape(features, shape=(-1, h * c, w))
features_seq = torch.transpose(features_seq, 1, 2)
logits, _ = self.decoder(features_seq)
logits = self.linear(logits)
out: Dict[str, Any] = {}
if self.exportable:
out["logits"] = logits
return out
if return_model_output:
out["out_map"] = logits
if target is None or return_preds:
# Post-process boxes
out["preds"] = self.postprocessor(logits)
if target is not None:
out["loss"] = self.compute_loss(logits, target)
return out
def _crnn(
arch: str,
pretrained: bool,
backbone_fn: Callable[[Any], nn.Module],
pretrained_backbone: bool = True,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> CRNN:
pretrained_backbone = pretrained_backbone and not pretrained
# Feature extractor
feat_extractor = backbone_fn(pretrained=pretrained_backbone).features # type: ignore[call-arg]
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
_cfg = deepcopy(default_cfgs[arch])
_cfg["vocab"] = kwargs["vocab"]
_cfg["input_shape"] = kwargs["input_shape"]
# Build the model
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys)
return model
def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
>>> import torch
>>> from doctr.models import crnn_vgg16_bn
>>> model = crnn_vgg16_bn(pretrained=True)
>>> input_tensor = torch.rand(1, 3, 32, 128)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
>>> import torch
>>> from doctr.models import crnn_mobilenet_v3_small
>>> model = crnn_mobilenet_v3_small(pretrained=True)
>>> input_tensor = torch.rand(1, 3, 32, 128)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
return _crnn(
"crnn_mobilenet_v3_small",
pretrained,
mobilenet_v3_small_r,
ignore_keys=["linear.weight", "linear.bias"],
**kwargs,
)
def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
>>> import torch
>>> from doctr.models import crnn_mobilenet_v3_large
>>> model = crnn_mobilenet_v3_large(pretrained=True)
>>> input_tensor = torch.rand(1, 3, 32, 128)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
return _crnn(
"crnn_mobilenet_v3_large",
pretrained,
mobilenet_v3_large_r,
ignore_keys=["linear.weight", "linear.bias"],
**kwargs,
)
|