Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Dict, List, Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
from .base import BaseDecoder | |
class RobustScannerFuser(BaseDecoder): | |
"""Decoder for RobustScanner. | |
Args: | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. | |
module_loss (dict, optional): Config to build module_loss. Defaults | |
to None. | |
postprocessor (dict, optional): Config to build postprocessor. | |
Defaults to None. | |
hybrid_decoder (dict): Config to build hybrid_decoder. Defaults to | |
dict(type='SequenceAttentionDecoder'). | |
position_decoder (dict): Config to build position_decoder. Defaults to | |
dict(type='PositionAttentionDecoder'). | |
fuser (dict): Config to build fuser. Defaults to | |
dict(type='RobustScannerFuser'). | |
max_seq_len (int): Maximum sequence length. The | |
sequence is usually generated from decoder. Defaults to 30. | |
in_channels (list[int]): List of input channels. | |
Defaults to [512, 512]. | |
dim (int): The dimension on which to split the input. Defaults to -1. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to None. | |
""" | |
def __init__(self, | |
dictionary: Union[Dict, Dictionary], | |
module_loss: Optional[Dict] = None, | |
postprocessor: Optional[Dict] = None, | |
hybrid_decoder: Dict = dict(type='SequenceAttentionDecoder'), | |
position_decoder: Dict = dict( | |
type='PositionAttentionDecoder'), | |
max_seq_len: int = 30, | |
in_channels: List[int] = [512, 512], | |
dim: int = -1, | |
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
super().__init__( | |
dictionary=dictionary, | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
max_seq_len=max_seq_len, | |
init_cfg=init_cfg) | |
for cfg_name in ['hybrid_decoder', 'position_decoder']: | |
cfg = eval(cfg_name) | |
if cfg is not None: | |
if cfg.get('dictionary', None) is None: | |
cfg.update(dictionary=self.dictionary) | |
else: | |
warnings.warn(f"Using dictionary {cfg['dictionary']} " | |
"in decoder's config.") | |
if cfg.get('max_seq_len', None) is None: | |
cfg.update(max_seq_len=max_seq_len) | |
else: | |
warnings.warn(f"Using max_seq_len {cfg['max_seq_len']} " | |
"in decoder's config.") | |
setattr(self, cfg_name, MODELS.build(cfg)) | |
in_channels = sum(in_channels) | |
self.dim = dim | |
self.linear_layer = nn.Linear(in_channels, in_channels) | |
self.glu_layer = nn.GLU(dim=dim) | |
self.prediction = nn.Linear( | |
int(in_channels / 2), self.dictionary.num_classes) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward_train( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
"""Forward for training. | |
Args: | |
feat (torch.Tensor, optional): The feature map from backbone of | |
shape :math:`(N, E, H, W)`. Defaults to None. | |
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | |
data_samples (Sequence[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text information. Defaults | |
to None. | |
""" | |
hybrid_glimpse = self.hybrid_decoder(feat, out_enc, data_samples) | |
position_glimpse = self.position_decoder(feat, out_enc, data_samples) | |
fusion_input = torch.cat([hybrid_glimpse, position_glimpse], self.dim) | |
outputs = self.linear_layer(fusion_input) | |
outputs = self.glu_layer(outputs) | |
return self.prediction(outputs) | |
def forward_test( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: Optional[torch.Tensor] = None, | |
data_samples: Optional[Sequence[TextRecogDataSample]] = None | |
) -> torch.Tensor: | |
"""Forward for testing. | |
Args: | |
feat (torch.Tensor, optional): The feature map from backbone of | |
shape :math:`(N, E, H, W)`. Defaults to None. | |
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | |
data_samples (Sequence[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing vaild_ratio information. | |
Defaults to None. | |
Returns: | |
Tensor: Character probabilities. of shape | |
:math:`(N, self.max_seq_len, C)` where :math:`C` is | |
``num_classes``. | |
""" | |
position_glimpse = self.position_decoder(feat, out_enc, data_samples) | |
batch_size = feat.size(0) | |
decode_sequence = (feat.new_ones((batch_size, self.max_seq_len)) * | |
self.dictionary.start_idx).long() | |
outputs = [] | |
for step in range(self.max_seq_len): | |
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step( | |
feat, out_enc, decode_sequence, step, data_samples) | |
fusion_input = torch.cat( | |
[hybrid_glimpse_step, position_glimpse[:, step, :]], self.dim) | |
output = self.linear_layer(fusion_input) | |
output = self.glu_layer(output) | |
output = self.prediction(output) | |
_, max_idx = torch.max(output, dim=1, keepdim=False) | |
if step < self.max_seq_len - 1: | |
decode_sequence[:, step + 1] = max_idx | |
outputs.append(output) | |
outputs = torch.stack(outputs, 1) | |
return self.softmax(outputs) | |