Spaces:
Running
Running
File size: 5,971 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Sequence, Union
import torch
import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss
@MODELS.register_module()
class CEModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum sequence length. The sequence is usually
generated from decoder. Defaults to 40.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
Defaults to 'auto'. Options are:
- 'auto': Use dictionary.padding_idx to pad gt texts, or
dictionary.end_idx if dictionary.padding_idx
is None.
- 'padding': Always use dictionary.padding_idx to pad gt texts.
- 'end': Always use dictionary.end_idx to pad gt texts.
- 'none': Do not pad gt texts.
ignore_char (int or str): Specifies a target value that is
ignored and does not contribute to the input gradient.
ignore_char can be int or str. If int, it is the index of
the ignored char. If str, it is the character to ignore.
Apart from single characters, each item can be one of the
following reversed keywords: 'padding', 'start', 'end',
and 'unknown', which refer to their corresponding special
tokens in the dictionary. It will not ignore any special
tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
flatten (bool): Whether to flatten the output and target before
computing CE loss. Defaults to False.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum'). Defaults
to 'none'.
ignore_first_char (bool): Whether to ignore the first token in target (
usually the start token). If ``True``, the last token of the output
sequence will also be removed to be aligned with the target length.
Defaults to ``False``.
flatten (bool): Whether to flatten the vectors for loss computation.
Defaults to False.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
pad_with: str = 'auto',
ignore_char: Union[int, str] = 'padding',
flatten: bool = False,
reduction: str = 'none',
ignore_first_char: bool = False):
super().__init__(
dictionary=dictionary,
max_seq_len=max_seq_len,
letter_case=letter_case,
pad_with=pad_with)
assert isinstance(ignore_char, (int, str))
assert isinstance(reduction, str)
assert reduction in ['none', 'mean', 'sum']
assert isinstance(ignore_first_char, bool)
assert isinstance(flatten, bool)
self.flatten = flatten
self.ignore_first_char = ignore_first_char
if isinstance(ignore_char, int):
ignore_index = ignore_char
else:
mapping_table = {
'none': -1,
'start': self.dictionary.start_idx,
'padding': self.dictionary.padding_idx,
'end': self.dictionary.end_idx,
'unknown': self.dictionary.unknown_idx,
}
ignore_index = mapping_table.get(
ignore_char,
self.dictionary.char2idx(ignore_char, strict=False))
if ignore_index is None or (ignore_index
== self.dictionary.unknown_idx
and ignore_char != 'unknown'):
warnings.warn(
f'{ignore_char} does not exist in the dictionary',
UserWarning)
ignore_index = -1
self.ignore_char = ignore_char
self.ignore_index = ignore_index
self.loss_ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction=reduction)
def forward(self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]) -> Dict:
"""
Args:
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
data_samples (list[TextRecogDataSample]): List of
``TextRecogDataSample`` which are processed by ``get_target``.
Returns:
dict: A loss dict with the key ``loss_ce``.
"""
targets = list()
for data_sample in data_samples:
targets.append(data_sample.gt_text.padded_indexes)
targets = torch.stack(targets, dim=0).long()
if self.ignore_first_char:
targets = targets[:, 1:].contiguous()
outputs = outputs[:, :-1, :].contiguous()
if self.flatten:
outputs = outputs.view(-1, outputs.size(-1))
targets = targets.view(-1)
else:
outputs = outputs.permute(0, 2, 1).contiguous()
loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
losses = dict(loss_ce=loss_ce)
return losses
|