Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
from copy import deepcopy | |
from typing import Any, Callable, Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision.models._utils import IntermediateLayerGetter | |
from doctr.datasets import VOCABS | |
from ...classification import resnet31 | |
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params | |
from ..core import RecognitionModel, RecognitionPostProcessor | |
__all__ = ["SAR", "sar_resnet31"] | |
default_cfgs: Dict[str, Dict[str, Any]] = { | |
"sar_resnet31": { | |
"mean": (0.694, 0.695, 0.693), | |
"std": (0.299, 0.296, 0.301), | |
"input_shape": (3, 32, 128), | |
"vocab": VOCABS["french"], | |
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0", | |
}, | |
} | |
class SAREncoder(nn.Module): | |
def __init__(self, in_feats: int, rnn_units: int, dropout_prob: float = 0.0) -> None: | |
super().__init__() | |
self.rnn = nn.LSTM(in_feats, rnn_units, 2, batch_first=True, dropout=dropout_prob) | |
self.linear = nn.Linear(rnn_units, rnn_units) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# (N, L, C) --> (N, T, C) | |
encoded = self.rnn(x)[0] | |
# (N, C) | |
return self.linear(encoded[:, -1, :]) | |
class AttentionModule(nn.Module): | |
def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None: | |
super().__init__() | |
self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1) | |
# No need to add another bias since both tensors are summed together | |
self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False) | |
self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False) | |
def forward( | |
self, | |
features: torch.Tensor, # (N, C, H, W) | |
hidden_state: torch.Tensor, # (N, C) | |
) -> torch.Tensor: | |
H_f, W_f = features.shape[2:] | |
# (N, feat_chans, H, W) --> (N, attention_units, H, W) | |
feat_projection = self.feat_conv(features) | |
# (N, state_chans, 1, 1) --> (N, attention_units, 1, 1) | |
hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1) | |
state_projection = self.state_conv(hidden_state) | |
state_projection = state_projection.expand(-1, -1, H_f, W_f) | |
# (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f) | |
attention_weights = torch.tanh(feat_projection + state_projection) | |
# (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f) | |
attention_weights = self.attention_projector(attention_weights) | |
B, C, H, W = attention_weights.size() | |
# (N, H, W) --> (N, 1, H, W) | |
attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W) | |
# fuse features and attention weights (N, C) | |
return (features * attention_weights).sum(dim=(2, 3)) | |
class SARDecoder(nn.Module): | |
"""Implements decoder module of the SAR model | |
Args: | |
---- | |
rnn_units: number of hidden units in recurrent cells | |
max_length: maximum length of a sequence | |
vocab_size: number of classes in the model alphabet | |
embedding_units: number of hidden embedding units | |
attention_units: number of hidden attention units | |
""" | |
def __init__( | |
self, | |
rnn_units: int, | |
max_length: int, | |
vocab_size: int, | |
embedding_units: int, | |
attention_units: int, | |
feat_chans: int = 512, | |
dropout_prob: float = 0.0, | |
) -> None: | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.max_length = max_length | |
self.embed = nn.Linear(self.vocab_size + 1, embedding_units) | |
self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1) | |
self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units) | |
self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units) | |
self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1) | |
self.dropout = nn.Dropout(dropout_prob) | |
def forward( | |
self, | |
features: torch.Tensor, # (N, C, H, W) | |
holistic: torch.Tensor, # (N, C) | |
gt: Optional[torch.Tensor] = None, # (N, L) | |
) -> torch.Tensor: | |
if gt is not None: | |
gt_embedding = self.embed_tgt(gt) | |
logits_list: List[torch.Tensor] = [] | |
for t in range(self.max_length + 1): # 32 | |
if t == 0: | |
# step to init the first states of the LSTMCell | |
hidden_state_init = cell_state_init = torch.zeros( | |
features.size(0), features.size(1), device=features.device, dtype=features.dtype | |
) | |
hidden_state, cell_state = hidden_state_init, cell_state_init | |
prev_symbol = holistic | |
elif t == 1: | |
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros | |
# (N, vocab_size + 1) --> (N, embedding_units) | |
prev_symbol = torch.zeros( | |
features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype | |
) | |
prev_symbol = self.embed(prev_symbol) | |
else: | |
if gt is not None and self.training: | |
# (N, embedding_units) -2 because of <bos> and <eos> (same) | |
prev_symbol = self.embed(gt_embedding[:, t - 2]) | |
else: | |
# -1 to start at timestep where prev_symbol was initialized | |
index = logits_list[t - 1].argmax(-1) | |
# update prev_symbol with ones at the index of the previous logit vector | |
prev_symbol = self.embed(self.embed_tgt(index)) | |
# (N, C), (N, C) take the last hidden state and cell state from current timestep | |
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init)) | |
hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state)) | |
# (N, C, H, W), (N, C) --> (N, C) | |
glimpse = self.attention_module(features, hidden_state) | |
# (N, C), (N, C) --> (N, 2 * C) | |
logits = torch.cat([hidden_state, glimpse], dim=1) | |
logits = self.dropout(logits) | |
# (N, vocab_size + 1) | |
logits_list.append(self.output_dense(logits)) | |
# (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1) | |
return torch.stack(logits_list[1:]).permute(1, 0, 2) | |
class SAR(nn.Module, RecognitionModel): | |
"""Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for | |
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_. | |
Args: | |
---- | |
feature_extractor: the backbone serving as feature extractor | |
vocab: vocabulary used for encoding | |
rnn_units: number of hidden units in both encoder and decoder LSTM | |
embedding_units: number of embedding units | |
attention_units: number of hidden units in attention module | |
max_length: maximum word length handled by the model | |
dropout_prob: dropout probability of the encoder LSTM | |
exportable: onnx exportable returns only logits | |
cfg: dictionary containing information about the model | |
""" | |
def __init__( | |
self, | |
feature_extractor, | |
vocab: str, | |
rnn_units: int = 512, | |
embedding_units: int = 512, | |
attention_units: int = 512, | |
max_length: int = 30, | |
dropout_prob: float = 0.0, | |
input_shape: Tuple[int, int, int] = (3, 32, 128), | |
exportable: bool = False, | |
cfg: Optional[Dict[str, Any]] = None, | |
) -> None: | |
super().__init__() | |
self.vocab = vocab | |
self.exportable = exportable | |
self.cfg = cfg | |
self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word | |
self.feat_extractor = feature_extractor | |
# Size the LSTM | |
self.feat_extractor.eval() | |
with torch.no_grad(): | |
out_shape = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape | |
# Switch back to original mode | |
self.feat_extractor.train() | |
self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob) | |
self.decoder = SARDecoder( | |
rnn_units, | |
self.max_length, | |
len(self.vocab), | |
embedding_units, | |
attention_units, | |
dropout_prob=dropout_prob, | |
) | |
self.postprocessor = SARPostProcessor(vocab=vocab) | |
for n, m in self.named_modules(): | |
# Don't override the initialization of the backbone | |
if n.startswith("feat_extractor."): | |
continue | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward( | |
self, | |
x: torch.Tensor, | |
target: Optional[List[str]] = None, | |
return_model_output: bool = False, | |
return_preds: bool = False, | |
) -> Dict[str, Any]: | |
features = self.feat_extractor(x)["features"] | |
# NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size) | |
# Vertical max pooling (N, C, H, W) --> (N, C, W) | |
pooled_features = features.max(dim=-2).values | |
# (N, W, C) | |
pooled_features = pooled_features.permute(0, 2, 1).contiguous() | |
# (N, C) | |
encoded = self.encoder(pooled_features) | |
if target is not None: | |
_gt, _seq_len = self.build_target(target) | |
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) | |
gt, seq_len = gt.to(x.device), seq_len.to(x.device) | |
if self.training and target is None: | |
raise ValueError("Need to provide labels during training for teacher forcing") | |
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt)) | |
out: Dict[str, Any] = {} | |
if self.exportable: | |
out["logits"] = decoded_features | |
return out | |
if return_model_output: | |
out["out_map"] = decoded_features | |
if target is None or return_preds: | |
# Post-process boxes | |
out["preds"] = self.postprocessor(decoded_features) | |
if target is not None: | |
out["loss"] = self.compute_loss(decoded_features, gt, seq_len) | |
return out | |
def compute_loss( | |
model_output: torch.Tensor, | |
gt: torch.Tensor, | |
seq_len: torch.Tensor, | |
) -> torch.Tensor: | |
"""Compute categorical cross-entropy loss for the model. | |
Sequences are masked after the EOS character. | |
Args: | |
---- | |
model_output: predicted logits of the model | |
gt: the encoded tensor with gt labels | |
seq_len: lengths of each gt word inside the batch | |
Returns: | |
------- | |
The loss of the model on the batch | |
""" | |
# Input length : number of timesteps | |
input_len = model_output.shape[1] | |
# Add one for additional <eos> token | |
seq_len = seq_len + 1 | |
# Compute loss | |
# (N, L, vocab_size + 1) | |
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none") | |
mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None] | |
cce[mask_2d] = 0 | |
ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) | |
return ce_loss.mean() | |
class SARPostProcessor(RecognitionPostProcessor): | |
"""Post processor for SAR architectures | |
Args: | |
---- | |
vocab: string containing the ordered sequence of supported characters | |
""" | |
def __call__( | |
self, | |
logits: torch.Tensor, | |
) -> List[Tuple[str, float]]: | |
# compute pred with argmax for attention models | |
out_idxs = logits.argmax(-1) | |
# N x L | |
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) | |
# Take the minimum confidence of the sequence | |
probs = probs.min(dim=1).values.detach().cpu() | |
# Manual decoding | |
word_values = [ | |
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0] | |
for encoded_seq in out_idxs.detach().cpu().numpy() | |
] | |
return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) | |
def _sar( | |
arch: str, | |
pretrained: bool, | |
backbone_fn: Callable[[bool], nn.Module], | |
layer: str, | |
pretrained_backbone: bool = True, | |
ignore_keys: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> SAR: | |
pretrained_backbone = pretrained_backbone and not pretrained | |
# Patch the config | |
_cfg = deepcopy(default_cfgs[arch]) | |
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) | |
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) | |
# Feature extractor | |
feat_extractor = IntermediateLayerGetter( | |
backbone_fn(pretrained_backbone), | |
{layer: "features"}, | |
) | |
kwargs["vocab"] = _cfg["vocab"] | |
kwargs["input_shape"] = _cfg["input_shape"] | |
# Build the model | |
model = SAR(feat_extractor, cfg=_cfg, **kwargs) | |
# Load pretrained parameters | |
if pretrained: | |
# The number of classes is not the same as the number of classes in the pretrained model => | |
# remove the last layer weights | |
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None | |
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) | |
return model | |
def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: | |
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong | |
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_. | |
>>> import torch | |
>>> from doctr.models import sar_resnet31 | |
>>> model = sar_resnet31(pretrained=False) | |
>>> input_tensor = torch.rand((1, 3, 32, 128)) | |
>>> out = model(input_tensor) | |
Args: | |
---- | |
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset | |
**kwargs: keyword arguments of the SAR architecture | |
Returns: | |
------- | |
text recognition architecture | |
""" | |
return _sar( | |
"sar_resnet31", | |
pretrained, | |
resnet31, | |
"10", | |
ignore_keys=[ | |
"decoder.embed.weight", | |
"decoder.embed_tgt.weight", | |
"decoder.output_dense.weight", | |
"decoder.output_dense.bias", | |
], | |
**kwargs, | |
) | |