File size: 5,971 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Sequence, Union

import torch
import torch.nn as nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss


@MODELS.register_module()
class CEModuleLoss(BaseTextRecogModuleLoss):
    """Implementation of loss module for encoder-decoder based text recognition
    method with CrossEntropy loss.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        max_seq_len (int): Maximum sequence length. The sequence is usually
            generated from decoder. Defaults to 40.
        letter_case (str): There are three options to alter the letter cases
            of gt texts:
            - unchanged: Do not change gt texts.
            - upper: Convert gt texts into uppercase characters.
            - lower: Convert gt texts into lowercase characters.
            Usually, it only works for English characters. Defaults to
            'unchanged'.
        pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
            Defaults to 'auto'. Options are:
            - 'auto': Use dictionary.padding_idx to pad gt texts, or
              dictionary.end_idx if dictionary.padding_idx
              is None.
            - 'padding': Always use dictionary.padding_idx to pad gt texts.
            - 'end': Always use dictionary.end_idx to pad gt texts.
            - 'none': Do not pad gt texts.
        ignore_char (int or str): Specifies a target value that is
            ignored and does not contribute to the input gradient.
            ignore_char can be int or str. If int, it is the index of
            the ignored char. If str, it is the character to ignore.
            Apart from single characters, each item can be one of the
            following reversed keywords: 'padding', 'start', 'end',
            and 'unknown', which refer to their corresponding special
            tokens in the dictionary. It will not ignore any special
            tokens when ignore_char == -1 or 'none'. Defaults to 'padding'.
        flatten (bool): Whether to flatten the output and target before
            computing CE loss. Defaults to False.
        reduction (str): Specifies the reduction to apply to the output,
            should be one of the following: ('none', 'mean', 'sum'). Defaults
            to 'none'.
        ignore_first_char (bool): Whether to ignore the first token in target (
            usually the start token). If ``True``, the last token of the output
            sequence will also be removed to be aligned with the target length.
            Defaults to ``False``.
        flatten (bool): Whether to flatten the vectors for loss computation.
            Defaults to False.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 max_seq_len: int = 40,
                 letter_case: str = 'unchanged',
                 pad_with: str = 'auto',
                 ignore_char: Union[int, str] = 'padding',
                 flatten: bool = False,
                 reduction: str = 'none',
                 ignore_first_char: bool = False):
        super().__init__(
            dictionary=dictionary,
            max_seq_len=max_seq_len,
            letter_case=letter_case,
            pad_with=pad_with)
        assert isinstance(ignore_char, (int, str))
        assert isinstance(reduction, str)
        assert reduction in ['none', 'mean', 'sum']
        assert isinstance(ignore_first_char, bool)
        assert isinstance(flatten, bool)
        self.flatten = flatten

        self.ignore_first_char = ignore_first_char

        if isinstance(ignore_char, int):
            ignore_index = ignore_char
        else:
            mapping_table = {
                'none': -1,
                'start': self.dictionary.start_idx,
                'padding': self.dictionary.padding_idx,
                'end': self.dictionary.end_idx,
                'unknown': self.dictionary.unknown_idx,
            }

            ignore_index = mapping_table.get(
                ignore_char,
                self.dictionary.char2idx(ignore_char, strict=False))
            if ignore_index is None or (ignore_index
                                        == self.dictionary.unknown_idx
                                        and ignore_char != 'unknown'):
                warnings.warn(
                    f'{ignore_char} does not exist in the dictionary',
                    UserWarning)
                ignore_index = -1

        self.ignore_char = ignore_char
        self.ignore_index = ignore_index
        self.loss_ce = nn.CrossEntropyLoss(
            ignore_index=ignore_index, reduction=reduction)

    def forward(self, outputs: torch.Tensor,
                data_samples: Sequence[TextRecogDataSample]) -> Dict:
        """
        Args:
            outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
            data_samples (list[TextRecogDataSample]): List of
                ``TextRecogDataSample`` which are processed by ``get_target``.

        Returns:
            dict: A loss dict with the key ``loss_ce``.
        """
        targets = list()
        for data_sample in data_samples:
            targets.append(data_sample.gt_text.padded_indexes)
        targets = torch.stack(targets, dim=0).long()
        if self.ignore_first_char:
            targets = targets[:, 1:].contiguous()
            outputs = outputs[:, :-1, :].contiguous()
        if self.flatten:
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)
        else:
            outputs = outputs.permute(0, 2, 1).contiguous()

        loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
        losses = dict(loss_ce=loss_ce)

        return losses