Spaces:
Running
Running
File size: 6,687 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
import torch
from mmengine.model import BaseModule
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS, TASK_UTILS
from mmocr.structures import TextRecogDataSample
@MODELS.register_module()
class BaseDecoder(BaseModule):
"""Base decoder for text recognition, build the loss and postprocessor.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
max_seq_len (int): Maximum sequence length. The
sequence is usually generated from decoder. Defaults to 40.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
module_loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
max_seq_len: int = 40,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.module_loss = None
self.postprocessor = None
self.max_seq_len = max_seq_len
if module_loss is not None:
assert isinstance(module_loss, dict)
module_loss.update(dictionary=dictionary)
module_loss.update(max_seq_len=max_seq_len)
self.module_loss = MODELS.build(module_loss)
if postprocessor is not None:
assert isinstance(postprocessor, dict)
postprocessor.update(dictionary=dictionary)
postprocessor.update(max_seq_len=max_seq_len)
self.postprocessor = MODELS.build(postprocessor)
def forward_train(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for training.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
raise NotImplementedError
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
raise NotImplementedError
def loss(self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
feat (Tensor, optional): Features from the backbone. Defaults
to None.
out_enc (Tensor, optional): Features from the encoder.
Defaults to None.
data_samples (list[TextRecogDataSample], optional): A list of
N datasamples, containing meta information and gold
annotations for each of the images. Defaults to None.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
out_dec = self(feat, out_enc, data_samples)
return self.module_loss(out_dec, data_samples)
def predict(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> Sequence[TextRecogDataSample]:
"""Perform forward propagation of the decoder and postprocessor.
Args:
feat (Tensor, optional): Features from the backbone. Defaults
to None.
out_enc (Tensor, optional): Features from the encoder. Defaults
to None.
data_samples (list[TextRecogDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images. Defaults to None.
Returns:
list[TextRecogDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_text``.
"""
out_dec = self(feat, out_enc, data_samples)
return self.postprocessor(out_dec, data_samples)
def forward(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Decoder forward.
Args:
feat (Tensor, optional): Features from the backbone. Defaults
to None.
out_enc (Tensor, optional): Features from the encoder.
Defaults to None.
data_samples (list[TextRecogDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images. Defaults to None.
Returns:
Tensor: Features from ``decoder`` forward.
"""
if self.training:
if getattr(self, 'module_loss') is not None:
data_samples = self.module_loss.get_targets(data_samples)
return self.forward_train(feat, out_enc, data_samples)
else:
return self.forward_test(feat, out_enc, data_samples)
|