Spaces:
Running
Running
File size: 5,516 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import torch
import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample
class BaseTextRecogModuleLoss(nn.Module):
"""Base recognition loss.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum sequence length. The sequence is usually
generated from decoder. Defaults to 40.
letter_case (str): There are three options to alter the letter cases
of gt texts:
- unchanged: Do not change gt texts.
- upper: Convert gt texts into uppercase characters.
- lower: Convert gt texts into lowercase characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
pad_with (str): The padding strategy for ``gt_text.padded_indexes``.
Defaults to 'auto'. Options are:
- 'auto': Use dictionary.padding_idx to pad gt texts, or
dictionary.end_idx if dictionary.padding_idx
is None.
- 'padding': Always use dictionary.padding_idx to pad gt texts.
- 'end': Always use dictionary.end_idx to pad gt texts.
- 'none': Do not pad gt texts.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
max_seq_len: int = 40,
letter_case: str = 'unchanged',
pad_with: str = 'auto',
**kwargs) -> None:
super().__init__()
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.max_seq_len = max_seq_len
assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
assert pad_with in ['auto', 'padding', 'end', 'none']
if pad_with == 'auto':
self.pad_idx = self.dictionary.padding_idx or \
self.dictionary.end_idx
elif pad_with == 'padding':
self.pad_idx = self.dictionary.padding_idx
elif pad_with == 'end':
self.pad_idx = self.dictionary.end_idx
else:
self.pad_idx = None
if self.pad_idx is None and pad_with != 'none':
if pad_with == 'auto':
raise ValueError('pad_with="auto", but dictionary.end_idx'
' and dictionary.padding_idx are both None')
else:
raise ValueError(
f'pad_with="{pad_with}", but dictionary.{pad_with}_idx is'
' None')
def get_targets(
self, data_samples: Sequence[TextRecogDataSample]
) -> Sequence[TextRecogDataSample]:
"""Target generator.
Args:
data_samples (list[TextRecogDataSample]): It usually includes
``gt_text`` information.
Returns:
list[TextRecogDataSample]: Updated data_samples. Two keys will be
added to data_sample:
- indexes (torch.LongTensor): Character indexes representing gt
texts. All special tokens are excluded, except for UKN.
- padded_indexes (torch.LongTensor): Character indexes
representing gt texts with BOS and EOS if applicable, following
several padding indexes until the length reaches ``max_seq_len``.
In particular, if ``pad_with='none'``, no padding will be
applied.
"""
for data_sample in data_samples:
if data_sample.get('have_target', False):
continue
text = data_sample.gt_text.item
if self.letter_case in ['upper', 'lower']:
text = getattr(text, self.letter_case)()
indexes = self.dictionary.str2idx(text)
indexes = torch.LongTensor(indexes)
# target indexes for loss
src_target = torch.LongTensor(indexes.size(0) + 2).fill_(0)
src_target[1:-1] = indexes
if self.dictionary.start_idx is not None:
src_target[0] = self.dictionary.start_idx
slice_start = 0
else:
slice_start = 1
if self.dictionary.end_idx is not None:
src_target[-1] = self.dictionary.end_idx
slice_end = src_target.size(0)
else:
slice_end = src_target.size(0) - 1
src_target = src_target[slice_start:slice_end]
if self.pad_idx is not None:
padded_indexes = (torch.ones(self.max_seq_len) *
self.pad_idx).long()
char_num = min(src_target.size(0), self.max_seq_len)
padded_indexes[:char_num] = src_target[:char_num]
else:
padded_indexes = src_target
# put in DataSample
data_sample.gt_text.indexes = indexes
data_sample.gt_text.padded_indexes = padded_indexes
data_sample.set_metainfo(dict(have_target=True))
return data_samples
|