Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict | |
import torch | |
from mmocr.registry import MODELS | |
from mmocr.utils.typing_utils import (ConfigType, InitConfigType, | |
OptConfigType, OptRecSampleList, | |
RecForwardResults, RecSampleList) | |
from .base import BaseRecognizer | |
class EncoderDecoderRecognizer(BaseRecognizer): | |
"""Base class for encode-decode recognizer. | |
Args: | |
preprocessor (dict, optional): Config dict for preprocessor. Defaults | |
to None. | |
backbone (dict, optional): Backbone config. Defaults to None. | |
encoder (dict, optional): Encoder config. If None, the output from | |
backbone will be directly fed into ``decoder``. Defaults to None. | |
decoder (dict, optional): Decoder config. Defaults to None. | |
data_preprocessor (dict, optional): Model preprocessing config | |
for processing the input image data. Keys allowed are | |
``to_rgb``(bool), ``pad_size_divisor``(int), ``pad_value``(int or | |
float), ``mean``(int or float) and ``std``(int or float). | |
Preprcessing order: 1. to rgb; 2. normalization 3. pad. | |
Defaults to None. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to None. | |
""" | |
def __init__(self, | |
preprocessor: OptConfigType = None, | |
backbone: OptConfigType = None, | |
encoder: OptConfigType = None, | |
decoder: OptConfigType = None, | |
data_preprocessor: ConfigType = None, | |
init_cfg: InitConfigType = None) -> None: | |
super().__init__( | |
init_cfg=init_cfg, data_preprocessor=data_preprocessor) | |
# Preprocessor module, e.g., TPS | |
if preprocessor is not None: | |
self.preprocessor = MODELS.build(preprocessor) | |
# Backbone | |
if backbone is not None: | |
self.backbone = MODELS.build(backbone) | |
# Encoder module | |
if encoder is not None: | |
self.encoder = MODELS.build(encoder) | |
# Decoder module | |
assert decoder is not None | |
self.decoder = MODELS.build(decoder) | |
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: | |
"""Directly extract features from the backbone.""" | |
if self.with_preprocessor: | |
inputs = self.preprocessor(inputs) | |
if self.with_backbone: | |
inputs = self.backbone(inputs) | |
return inputs | |
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList, | |
**kwargs) -> Dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (tensor): Input images of shape (N, C, H, W). | |
Typically these should be mean centered and std scaled. | |
data_samples (list[TextRecogDataSample]): A list of N | |
datasamples, containing meta information and gold | |
annotations for each of the images. | |
Returns: | |
dict[str, tensor]: A dictionary of loss components. | |
""" | |
feat = self.extract_feat(inputs) | |
out_enc = None | |
if self.with_encoder: | |
out_enc = self.encoder(feat, data_samples) | |
return self.decoder.loss(feat, out_enc, data_samples) | |
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList, | |
**kwargs) -> RecSampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
Args: | |
inputs (torch.Tensor): Image input tensor. | |
data_samples (list[TextRecogDataSample]): A list of N datasamples, | |
containing meta information and gold annotations for each of | |
the images. | |
Returns: | |
list[TextRecogDataSample]: A list of N datasamples of prediction | |
results. Results are stored in ``pred_text``. | |
""" | |
feat = self.extract_feat(inputs) | |
out_enc = None | |
if self.with_encoder: | |
out_enc = self.encoder(feat, data_samples) | |
return self.decoder.predict(feat, out_enc, data_samples) | |
def _forward(self, | |
inputs: torch.Tensor, | |
data_samples: OptRecSampleList = None, | |
**kwargs) -> RecForwardResults: | |
"""Network forward process. Usually includes backbone, encoder and | |
decoder forward without any post-processing. | |
Args: | |
inputs (Tensor): Inputs with shape (N, C, H, W). | |
data_samples (list[TextRecogDataSample]): A list of N | |
datasamples, containing meta information and gold | |
annotations for each of the images. | |
Returns: | |
Tensor: A tuple of features from ``decoder`` forward. | |
""" | |
feat = self.extract_feat(inputs) | |
out_enc = None | |
if self.with_encoder: | |
out_enc = self.encoder(feat, data_samples) | |
return self.decoder(feat, out_enc, data_samples) | |