File size: 4,147 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union

import torch

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss


@MODELS.register_module()
class ABIModuleLoss(BaseTextRecogModuleLoss):
    """Implementation of ABINet multiloss that allows mixing different types of
    losses with weights.

    Args:
        dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
            the instance of `Dictionary`.
        max_seq_len (int): Maximum sequence length. The sequence is usually
            generated from decoder. Defaults to 40.
        letter_case (str): There are three options to alter the letter cases
            of gt texts:
            - unchanged: Do not change gt texts.
            - upper: Convert gt texts into uppercase characters.
            - lower: Convert gt texts into lowercase characters.
            Usually, it only works for English characters. Defaults to
            'unchanged'.
        weight_vis (float or int): The weight of vision decoder loss. Defaults
            to 1.0.
        weight_dec (float or int): The weight of language decoder loss.
            Defaults to 1.0.
        weight_fusion (float or int): The weight of fuser (aligner) loss.
            Defaults to 1.0.
    """

    def __init__(self,
                 dictionary: Union[Dict, Dictionary],
                 max_seq_len: int = 40,
                 letter_case: str = 'unchanged',
                 weight_vis: Union[float, int] = 1.0,
                 weight_lang: Union[float, int] = 1.0,
                 weight_fusion: Union[float, int] = 1.0,
                 **kwargs) -> None:
        assert isinstance(weight_vis, (float, int))
        assert isinstance(weight_lang, (float, int))
        assert isinstance(weight_fusion, (float, int))
        super().__init__(
            dictionary=dictionary,
            max_seq_len=max_seq_len,
            letter_case=letter_case)
        self.weight_vis = weight_vis
        self.weight_lang = weight_lang
        self.weight_fusion = weight_fusion
        self._ce_loss = CEModuleLoss(
            self.dictionary,
            max_seq_len,
            letter_case,
            reduction='mean',
            ignore_first_char=True)

    def forward(self, outputs: Dict,
                data_samples: Sequence[TextRecogDataSample]) -> Dict:
        """
        Args:
            outputs (dict): The output dictionary with at least one of
                ``out_vis``, ``out_langs`` and ``out_fusers`` specified.
            data_samples (list[TextRecogDataSample]): List of
                ``TextRecogDataSample`` which are processed by ``get_target``.

        Returns:
            dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and
            ``loss_fusion``. Each should either be the loss tensor or None if
            the output of its corresponding module is not given.
        """
        assert 'out_vis' in outputs or \
            'out_langs' in outputs or 'out_fusers' in outputs
        losses = {}

        if outputs.get('out_vis', None):
            losses['loss_visual'] = self.weight_vis * self._ce_loss(
                outputs['out_vis']['logits'], data_samples)['loss_ce']
        if outputs.get('out_langs', None):
            lang_losses = []
            for out_lang in outputs['out_langs']:
                lang_losses.append(
                    self._ce_loss(out_lang['logits'], data_samples)['loss_ce'])
            losses['loss_lang'] = self.weight_lang * torch.mean(
                torch.stack(lang_losses))
        if outputs.get('out_fusers', None):
            fuser_losses = []
            for out_fuser in outputs['out_fusers']:
                fuser_losses.append(
                    self._ce_loss(out_fuser['logits'],
                                  data_samples)['loss_ce'])
            losses['loss_fusion'] = self.weight_fusion * torch.mean(
                torch.stack(fuser_losses))
        return losses