Spaces:
Sleeping
Sleeping
File size: 3,725 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from mmcv.transforms.base import BaseTransform
from mmdet.structures.mask import PolygonMasks, bitmap_to_polygon
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MMDet2MMOCR(BaseTransform):
"""Convert transforms's data format from MMDet to MMOCR.
Required Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Added Keys:
- gt_polygons (list[np.ndarray])
- gt_ignored (np.ndarray)
"""
def transform(self, results: Dict) -> Dict:
"""Convert MMDet's data format to MMOCR's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_masks -> gt_polygons
if 'gt_masks' in results.keys():
gt_polygons = []
gt_masks = results.pop('gt_masks')
if len(gt_masks) > 0:
# PolygonMasks
if isinstance(gt_masks[0], PolygonMasks):
gt_polygons = [mask[0] for mask in gt_masks.masks]
# BitmapMasks
else:
polygons = []
for mask in gt_masks.masks:
contours, _ = bitmap_to_polygon(mask)
polygons += [
contour.reshape(-1) for contour in contours
]
# filter invalid polygons
gt_polygons = []
for polygon in polygons:
if len(polygon) < 6:
continue
gt_polygons.append(polygon)
results['gt_polygons'] = gt_polygons
# gt_ignore_flags -> gt_ignored
if 'gt_ignore_flags' in results.keys():
gt_ignored = results.pop('gt_ignore_flags')
results['gt_ignored'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class MMOCR2MMDet(BaseTransform):
"""Convert transforms's data format from MMOCR to MMDet.
Required Keys:
- img_shape
- gt_polygons (List[ndarray]) (optional)
- gt_ignored (np.bool) (optional)
Added Keys:
- gt_masks (PolygonMasks | BitmapMasks) (optional)
- gt_ignore_flags (np.bool) (optional)
Args:
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
"""
def __init__(self, poly2mask: bool = False) -> None:
self.poly2mask = poly2mask
def transform(self, results: Dict) -> Dict:
"""Convert MMOCR's data format to MMDet's data format.
Args:
results (Dict): Result dict containing the data to transform.
Returns:
(Dict): The transformed data.
"""
# gt_polygons -> gt_masks
if 'gt_polygons' in results.keys():
gt_polygons = results.pop('gt_polygons')
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])
if self.poly2mask:
gt_masks = gt_masks.to_bitmap()
results['gt_masks'] = gt_masks
# gt_ignore_flags -> gt_ignored
if 'gt_ignored' in results.keys():
gt_ignored = results.pop('gt_ignored')
results['gt_ignore_flags'] = gt_ignored
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(poly2mask = {self.poly2mask})'
return repr_str
|