Mountchicken's picture
Upload 704 files
9bf4bd7
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseDecoder
@MODELS.register_module()
class SVTRDecoder(BaseDecoder):
"""Decoder module in `SVTR <https://arxiv.org/abs/2205.00159>`_.
Args:
in_channels (int): The num of input channels.
dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or
the instance of `Dictionary`. Defaults to None.
module_loss (Optional[Dict], optional): Cfg to build module_loss.
Defaults to None.
postprocessor (Optional[Dict], optional): Cfg to build postprocessor.
Defaults to None.
max_seq_len (int, optional): Maximum output sequence length :math:`T`.
Defaults to 25.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
in_channels: int,
dictionary: Union[Dict, Dictionary] = None,
module_loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
max_seq_len: int = 25,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(
dictionary=dictionary,
module_loss=module_loss,
postprocessor=postprocessor,
max_seq_len=max_seq_len,
init_cfg=init_cfg)
self.decoder = nn.Linear(
in_features=in_channels, out_features=self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)
def forward_train(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for training.
Args:
feat (torch.Tensor, optional): The feature map. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output from encoder of
shape :math:`(N, 1, H, W)`. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
:math:`C` is ``num_classes``.
"""
assert out_enc.size(2) == 1, 'feature height must be 1'
x = out_enc.squeeze(2)
x = x.permute(0, 2, 1)
predicts = self.decoder(x)
return predicts
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (torch.Tensor, optional): The feature map. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output from encoder of
shape :math:`(N, 1, H, W)`. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
return self.softmax(self.forward_train(feat, out_enc, data_samples))