Spaces:
Running
Running
File size: 5,048 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Sequence, Union
import torch
import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogModuleLoss
@MODELS.register_module()
class CTCModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for CTC-loss based text recognition.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
flatten (bool): If True, use flattened targets, else padded targets.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
zero_infinity (bool): Whether to zero infinite losses and
the associated gradients. Default: False.
Infinite losses mainly occur when the inputs
are too short to be aligned to the targets.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
letter_case: str = 'unchanged',
flatten: bool = True,
reduction: str = 'mean',
zero_infinity: bool = False,
**kwargs) -> None:
super().__init__(dictionary=dictionary, letter_case=letter_case)
assert isinstance(flatten, bool)
assert isinstance(reduction, str)
assert isinstance(zero_infinity, bool)
self.flatten = flatten
self.ctc_loss = nn.CTCLoss(
blank=self.dictionary.padding_idx,
reduction=reduction,
zero_infinity=zero_infinity)
def forward(self, outputs: torch.Tensor,
data_samples: Sequence[TextRecogDataSample]) -> Dict:
"""
Args:
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
data_samples (list[TextRecogDataSample]): List of
``TextRecogDataSample`` which are processed by ``get_target``.
Returns:
dict: The loss dict with key ``loss_ctc``.
"""
valid_ratios = None
if data_samples is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in data_samples
]
outputs = torch.log_softmax(outputs, dim=2)
bsz, seq_len = outputs.size(0), outputs.size(1)
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
targets = [
data_sample.gt_text.indexes[:seq_len]
for data_sample in data_samples
]
target_lengths = torch.IntTensor([len(t) for t in targets])
target_lengths = torch.clamp(target_lengths, max=seq_len).long()
input_lengths = torch.full(
size=(bsz, ), fill_value=seq_len, dtype=torch.long)
if self.flatten:
targets = torch.cat(targets)
else:
padded_targets = torch.full(
size=(bsz, seq_len),
fill_value=self.dictionary.padding_idx,
dtype=torch.long)
for idx, valid_len in enumerate(target_lengths):
padded_targets[idx, :valid_len] = targets[idx][:valid_len]
targets = padded_targets
if valid_ratios is not None:
input_lengths = [
math.ceil(valid_ratio * seq_len)
for valid_ratio in valid_ratios
]
input_lengths = torch.Tensor(input_lengths).long()
loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
target_lengths)
losses = dict(loss_ctc=loss_ctc)
return losses
def get_targets(
self, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
"""Target generator.
Args:
data_samples (list[TextRecogDataSample]): It usually includes
``gt_text`` information.
Returns:
list[TextRecogDataSample]: updated data_samples. It will add two
key in data_sample:
- indexes (torch.LongTensor): The index corresponding to the item.
"""
for data_sample in data_samples:
text = data_sample.gt_text.item
if self.letter_case in ['upper', 'lower']:
text = getattr(text, self.letter_case)()
indexes = self.dictionary.str2idx(text)
indexes = torch.IntTensor(indexes)
data_sample.gt_text.indexes = indexes
return data_samples
|