Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Dict, List, Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from mmengine.model import ModuleList | |
from mmocr.models.common import PositionalEncoding, TFDecoderLayer | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
from .base import BaseDecoder | |
class MAERecDecoder(BaseDecoder): | |
"""Transformer Decoder block with self attention mechanism. | |
Args: | |
n_layers (int): Number of attention layers. Defaults to 6. | |
d_embedding (int): Language embedding dimension. Defaults to 512. | |
n_head (int): Number of parallel attention heads. Defaults to 8. | |
d_k (int): Dimension of the key vector. Defaults to 64. | |
d_v (int): Dimension of the value vector. Defaults to 64 | |
d_model (int): Dimension :math:`D_m` of the input from previous model. | |
Defaults to 512. | |
d_inner (int): Hidden dimension of feedforward layers. Defaults to 256. | |
n_position (int): Length of the positional encoding vector. Must be | |
greater than ``max_seq_len``. Defaults to 200. | |
dropout (float): Dropout rate for text embedding, MHSA, FFN. Defaults | |
to 0.1. | |
module_loss (dict, optional): Config to build module_loss. Defaults | |
to None. | |
postprocessor (dict, optional): Config to build postprocessor. | |
Defaults to None. | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. | |
max_seq_len (int): Maximum output sequence length :math:`T`. Defaults | |
to 30. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
""" | |
def __init__(self, | |
n_layers: int = 6, | |
d_embedding: int = 512, | |
n_head: int = 8, | |
d_k: int = 64, | |
d_v: int = 64, | |
d_model: int = 512, | |
d_inner: int = 256, | |
n_position: int = 200, | |
dropout: float = 0.1, | |
module_loss: Optional[Dict] = None, | |
postprocessor: Optional[Dict] = None, | |
dictionary: Optional[Union[Dict, Dictionary]] = None, | |
max_seq_len: int = 30, | |
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
super().__init__( | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
dictionary=dictionary, | |
init_cfg=init_cfg, | |
max_seq_len=max_seq_len) | |
self.padding_idx = self.dictionary.padding_idx | |
self.start_idx = self.dictionary.start_idx | |
self.max_seq_len = max_seq_len | |
self.trg_word_emb = nn.Embedding( | |
self.dictionary.num_classes, | |
d_embedding, | |
padding_idx=self.padding_idx) | |
self.position_enc = PositionalEncoding( | |
d_embedding, n_position=n_position) | |
self.dropout = nn.Dropout(p=dropout) | |
self.layer_stack = ModuleList([ | |
TFDecoderLayer( | |
d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |
for _ in range(n_layers) | |
]) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
pred_num_class = self.dictionary.num_classes | |
self.classifier = nn.Linear(d_model, pred_num_class) | |
self.softmax = nn.Softmax(dim=-1) | |
def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor: | |
"""Generate mask for target sequence. | |
Args: | |
trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`. | |
Returns: | |
Tensor: Target mask. Shape :math:`(N, T, T)`. | |
E.g.: | |
seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then | |
target_mask = | |
torch.Tensor([[[True, False, False, False], | |
[True, True, False, False], | |
[True, True, False, False], | |
[True, True, False, False]]]) | |
""" | |
pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2) | |
len_s = trg_seq.size(1) | |
subsequent_mask = 1 - torch.triu( | |
torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1) | |
subsequent_mask = subsequent_mask.unsqueeze(0).bool() | |
return pad_mask & subsequent_mask | |
def _get_source_mask(self, src_seq: torch.Tensor, | |
valid_ratios: Sequence[float]) -> torch.Tensor: | |
"""Generate mask for source sequence. | |
Args: | |
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`. | |
valid_ratios (list[float]): The valid ratio of input image. For | |
example, if the width of the original image is w1 and the width | |
after padding is w2, then valid_ratio = w1/w2. Source mask is | |
used to cover the area of the padding region. | |
Returns: | |
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of | |
padding area are False, and the rest are True. | |
""" | |
N, T, _ = src_seq.size() | |
mask = None | |
if len(valid_ratios) > 0: | |
mask = src_seq.new_zeros((N, T), device=src_seq.device) | |
for i, valid_ratio in enumerate(valid_ratios): | |
valid_width = min(T, math.ceil(T * valid_ratio)) | |
mask[i, :valid_width] = 1 | |
return mask | |
def _attention(self, | |
trg_seq: torch.Tensor, | |
src: torch.Tensor, | |
src_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""A wrapped process for transformer based decoder including text | |
embedding, position embedding, N x transformer decoder and a LayerNorm | |
operation. | |
Args: | |
trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`. | |
src (Tensor): Source sequence from encoder in shape | |
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. | |
src_mask (Tensor, Optional): Mask for source sequence. | |
Shape :math:`(N, T)`. Defaults to None. | |
Returns: | |
Tensor: Output sequence from transformer decoder. | |
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. | |
""" | |
trg_embedding = self.trg_word_emb(trg_seq) | |
trg_pos_encoded = self.position_enc(trg_embedding) | |
trg_mask = self._get_target_mask(trg_seq) | |
tgt_seq = self.dropout(trg_pos_encoded) | |
output = tgt_seq | |
for dec_layer in self.layer_stack: | |
output = dec_layer( | |
output, | |
src, | |
self_attn_mask=trg_mask, | |
dec_enc_attn_mask=src_mask) | |
output = self.layer_norm(output) | |
return output | |
def forward_train( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: torch.Tensor = None, | |
data_samples: Sequence[TextRecogDataSample] = None | |
) -> torch.Tensor: | |
"""Forward for training. Source mask will be used here. | |
Args: | |
feat (Tensor, optional): Unused. | |
out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)` | |
where :math:`D_m` is ``d_model``. Defaults to None. | |
data_samples (list[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text and valid_ratio | |
information. Defaults to None. | |
Returns: | |
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where | |
:math:`C` is ``num_classes``. | |
""" | |
valid_ratios = [] | |
for data_sample in data_samples: | |
valid_ratios.append(data_sample.get('valid_ratio')) | |
src_mask = self._get_source_mask(feat, valid_ratios) | |
trg_seq = [] | |
for data_sample in data_samples: | |
trg_seq.append(data_sample.gt_text.padded_indexes.to(feat.device)) | |
trg_seq = torch.stack(trg_seq, dim=0) | |
attn_output = self._attention(trg_seq, feat, src_mask=src_mask) | |
outputs = self.classifier(attn_output) | |
return outputs | |
def forward_test( | |
self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: torch.Tensor = None, | |
data_samples: Sequence[TextRecogDataSample] = None | |
) -> torch.Tensor: | |
"""Forward for testing. | |
Args: | |
feat (Tensor, optional): Unused. | |
out_enc (Tensor): Encoder output of shape: | |
math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. | |
Defaults to None. | |
data_samples (list[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text and valid_ratio | |
information. Defaults to None. | |
Returns: | |
Tensor: Character probabilities. of shape | |
:math:`(N, self.max_seq_len, C)` where :math:`C` is | |
``num_classes``. | |
""" | |
valid_ratios = [] | |
for data_sample in data_samples: | |
valid_ratios.append(data_sample.get('valid_ratio')) | |
src_mask = self._get_source_mask(feat, valid_ratios) | |
N = feat.size(0) | |
init_target_seq = torch.full((N, self.max_seq_len + 1), | |
self.padding_idx, | |
device=feat.device, | |
dtype=torch.long) | |
# bsz * seq_len | |
init_target_seq[:, 0] = self.start_idx | |
outputs = [] | |
for step in range(0, self.max_seq_len): | |
decoder_output = self._attention( | |
init_target_seq, feat, src_mask=src_mask) | |
# bsz * seq_len * C | |
step_result = self.classifier(decoder_output[:, step, :]) | |
# bsz * num_classes | |
outputs.append(step_result) | |
_, step_max_index = torch.max(step_result, dim=-1) | |
init_target_seq[:, step + 1] = step_max_index | |
outputs = torch.stack(outputs, dim=1) | |
return self.softmax(outputs) | |