# Copyright (c) OpenMMLab. All rights reserved. from numbers import Number from typing import Dict, List, Optional, Sequence, Union import torch.nn as nn from mmengine.model import ImgDataPreprocessor from mmocr.registry import MODELS @MODELS.register_module() class TextRecogDataPreprocessor(ImgDataPreprocessor): """Image pre-processor for recognition tasks. Comparing with the :class:`mmengine.ImgDataPreprocessor`, 1. It supports batch augmentations. 2. It will additionally append batch_input_shape and valid_ratio to data_samples considering the object recognition task. It provides the data pre-processing as follows - Collate and move data to the target device. - Pad inputs to the maximum size of current batch with defined ``pad_value``. The padding size can be divisible by a defined ``pad_size_divisor`` - Stack inputs to inputs. - Convert inputs from bgr to rgb if the shape of input is (3, H, W). - Normalize image with defined std and mean. - Do batch augmentations during training. Args: mean (Sequence[Number], optional): The pixel mean of R, G, B channels. Defaults to None. std (Sequence[Number], optional): The pixel standard deviation of R, G, B channels. Defaults to None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. pad_value (Number): The padded pixel value. Defaults to 0. bgr_to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. batch_augments (list[dict], optional): Batch-level augmentations """ def __init__(self, mean: Sequence[Number] = None, std: Sequence[Number] = None, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, batch_augments: Optional[List[Dict]] = None) -> None: super().__init__( mean=mean, std=std, pad_size_divisor=pad_size_divisor, pad_value=pad_value, bgr_to_rgb=bgr_to_rgb, rgb_to_bgr=rgb_to_bgr) if batch_augments is not None: self.batch_augments = nn.ModuleList( [MODELS.build(aug) for aug in batch_augments]) else: self.batch_augments = None def forward(self, data: Dict, training: bool = False) -> Dict: """Perform normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. Args: data (dict): Data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: dict: Data in the same format as the model input. """ data = super().forward(data=data, training=training) inputs, data_samples = data['inputs'], data['data_samples'] if data_samples is not None: batch_input_shape = tuple(inputs[0].size()[-2:]) for data_sample in data_samples: valid_ratio = data_sample.valid_ratio * \ data_sample.img_shape[1] / batch_input_shape[1] data_sample.set_metainfo( dict( valid_ratio=valid_ratio, batch_input_shape=batch_input_shape)) if training and self.batch_augments is not None: for batch_aug in self.batch_augments: inputs, data_samples = batch_aug(inputs, data_samples) return data