Spaces:
Running
Running
File size: 4,147 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import torch
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss
@MODELS.register_module()
class ABIModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum sequence length. The sequence is usually
generated from decoder. Defaults to 40.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
weight_vis (float or int): The weight of vision decoder loss. Defaults
to 1.0.
weight_dec (float or int): The weight of language decoder loss.
Defaults to 1.0.
weight_fusion (float or int): The weight of fuser (aligner) loss.
Defaults to 1.0.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
weight_vis: Union[float, int] = 1.0,
weight_lang: Union[float, int] = 1.0,
weight_fusion: Union[float, int] = 1.0,
**kwargs) -> None:
assert isinstance(weight_vis, (float, int))
assert isinstance(weight_lang, (float, int))
assert isinstance(weight_fusion, (float, int))
super().__init__(
dictionary=dictionary,
max_seq_len=max_seq_len,
letter_case=letter_case)
self.weight_vis = weight_vis
self.weight_lang = weight_lang
self.weight_fusion = weight_fusion
self._ce_loss = CEModuleLoss(
self.dictionary,
max_seq_len,
letter_case,
reduction='mean',
ignore_first_char=True)
def forward(self, outputs: Dict,
data_samples: Sequence[TextRecogDataSample]) -> Dict:
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_vis``, ``out_langs`` and ``out_fusers`` specified.
data_samples (list[TextRecogDataSample]): List of
``TextRecogDataSample`` which are processed by ``get_target``.
Returns:
dict: A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or None if
the output of its corresponding module is not given.
"""
assert 'out_vis' in outputs or \
'out_langs' in outputs or 'out_fusers' in outputs
losses = {}
if outputs.get('out_vis', None):
losses['loss_visual'] = self.weight_vis * self._ce_loss(
outputs['out_vis']['logits'], data_samples)['loss_ce']
if outputs.get('out_langs', None):
lang_losses = []
for out_lang in outputs['out_langs']:
lang_losses.append(
self._ce_loss(out_lang['logits'], data_samples)['loss_ce'])
losses['loss_lang'] = self.weight_lang * torch.mean(
torch.stack(lang_losses))
if outputs.get('out_fusers', None):
fuser_losses = []
for out_fuser in outputs['out_fusers']:
fuser_losses.append(
self._ce_loss(out_fuser['logits'],
data_samples)['loss_ce'])
losses['loss_fusion'] = self.weight_fusion * torch.mean(
torch.stack(fuser_losses))
return losses
|