Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import math | |
from typing import Dict, Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
from mmengine.model import ModuleList | |
from mmocr.models.common.dictionary import Dictionary | |
from mmocr.models.common.modules import PositionalEncoding | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
from .base import BaseDecoder | |
def clones(module: nn.Module, N: int) -> nn.ModuleList: | |
"""Produce N identical layers. | |
Args: | |
module (nn.Module): A pytorch nn.module. | |
N (int): Number of copies. | |
Returns: | |
nn.ModuleList: A pytorch nn.ModuleList with the copies. | |
""" | |
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
class Embeddings(nn.Module): | |
"""Construct the word embeddings given vocab size and embed dim. | |
Args: | |
d_model (int): The embedding dimension. | |
vocab (int): Vocablury size. | |
""" | |
def __init__(self, d_model: int, vocab: int): | |
super().__init__() | |
self.lut = nn.Embedding(vocab, d_model) | |
self.d_model = d_model | |
def forward(self, *input: torch.Tensor) -> torch.Tensor: | |
"""Forward the embeddings. | |
Args: | |
input (torch.Tensor): The input tensors. | |
Returns: | |
torch.Tensor: The embeddings. | |
""" | |
x = input[0] | |
return self.lut(x) * math.sqrt(self.d_model) | |
class MasterDecoder(BaseDecoder): | |
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_. | |
Code is partially modified from https://github.com/wenwenyu/MASTER-pytorch. | |
Args: | |
n_layers (int): Number of attention layers. Defaults to 3. | |
n_head (int): Number of parallel attention heads. Defaults to 8. | |
d_model (int): Dimension :math:`E` of the input from previous model. | |
Defaults to 512. | |
feat_size (int): The size of the input feature from previous model, | |
usually :math:`H * W`. Defaults to 6 * 40. | |
d_inner (int): Hidden dimension of feedforward layers. | |
Defaults to 2048. | |
attn_drop (float): Dropout rate of the attention layer. Defaults to 0. | |
ffn_drop (float): Dropout rate of the feedforward layer. Defaults to 0. | |
feat_pe_drop (float): Dropout rate of the feature positional encoding | |
layer. Defaults to 0.2. | |
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or | |
the instance of `Dictionary`. Defaults to None. | |
module_loss (dict, optional): Config to build module_loss. Defaults | |
to None. | |
postprocessor (dict, optional): Config to build postprocessor. | |
Defaults to None. | |
max_seq_len (int): Maximum output sequence length :math:`T`. Defaults | |
to 30. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
""" | |
def __init__( | |
self, | |
n_layers: int = 3, | |
n_head: int = 8, | |
d_model: int = 512, | |
feat_size: int = 6 * 40, | |
d_inner: int = 2048, | |
attn_drop: float = 0., | |
ffn_drop: float = 0., | |
feat_pe_drop: float = 0.2, | |
module_loss: Optional[Dict] = None, | |
postprocessor: Optional[Dict] = None, | |
dictionary: Optional[Union[Dict, Dictionary]] = None, | |
max_seq_len: int = 30, | |
init_cfg: Optional[Union[Dict, Sequence[Dict]]] = None, | |
): | |
super().__init__( | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
dictionary=dictionary, | |
init_cfg=init_cfg, | |
max_seq_len=max_seq_len) | |
operation_order = ('norm', 'self_attn', 'norm', 'cross_attn', 'norm', | |
'ffn') | |
decoder_layer = BaseTransformerLayer( | |
operation_order=operation_order, | |
attn_cfgs=dict( | |
type='MultiheadAttention', | |
embed_dims=d_model, | |
num_heads=n_head, | |
attn_drop=attn_drop, | |
dropout_layer=dict(type='Dropout', drop_prob=attn_drop), | |
), | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=d_model, | |
feedforward_channels=d_inner, | |
ffn_drop=ffn_drop, | |
dropout_layer=dict(type='Dropout', drop_prob=ffn_drop), | |
), | |
norm_cfg=dict(type='LN'), | |
batch_first=True, | |
) | |
self.decoder_layers = ModuleList( | |
[copy.deepcopy(decoder_layer) for _ in range(n_layers)]) | |
self.cls = nn.Linear(d_model, self.dictionary.num_classes) | |
self.SOS = self.dictionary.start_idx | |
self.PAD = self.dictionary.padding_idx | |
self.max_seq_len = max_seq_len | |
self.feat_size = feat_size | |
self.n_head = n_head | |
self.embedding = Embeddings( | |
d_model=d_model, vocab=self.dictionary.num_classes) | |
# TODO: | |
self.positional_encoding = PositionalEncoding( | |
d_hid=d_model, n_position=self.max_seq_len + 1) | |
self.feat_positional_encoding = PositionalEncoding( | |
d_hid=d_model, n_position=self.feat_size, dropout=feat_pe_drop) | |
self.norm = nn.LayerNorm(d_model) | |
self.softmax = nn.Softmax(dim=-1) | |
def make_target_mask(self, tgt: torch.Tensor, | |
device: torch.device) -> torch.Tensor: | |
"""Make target mask for self attention. | |
Args: | |
tgt (Tensor): Shape [N, l_tgt] | |
device (torch.device): Mask device. | |
Returns: | |
Tensor: Mask of shape [N * self.n_head, l_tgt, l_tgt] | |
""" | |
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3).bool() | |
tgt_len = tgt.size(1) | |
trg_sub_mask = torch.tril( | |
torch.ones((tgt_len, tgt_len), dtype=torch.bool, device=device)) | |
tgt_mask = trg_pad_mask & trg_sub_mask | |
# inverse for mmcv's BaseTransformerLayer | |
tril_mask = tgt_mask.clone() | |
tgt_mask = tgt_mask.float().masked_fill_(tril_mask == 0, -1e9) | |
tgt_mask = tgt_mask.masked_fill_(tril_mask, 0) | |
tgt_mask = tgt_mask.repeat(1, self.n_head, 1, 1) | |
tgt_mask = tgt_mask.view(-1, tgt_len, tgt_len) | |
return tgt_mask | |
def decode(self, tgt_seq: torch.Tensor, feature: torch.Tensor, | |
src_mask: torch.BoolTensor, | |
tgt_mask: torch.BoolTensor) -> torch.Tensor: | |
"""Decode the input sequence. | |
Args: | |
tgt_seq (Tensor): Target sequence of shape: math: `(N, T, C)`. | |
feature (Tensor): Input feature map from encoder of | |
shape: math: `(N, C, H, W)` | |
src_mask (BoolTensor): The source mask of shape: math: `(N, H*W)`. | |
tgt_mask (BoolTensor): The target mask of shape: math: `(N, T, T)`. | |
Return: | |
Tensor: The decoded sequence. | |
""" | |
tgt_seq = self.embedding(tgt_seq) | |
x = self.positional_encoding(tgt_seq) | |
attn_masks = [tgt_mask, src_mask] | |
for layer in self.decoder_layers: | |
x = layer( | |
query=x, key=feature, value=feature, attn_masks=attn_masks) | |
x = self.norm(x) | |
return self.cls(x) | |
def forward_train(self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: torch.Tensor = None, | |
data_samples: Sequence[TextRecogDataSample] = None | |
) -> torch.Tensor: | |
"""Forward for training. Source mask will not be used here. | |
Args: | |
feat (Tensor, optional): Input feature map from backbone. | |
out_enc (Tensor): Unused. | |
data_samples (list[TextRecogDataSample]): Batch of | |
TextRecogDataSample, containing gt_text and valid_ratio | |
information. | |
Returns: | |
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where | |
:math:`C` is ``num_classes``. | |
""" | |
# flatten 2D feature map | |
if len(feat.shape) > 3: | |
b, c, h, w = feat.shape | |
feat = feat.view(b, c, h * w) | |
feat = feat.permute((0, 2, 1)) | |
feat = self.feat_positional_encoding(feat) | |
trg_seq = [] | |
for target in data_samples: | |
trg_seq.append(target.gt_text.padded_indexes.to(feat.device)) | |
trg_seq = torch.stack(trg_seq, dim=0) | |
src_mask = None | |
tgt_mask = self.make_target_mask(trg_seq, device=feat.device) | |
return self.decode(trg_seq, feat, src_mask, tgt_mask) | |
def forward_test(self, | |
feat: Optional[torch.Tensor] = None, | |
out_enc: torch.Tensor = None, | |
data_samples: Sequence[TextRecogDataSample] = None | |
) -> torch.Tensor: | |
"""Forward for testing. | |
Args: | |
feat (Tensor, optional): Input feature map from backbone. | |
out_enc (Tensor): Unused. | |
data_samples (list[TextRecogDataSample]): Unused. | |
Returns: | |
Tensor: Character probabilities. of shape | |
:math:`(N, self.max_seq_len, C)` where :math:`C` is | |
``num_classes``. | |
""" | |
# flatten 2D feature map | |
if len(feat.shape) > 3: | |
b, c, h, w = feat.shape | |
feat = feat.view(b, c, h * w) | |
feat = feat.permute((0, 2, 1)) | |
feat = self.feat_positional_encoding(feat) | |
N = feat.shape[0] | |
input = torch.full((N, 1), | |
self.SOS, | |
device=feat.device, | |
dtype=torch.long) | |
output = None | |
for _ in range(self.max_seq_len): | |
target_mask = self.make_target_mask(input, device=feat.device) | |
out = self.decode(input, feat, None, target_mask) | |
output = out | |
_, next_word = torch.max(out, dim=-1) | |
input = torch.cat([input, next_word[:, -1].unsqueeze(-1)], dim=1) | |
return self.softmax(output) | |