adirathor07's picture
added doctr folder
153628e
raw
history blame
15.1 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter
from doctr.datasets import VOCABS
from ...classification import resnet31
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
__all__ = ["SAR", "sar_resnet31"]
default_cfgs: Dict[str, Dict[str, Any]] = {
"sar_resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 128),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0",
},
}
class SAREncoder(nn.Module):
def __init__(self, in_feats: int, rnn_units: int, dropout_prob: float = 0.0) -> None:
super().__init__()
self.rnn = nn.LSTM(in_feats, rnn_units, 2, batch_first=True, dropout=dropout_prob)
self.linear = nn.Linear(rnn_units, rnn_units)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# (N, L, C) --> (N, T, C)
encoded = self.rnn(x)[0]
# (N, C)
return self.linear(encoded[:, -1, :])
class AttentionModule(nn.Module):
def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None:
super().__init__()
self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1)
# No need to add another bias since both tensors are summed together
self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False)
self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False)
def forward(
self,
features: torch.Tensor, # (N, C, H, W)
hidden_state: torch.Tensor, # (N, C)
) -> torch.Tensor:
H_f, W_f = features.shape[2:]
# (N, feat_chans, H, W) --> (N, attention_units, H, W)
feat_projection = self.feat_conv(features)
# (N, state_chans, 1, 1) --> (N, attention_units, 1, 1)
hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1)
state_projection = self.state_conv(hidden_state)
state_projection = state_projection.expand(-1, -1, H_f, W_f)
# (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f)
attention_weights = torch.tanh(feat_projection + state_projection)
# (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f)
attention_weights = self.attention_projector(attention_weights)
B, C, H, W = attention_weights.size()
# (N, H, W) --> (N, 1, H, W)
attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W)
# fuse features and attention weights (N, C)
return (features * attention_weights).sum(dim=(2, 3))
class SARDecoder(nn.Module):
"""Implements decoder module of the SAR model
Args:
----
rnn_units: number of hidden units in recurrent cells
max_length: maximum length of a sequence
vocab_size: number of classes in the model alphabet
embedding_units: number of hidden embedding units
attention_units: number of hidden attention units
"""
def __init__(
self,
rnn_units: int,
max_length: int,
vocab_size: int,
embedding_units: int,
attention_units: int,
feat_chans: int = 512,
dropout_prob: float = 0.0,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.max_length = max_length
self.embed = nn.Linear(self.vocab_size + 1, embedding_units)
self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1)
self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units)
self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units)
self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1)
self.dropout = nn.Dropout(dropout_prob)
def forward(
self,
features: torch.Tensor, # (N, C, H, W)
holistic: torch.Tensor, # (N, C)
gt: Optional[torch.Tensor] = None, # (N, L)
) -> torch.Tensor:
if gt is not None:
gt_embedding = self.embed_tgt(gt)
logits_list: List[torch.Tensor] = []
for t in range(self.max_length + 1): # 32
if t == 0:
# step to init the first states of the LSTMCell
hidden_state_init = cell_state_init = torch.zeros(
features.size(0), features.size(1), device=features.device, dtype=features.dtype
)
hidden_state, cell_state = hidden_state_init, cell_state_init
prev_symbol = holistic
elif t == 1:
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
# (N, vocab_size + 1) --> (N, embedding_units)
prev_symbol = torch.zeros(
features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
)
prev_symbol = self.embed(prev_symbol)
else:
if gt is not None and self.training:
# (N, embedding_units) -2 because of <bos> and <eos> (same)
prev_symbol = self.embed(gt_embedding[:, t - 2])
else:
# -1 to start at timestep where prev_symbol was initialized
index = logits_list[t - 1].argmax(-1)
# update prev_symbol with ones at the index of the previous logit vector
prev_symbol = self.embed(self.embed_tgt(index))
# (N, C), (N, C) take the last hidden state and cell state from current timestep
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state))
# (N, C, H, W), (N, C) --> (N, C)
glimpse = self.attention_module(features, hidden_state)
# (N, C), (N, C) --> (N, 2 * C)
logits = torch.cat([hidden_state, glimpse], dim=1)
logits = self.dropout(logits)
# (N, vocab_size + 1)
logits_list.append(self.output_dense(logits))
# (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
return torch.stack(logits_list[1:]).permute(1, 0, 2)
class SAR(nn.Module, RecognitionModel):
"""Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
Args:
----
feature_extractor: the backbone serving as feature extractor
vocab: vocabulary used for encoding
rnn_units: number of hidden units in both encoder and decoder LSTM
embedding_units: number of embedding units
attention_units: number of hidden units in attention module
max_length: maximum word length handled by the model
dropout_prob: dropout probability of the encoder LSTM
exportable: onnx exportable returns only logits
cfg: dictionary containing information about the model
"""
def __init__(
self,
feature_extractor,
vocab: str,
rnn_units: int = 512,
embedding_units: int = 512,
attention_units: int = 512,
max_length: int = 30,
dropout_prob: float = 0.0,
input_shape: Tuple[int, int, int] = (3, 32, 128),
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.vocab = vocab
self.exportable = exportable
self.cfg = cfg
self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
self.feat_extractor = feature_extractor
# Size the LSTM
self.feat_extractor.eval()
with torch.no_grad():
out_shape = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape
# Switch back to original mode
self.feat_extractor.train()
self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob)
self.decoder = SARDecoder(
rnn_units,
self.max_length,
len(self.vocab),
embedding_units,
attention_units,
dropout_prob=dropout_prob,
)
self.postprocessor = SARPostProcessor(vocab=vocab)
for n, m in self.named_modules():
# Don't override the initialization of the backbone
if n.startswith("feat_extractor."):
continue
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
x: torch.Tensor,
target: Optional[List[str]] = None,
return_model_output: bool = False,
return_preds: bool = False,
) -> Dict[str, Any]:
features = self.feat_extractor(x)["features"]
# NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
# Vertical max pooling (N, C, H, W) --> (N, C, W)
pooled_features = features.max(dim=-2).values
# (N, W, C)
pooled_features = pooled_features.permute(0, 2, 1).contiguous()
# (N, C)
encoded = self.encoder(pooled_features)
if target is not None:
_gt, _seq_len = self.build_target(target)
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
gt, seq_len = gt.to(x.device), seq_len.to(x.device)
if self.training and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
out: Dict[str, Any] = {}
if self.exportable:
out["logits"] = decoded_features
return out
if return_model_output:
out["out_map"] = decoded_features
if target is None or return_preds:
# Post-process boxes
out["preds"] = self.postprocessor(decoded_features)
if target is not None:
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
return out
@staticmethod
def compute_loss(
model_output: torch.Tensor,
gt: torch.Tensor,
seq_len: torch.Tensor,
) -> torch.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.
Args:
----
model_output: predicted logits of the model
gt: the encoded tensor with gt labels
seq_len: lengths of each gt word inside the batch
Returns:
-------
The loss of the model on the batch
"""
# Input length : number of timesteps
input_len = model_output.shape[1]
# Add one for additional <eos> token
seq_len = seq_len + 1
# Compute loss
# (N, L, vocab_size + 1)
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None]
cce[mask_2d] = 0
ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
return ce_loss.mean()
class SARPostProcessor(RecognitionPostProcessor):
"""Post processor for SAR architectures
Args:
----
vocab: string containing the ordered sequence of supported characters
"""
def __call__(
self,
logits: torch.Tensor,
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = logits.argmax(-1)
# N x L
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
# Take the minimum confidence of the sequence
probs = probs.min(dim=1).values.detach().cpu()
# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.detach().cpu().numpy()
]
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
def _sar(
arch: str,
pretrained: bool,
backbone_fn: Callable[[bool], nn.Module],
layer: str,
pretrained_backbone: bool = True,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> SAR:
pretrained_backbone = pretrained_backbone and not pretrained
# Patch the config
_cfg = deepcopy(default_cfgs[arch])
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
_cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
# Feature extractor
feat_extractor = IntermediateLayerGetter(
backbone_fn(pretrained_backbone),
{layer: "features"},
)
kwargs["vocab"] = _cfg["vocab"]
kwargs["input_shape"] = _cfg["input_shape"]
# Build the model
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
return model
def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_.
>>> import torch
>>> from doctr.models import sar_resnet31
>>> model = sar_resnet31(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 128))
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
**kwargs: keyword arguments of the SAR architecture
Returns:
-------
text recognition architecture
"""
return _sar(
"sar_resnet31",
pretrained,
resnet31,
"10",
ignore_keys=[
"decoder.embed.weight",
"decoder.embed_tgt.weight",
"decoder.output_dense.weight",
"decoder.output_dense.bias",
],
**kwargs,
)