# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn from torch.nn import functional as F from torchvision.models._utils import IntermediateLayerGetter from doctr.datasets import VOCABS from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding from ...utils.pytorch import _bf16_to_float32, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] default_cfgs: Dict[str, Dict[str, Any]] = { "master": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 128), "vocab": VOCABS["french"], "url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0", }, } class MASTER(_MASTER, nn.Module): """Implements MASTER as described in paper: `_. Implementation based on the official Pytorch implementation: `_. Args: ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary, (without EOS, SOS, PAD) d_model: d parameter for the transformer decoder dff: depth of the pointwise feed-forward layer num_heads: number of heads for the mutli-head attention module num_layers: number of decoder layers to stack max_length: maximum length of character sequence handled by the model dropout: dropout probability of the decoder input_shape: size of the image inputs exportable: onnx exportable returns only logits cfg: dictionary containing information about the model """ def __init__( self, feature_extractor: nn.Module, vocab: str, d_model: int = 512, dff: int = 2048, num_heads: int = 8, # number of heads in the transformer decoder num_layers: int = 3, max_length: int = 50, dropout: float = 0.2, input_shape: Tuple[int, int, int] = (3, 32, 128), # different from the paper exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.exportable = exportable self.max_length = max_length self.d_model = d_model self.vocab = vocab self.cfg = cfg self.vocab_size = len(vocab) self.feat_extractor = feature_extractor self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[1] * input_shape[2]) self.decoder = Decoder( num_layers=num_layers, d_model=self.d_model, num_heads=num_heads, vocab_size=self.vocab_size + 3, # EOS, SOS, PAD dff=dff, dropout=dropout, maximum_position_encoding=self.max_length, ) self.linear = nn.Linear(self.d_model, self.vocab_size + 3) self.postprocessor = MASTERPostProcessor(vocab=self.vocab) for n, m in self.named_modules(): # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def make_source_and_target_mask( self, source: torch.Tensor, target: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch # NOTE: nn.TransformerDecoder takes the inverse from this implementation # [True, True, True, ..., False, False, False] -> False is masked # (N, 1, 1, max_length) target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) target_length = target.size(1) # sub mask filled diagonal with True = see and False = masked (max_length, max_length) # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14) target_sub_mask = torch.tril(torch.ones((target_length, target_length), device=source.device), diagonal=0).to( dtype=torch.bool ) # source mask filled with ones (max_length, positional_encoded_seq_len) source_mask = torch.ones((target_length, source.size(1)), dtype=torch.uint8, device=source.device) # combine the two masks into one (N, 1, max_length, max_length) target_mask = target_pad_mask & target_sub_mask return source_mask, target_mask.int() @staticmethod def compute_loss( model_output: torch.Tensor, gt: torch.Tensor, seq_len: torch.Tensor, ) -> torch.Tensor: """Compute categorical cross-entropy loss for the model. Sequences are masked after the EOS character. Args: ---- gt: the encoded tensor with gt labels model_output: predicted logits of the model seq_len: lengths of each gt word inside the batch Returns: ------- The loss of the model on the batch """ # Input length : number of timesteps input_len = model_output.shape[1] # Add one for additional token (sos disappear in shift!) seq_len = seq_len + 1 # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! # The "masked" first gt char is . Delete last logit of the model output. cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none") # Compute mask, remove 1 timestep here as well mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None] cce[mask_2d] = 0 ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) return ce_loss.mean() def forward( self, x: torch.Tensor, target: Optional[List[str]] = None, return_model_output: bool = False, return_preds: bool = False, ) -> Dict[str, Any]: """Call function for training Args: ---- x: images target: list of str labels return_model_output: if True, return logits return_preds: if True, decode logits Returns: ------- A dictionnary containing eventually loss, logits and predictions. """ # Encode features = self.feat_extractor(x)["features"] b, c, h, w = features.shape # (N, C, H, W) --> (N, H * W, C) features = features.view(b, c, h * w).permute((0, 2, 1)) # add positional encoding to features encoded = self.positional_encoding(features) out: Dict[str, Any] = {} if self.training and target is None: raise ValueError("Need to provide labels during training") if target is not None: # Compute target: tensor of gts and sequence lengths _gt, _seq_len = self.build_target(target) gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) gt, seq_len = gt.to(x.device), seq_len.to(x.device) # Compute source mask and target mask source_mask, target_mask = self.make_source_and_target_mask(encoded, gt) output = self.decoder(gt, encoded, source_mask, target_mask) # Compute logits logits = self.linear(output) else: logits = self.decode(encoded) logits = _bf16_to_float32(logits) if self.exportable: out["logits"] = logits return out if target is not None: out["loss"] = self.compute_loss(logits, gt, seq_len) if return_model_output: out["out_map"] = logits if return_preds: out["preds"] = self.postprocessor(logits) return out def decode(self, encoded: torch.Tensor) -> torch.Tensor: """Decode function for prediction Args: ---- encoded: input tensor Returns: ------- A Tuple of torch.Tensor: predictions, logits """ b = encoded.size(0) # Padding symbol + SOS at the beginning ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device) # pad ys[:, 0] = self.vocab_size + 1 # sos # Final dimension include EOS/SOS/PAD for i in range(self.max_length - 1): source_mask, target_mask = self.make_source_and_target_mask(encoded, ys) output = self.decoder(ys, encoded, source_mask, target_mask) logits = self.linear(output) prob = torch.softmax(logits, dim=-1) next_token = torch.max(prob, dim=-1).indices # update ys with the next token and ignore the first token (SOS) ys[:, i + 1] = next_token[:, i] # Shape (N, max_length, vocab_size + 1) return logits class MASTERPostProcessor(_MASTERPostProcessor): """Post processor for MASTER architectures""" def __call__( self, logits: torch.Tensor, ) -> List[Tuple[str, float]]: # compute pred with argmax for attention models out_idxs = logits.argmax(-1) # N x L probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) # Take the minimum confidence of the sequence probs = probs.min(dim=1).values.detach().cpu() # Manual decoding word_values = [ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs.cpu().numpy() ] return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) def _master( arch: str, pretrained: bool, backbone_fn: Callable[[bool], nn.Module], layer: str, pretrained_backbone: bool = True, ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> MASTER: pretrained_backbone = pretrained_backbone and not pretrained # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) kwargs["vocab"] = _cfg["vocab"] kwargs["input_shape"] = _cfg["input_shape"] # Build the model feat_extractor = IntermediateLayerGetter( backbone_fn(pretrained_backbone), {layer: "features"}, ) model = MASTER(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => # remove the last layer weights _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model def master(pretrained: bool = False, **kwargs: Any) -> MASTER: """MASTER as described in paper: `_. >>> import torch >>> from doctr.models import master >>> model = master(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 128)) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keywoard arguments passed to the MASTER architecture Returns: ------- text recognition architecture """ return _master( "master", pretrained, magc_resnet31, "10", ignore_keys=[ "decoder.embed.weight", "linear.weight", "linear.bias", ], **kwargs, )