# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Sequence, Union import torch import torch.nn as nn from mmocr.models.common.dictionary import Dictionary from mmocr.registry import MODELS from mmocr.structures import TextRecogDataSample from .base import BaseDecoder @MODELS.register_module() class SVTRDecoder(BaseDecoder): """Decoder module in `SVTR `_. Args: in_channels (int): The num of input channels. dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or the instance of `Dictionary`. Defaults to None. module_loss (Optional[Dict], optional): Cfg to build module_loss. Defaults to None. postprocessor (Optional[Dict], optional): Cfg to build postprocessor. Defaults to None. max_seq_len (int, optional): Maximum output sequence length :math:`T`. Defaults to 25. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to None. """ def __init__(self, in_channels: int, dictionary: Union[Dict, Dictionary] = None, module_loss: Optional[Dict] = None, postprocessor: Optional[Dict] = None, max_seq_len: int = 25, init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: super().__init__( dictionary=dictionary, module_loss=module_loss, postprocessor=postprocessor, max_seq_len=max_seq_len, init_cfg=init_cfg) self.decoder = nn.Linear( in_features=in_channels, out_features=self.dictionary.num_classes) self.softmax = nn.Softmax(dim=-1) def forward_train( self, feat: Optional[torch.Tensor] = None, out_enc: Optional[torch.Tensor] = None, data_samples: Optional[Sequence[TextRecogDataSample]] = None ) -> torch.Tensor: """Forward for training. Args: feat (torch.Tensor, optional): The feature map. Defaults to None. out_enc (torch.Tensor, optional): Encoder output from encoder of shape :math:`(N, 1, H, W)`. Defaults to None. data_samples (Sequence[TextRecogDataSample]): Batch of TextRecogDataSample, containing gt_text information. Defaults to None. Returns: Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where :math:`C` is ``num_classes``. """ assert out_enc.size(2) == 1, 'feature height must be 1' x = out_enc.squeeze(2) x = x.permute(0, 2, 1) predicts = self.decoder(x) return predicts def forward_test( self, feat: Optional[torch.Tensor] = None, out_enc: Optional[torch.Tensor] = None, data_samples: Optional[Sequence[TextRecogDataSample]] = None ) -> torch.Tensor: """Forward for testing. Args: feat (torch.Tensor, optional): The feature map. Defaults to None. out_enc (torch.Tensor, optional): Encoder output from encoder of shape :math:`(N, 1, H, W)`. Defaults to None. data_samples (Sequence[TextRecogDataSample]): Batch of TextRecogDataSample, containing gt_text information. Defaults to None. Returns: Tensor: Character probabilities. of shape :math:`(N, self.max_seq_len, C)` where :math:`C` is ``num_classes``. """ return self.softmax(self.forward_train(feat, out_enc, data_samples))