# Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Dict, List, Optional, Sequence import numpy as np import torch from mmengine.structures import InstanceData from mmocr.structures import TextDetDataSample def create_dummy_textdet_inputs(input_shape: Sequence[int] = (1, 3, 300, 300), num_items: Optional[Sequence[int]] = None ) -> Dict[str, Any]: """Create dummy inputs to test text detectors. Args: input_shape (tuple(int)): 4-d shape of the input image. Defaults to (1, 3, 300, 300). num_items (list[int], optional): Number of bboxes to create for each image. If None, they will be randomly generated. Defaults to None. Returns: Dict[str, Any]: A dictionary of demo inputs. """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) metainfo = dict( img_shape=(H, W, C), ori_shape=(H, W, C), pad_shape=(H, W, C), filename='test.jpg', scale_factor=(1, 1), flip=False) gt_masks = [] gt_kernels = [] gt_effective_mask = [] data_samples = [] for batch_idx in range(N): if num_items is None: num_boxes = rng.randint(1, 10) else: num_boxes = num_items[batch_idx] data_sample = TextDetDataSample( metainfo=metainfo, gt_instances=InstanceData()) cx, cy, bw, bh = rng.rand(num_boxes, 4).T tl_x = ((cx * W) - (W * bw / 2)).clip(0, W) tl_y = ((cy * H) - (H * bh / 2)).clip(0, H) br_x = ((cx * W) + (W * bw / 2)).clip(0, W) br_y = ((cy * H) + (H * bh / 2)).clip(0, H) boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T class_idxs = [0] * num_boxes data_sample.gt_instances.bboxes = torch.FloatTensor(boxes) data_sample.gt_instances.labels = torch.LongTensor(class_idxs) data_sample.gt_instances.ignored = torch.BoolTensor([False] * num_boxes) data_samples.append(data_sample) # kernels = [] # TODO: add support for multiple kernels (if necessary) # for _ in range(num_kernels): # kernel = np.random.rand(H, W) # kernels.append(kernel) gt_kernels.append(np.random.rand(H, W)) gt_effective_mask.append(np.ones((H, W))) mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8) gt_masks.append(mask) mm_inputs = { 'imgs': torch.FloatTensor(imgs).requires_grad_(True), 'data_samples': data_samples, 'gt_masks': gt_masks, 'gt_kernels': gt_kernels, 'gt_mask': gt_effective_mask, 'gt_thr_mask': gt_effective_mask, 'gt_text_mask': gt_effective_mask, 'gt_center_region_mask': gt_effective_mask, 'gt_radius_map': gt_kernels, 'gt_sin_map': gt_kernels, 'gt_cos_map': gt_kernels, } return mm_inputs def create_dummy_dict_file( dict_file: str, chars: List[str] = list('0123456789abcdefghijklmnopqrstuvwxyz') ) -> None: # NOQA """Create a dummy dictionary file. Args: dict_file (str): Path to the dummy dictionary file. chars (list[str]): List of characters in dictionary. Defaults to ``list('0123456789abcdefghijklmnopqrstuvwxyz')``. """ with open(dict_file, 'w') as f: for char in chars: f.write(char + '\n')