Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
12.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, LabelData
from mmocr.registry import TRANSFORMS
from mmocr.structures import (KIEDataSample, TextDetDataSample,
TextRecogDataSample)
@TRANSFORMS.register_module()
class PackTextDetInputs(BaseTransform):
"""Pack the inputs data for text detection.
The type of outputs is `dict`:
- inputs: image converted to tensor, whose shape is (C, H, W).
- data_samples: Two components of ``TextDetDataSample`` will be updated:
- gt_instances (InstanceData): Depending on annotations, a subset of the
following keys will be updated:
- bboxes (torch.Tensor((N, 4), dtype=torch.float32)): The groundtruth
of bounding boxes in the form of [x1, y1, x2, y2]. Renamed from
'gt_bboxes'.
- labels (torch.LongTensor(N)): The labels of instances.
Renamed from 'gt_bboxes_labels'.
- polygons(list[np.array((2k,), dtype=np.float32)]): The
groundtruth of polygons in the form of [x1, y1,..., xk, yk]. Each
element in polygons may have different number of points. Renamed from
'gt_polygons'. Using numpy instead of tensor is that polygon usually
is not the output of model and operated on cpu.
- ignored (torch.BoolTensor((N,))): The flag indicating whether the
corresponding instance should be ignored. Renamed from
'gt_ignored'.
- texts (list[str]): The groundtruth texts. Renamed from 'gt_texts'.
- metainfo (dict): 'metainfo' is always populated. The contents of the
'metainfo' depends on ``meta_keys``. By default it includes:
- "img_path": Path to the image file.
- "img_shape": Shape of the image input to the network as a tuple
(h, w). Note that the image may be zero-padded afterward on the
bottom/right if the batch tensor is larger than this shape.
- "scale_factor": A tuple indicating the ratio of width and height
of the preprocessed image to the original one.
- "ori_shape": Shape of the preprocessed image as a tuple
(h, w).
- "pad_shape": Image shape after padding (if any Pad-related
transform involved) as a tuple (h, w).
- "flip": A boolean indicating if the image has been flipped.
- ``flip_direction``: the flipping direction.
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
the metainfo of ``TextDetSample``. Defaults to ``('img_path',
'ori_shape', 'img_shape', 'scale_factor', 'flip',
'flip_direction')``.
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_polygons': 'polygons',
'gt_texts': 'texts',
'gt_ignored': 'ignored'
}
def __init__(self,
meta_keys=('img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
data_sample = TextDetDataSample()
instance_data = InstanceData()
for key in self.mapping_table.keys():
if key not in results:
continue
if key in ['gt_bboxes', 'gt_bboxes_labels', 'gt_ignored']:
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
else:
instance_data[self.mapping_table[key]] = results[key]
data_sample.gt_instances = instance_data
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackTextRecogInputs(BaseTransform):
"""Pack the inputs data for text recognition.
The type of outputs is `dict`:
- inputs: Image as a tensor, whose shape is (C, H, W).
- data_samples: Two components of ``TextRecogDataSample`` will be updated:
- gt_text (LabelData):
- item(str): The groundtruth of text. Rename from 'gt_texts'.
- metainfo (dict): 'metainfo' is always populated. The contents of the
'metainfo' depends on ``meta_keys``. By default it includes:
- "img_path": Path to the image file.
- "ori_shape": Shape of the preprocessed image as a tuple
(h, w).
- "img_shape": Shape of the image input to the network as a tuple
(h, w). Note that the image may be zero-padded afterward on the
bottom/right if the batch tensor is larger than this shape.
- "valid_ratio": The proportion of valid (unpadded) content of image
on the x-axis. It defaults to 1 if not set in pipeline.
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
the metainfo of ``TextRecogDataSampel``. Defaults to
``('img_path', 'ori_shape', 'img_shape', 'pad_shape',
'valid_ratio')``.
"""
def __init__(self,
meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape',
'valid_ratio')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_samples' (obj:`TextRecogDataSample`): The annotation info
of the sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
data_sample = TextRecogDataSample()
gt_text = LabelData()
if results.get('gt_texts', None):
assert len(
results['gt_texts']
) == 1, 'Each image sample should have one text annotation only'
gt_text.item = results['gt_texts'][0]
data_sample.gt_text = gt_text
img_meta = {}
for key in self.meta_keys:
if key == 'valid_ratio':
img_meta[key] = results.get('valid_ratio', 1)
else:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackKIEInputs(BaseTransform):
"""Pack the inputs data for key information extraction.
The type of outputs is `dict`:
- inputs: image converted to tensor, whose shape is (C, H, W).
- data_samples: Two components of ``TextDetDataSample`` will be updated:
- gt_instances (InstanceData): Depending on annotations, a subset of the
following keys will be updated:
- bboxes (torch.Tensor((N, 4), dtype=torch.float32)): The groundtruth
of bounding boxes in the form of [x1, y1, x2, y2]. Renamed from
'gt_bboxes'.
- labels (torch.LongTensor(N)): The labels of instances.
Renamed from 'gt_bboxes_labels'.
- edge_labels (torch.LongTensor(N, N)): The edge labels.
Renamed from 'gt_edges_labels'.
- texts (list[str]): The groundtruth texts. Renamed from 'gt_texts'.
- metainfo (dict): 'metainfo' is always populated. The contents of the
'metainfo' depends on ``meta_keys``. By default it includes:
- "img_path": Path to the image file.
- "img_shape": Shape of the image input to the network as a tuple
(h, w). Note that the image may be zero-padded afterward on the
bottom/right if the batch tensor is larger than this shape.
- "scale_factor": A tuple indicating the ratio of width and height
of the preprocessed image to the original one.
- "ori_shape": Shape of the preprocessed image as a tuple
(h, w).
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
the metainfo of ``TextDetSample``. Defaults to ``('img_path',
'ori_shape', 'img_shape', 'scale_factor', 'flip',
'flip_direction')``.
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_edges_labels': 'edge_labels',
'gt_texts': 'texts',
}
def __init__(self, meta_keys=()):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
else:
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)
data_sample = KIEDataSample()
instance_data = InstanceData()
for key in self.mapping_table.keys():
if key not in results:
continue
if key in ['gt_bboxes', 'gt_bboxes_labels', 'gt_edges_labels']:
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
else:
instance_data[self.mapping_table[key]] = results[key]
data_sample.gt_instances = instance_data
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str