File size: 5,048 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Sequence, Union

import torch
import torch.nn as nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss


@MODELS.register_module()
class CTCModuleLoss(BaseTextRecogModuleLoss):
    """Implementation of loss module for CTC-loss based text recognition.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        letter_case (str): There are three options to alter the letter cases
            of gt texts:
            - unchanged: Do not change gt texts.
            - upper: Convert gt texts into uppercase characters.
            - lower: Convert gt texts into lowercase characters.
            Usually, it only works for English characters. Defaults to
            'unchanged'.
        flatten (bool): If True, use flattened targets, else padded targets.
        reduction (str): Specifies the reduction to apply to the output,
            should be one of the following: ('none', 'mean', 'sum').
        zero_infinity (bool): Whether to zero infinite losses and
            the associated gradients. Default: False.
            Infinite losses mainly occur when the inputs
            are too short to be aligned to the targets.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 letter_case: str = 'unchanged',
                 flatten: bool = True,
                 reduction: str = 'mean',
                 zero_infinity: bool = False,
                 **kwargs) -> None:
        super().__init__(dictionary=dictionary, letter_case=letter_case)
        assert isinstance(flatten, bool)
        assert isinstance(reduction, str)
        assert isinstance(zero_infinity, bool)

        self.flatten = flatten
        self.ctc_loss = nn.CTCLoss(
            blank=self.dictionary.padding_idx,
            reduction=reduction,
            zero_infinity=zero_infinity)

    def forward(self, outputs: torch.Tensor,
                data_samples: Sequence[TextRecogDataSample]) -> Dict:
        """
        Args:
            outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
            data_samples (list[TextRecogDataSample]): List of
                ``TextRecogDataSample`` which are processed by ``get_target``.

        Returns:
            dict: The loss dict with key ``loss_ctc``.
        """
        valid_ratios = None
        if data_samples is not None:
            valid_ratios = [
                img_meta.get('valid_ratio', 1.0) for img_meta in data_samples
            ]

        outputs = torch.log_softmax(outputs, dim=2)
        bsz, seq_len = outputs.size(0), outputs.size(1)
        outputs_for_loss = outputs.permute(1, 0, 2).contiguous()  # T * N * C
        targets = [
            data_sample.gt_text.indexes[:seq_len]
            for data_sample in data_samples
        ]
        target_lengths = torch.IntTensor([len(t) for t in targets])
        target_lengths = torch.clamp(target_lengths, max=seq_len).long()
        input_lengths = torch.full(
            size=(bsz, ), fill_value=seq_len, dtype=torch.long)
        if self.flatten:
            targets = torch.cat(targets)
        else:
            padded_targets = torch.full(
                size=(bsz, seq_len),
                fill_value=self.dictionary.padding_idx,
                dtype=torch.long)
            for idx, valid_len in enumerate(target_lengths):
                padded_targets[idx, :valid_len] = targets[idx][:valid_len]
            targets = padded_targets

            if valid_ratios is not None:
                input_lengths = [
                    math.ceil(valid_ratio * seq_len)
                    for valid_ratio in valid_ratios
                ]
                input_lengths = torch.Tensor(input_lengths).long()
        loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
                                 target_lengths)
        losses = dict(loss_ctc=loss_ctc)

        return losses

    def get_targets(
        self, data_samples: Sequence[TextRecogDataSample]
    ) -> Sequence[TextRecogDataSample]:
        """Target generator.

        Args:
            data_samples (list[TextRecogDataSample]): It usually includes
                ``gt_text`` information.

        Returns:

            list[TextRecogDataSample]: updated data_samples. It will add two
            key in data_sample:

            - indexes (torch.LongTensor): The index corresponding to the item.
        """

        for data_sample in data_samples:
            text = data_sample.gt_text.item
            if self.letter_case in ['upper', 'lower']:
                text = getattr(text, self.letter_case)()
            indexes = self.dictionary.str2idx(text)
            indexes = torch.IntTensor(indexes)
            data_sample.gt_text.indexes = indexes
        return data_samples