Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
from .base import BaseDecoder | |
class SVTRDecoder(BaseDecoder): | |
"""Decoder module in `SVTR <https://arxiv.org/abs/2205.00159>`_. | |
Args: | |
in_channels (int): The num of input channels. | |
dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or | |
the instance of `Dictionary`. Defaults to None. | |
module_loss (Optional[Dict], optional): Cfg to build module_loss. | |
Defaults to None. | |
postprocessor (Optional[Dict], optional): Cfg to build postprocessor. | |
Defaults to None. | |
max_seq_len (int, optional): Maximum output sequence length :math:`T`. | |
Defaults to 25. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: int, | |
dictionary: Union[Dict, Dictionary] = None, | |
module_loss: Optional[Dict] = None, | |
postprocessor: Optional[Dict] = None, | |
max_seq_len: int = 25, | |
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
super().__init__( | |
dictionary=dictionary, | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
max_seq_len=max_seq_len, | |
init_cfg=init_cfg) | |
self.decoder = nn.Linear( | |
in_features=in_channels, out_features=self.dictionary.num_classes) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward_train( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
"""Forward for training. | |
Args: | |
feat (torch.Tensor, optional): The feature map. Defaults to None. | |
out_enc (torch.Tensor, optional): Encoder output from encoder of | |
shape :math:`(N, 1, H, W)`. Defaults to None. | |
data_samples (Sequence[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text information. Defaults | |
to None. | |
Returns: | |
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where | |
:math:`C` is ``num_classes``. | |
""" | |
assert out_enc.size(2) == 1, 'feature height must be 1' | |
x = out_enc.squeeze(2) | |
x = x.permute(0, 2, 1) | |
predicts = self.decoder(x) | |
return predicts | |
def forward_test( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
"""Forward for testing. | |
Args: | |
feat (torch.Tensor, optional): The feature map. Defaults to None. | |
out_enc (torch.Tensor, optional): Encoder output from encoder of | |
shape :math:`(N, 1, H, W)`. Defaults to None. | |
data_samples (Sequence[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text information. Defaults | |
to None. | |
Returns: | |
Tensor: Character probabilities. of shape | |
:math:`(N, self.max_seq_len, C)` where :math:`C` is | |
``num_classes``. | |
""" | |
return self.softmax(self.forward_train(feat, out_enc, data_samples)) | |