Spaces:
Running
Running
File size: 6,921 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import cv2
import torch
from mmdet.structures import DetDataSample
from mmdet.structures import SampleList as MMDET_SampleList
from mmdet.structures.mask import bitmap_to_polygon
from mmengine.model import BaseModel
from mmengine.structures import InstanceData
from mmocr.registry import MODELS
from mmocr.utils.bbox_utils import bbox2poly
from mmocr.utils.typing_utils import DetSampleList
ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
Tuple[torch.Tensor], torch.Tensor]
@MODELS.register_module()
class MMDetWrapper(BaseModel):
"""A wrapper of MMDet's model.
Args:
cfg (dict): The config of the model.
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
Defaults to 'poly'.
"""
def __init__(self, cfg: Dict, text_repr_type: str = 'poly') -> None:
data_preprocessor = cfg.pop('data_preprocessor')
data_preprocessor.update(_scope_='mmdet')
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
cfg['_scope_'] = 'mmdet'
self.wrapped_model = MODELS.build(cfg)
self.text_repr_type = text_repr_type
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[Union[DetSampleList,
MMDET_SampleList]] = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method works in three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle either back propagation or
parameter update, which are supposed to be done in :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (list[:obj:`DetDataSample`] or
list[:obj:`TextDetDataSample`]): The annotation data of every
sample. When in "predict" mode, it should be a list of
:obj:`TextDetDataSample`. Otherwise they are
:obj:`DetDataSample`s. Defaults to None.
mode (str): Running mode. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'predict':
ocr_data_samples = data_samples
data_samples = []
for i in range(len(ocr_data_samples)):
data_samples.append(
DetDataSample(metainfo=ocr_data_samples[i].metainfo))
results = self.wrapped_model.forward(inputs, data_samples, mode,
**kwargs)
if mode == 'predict':
results = self.adapt_predictions(results, ocr_data_samples)
return results
def adapt_predictions(self, data: MMDET_SampleList,
data_samples: DetSampleList) -> DetSampleList:
"""Convert Instance datas from MMDet into MMOCR's format.
Args:
data: (list[DetDataSample]): Detection results of the
input images. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor, Optional): Has a shape (num_instances, H, W).
data_samples (list[:obj:`TextDetDataSample`]): The annotation data
of every samples.
Returns:
list[TextDetDataSample]: A list of N datasamples containing ground
truth and prediction results.
The polygon results are saved in
``TextDetDataSample.pred_instances.polygons``
The confidence scores are saved in
``TextDetDataSample.pred_instances.scores``.
"""
for i, det_data_sample in enumerate(data):
data_samples[i].pred_instances = InstanceData()
# convert mask to polygons if mask exists
if 'masks' in det_data_sample.pred_instances.keys():
masks = det_data_sample.pred_instances.masks.cpu().numpy()
polygons = []
scores = []
for mask_idx, mask in enumerate(masks):
contours, _ = bitmap_to_polygon(mask)
polygons += [contour.reshape(-1) for contour in contours]
scores += [
det_data_sample.pred_instances.scores[mask_idx].cpu()
] * len(contours)
# filter invalid polygons
filterd_polygons = []
keep_idx = []
for poly_idx, polygon in enumerate(polygons):
if len(polygon) < 6:
continue
filterd_polygons.append(polygon)
keep_idx.append(poly_idx)
# convert by text_repr_type
if self.text_repr_type == 'quad':
for j, poly in enumerate(filterd_polygons):
rect = cv2.minAreaRect(poly)
vertices = cv2.boxPoints(rect)
poly = vertices.flatten()
filterd_polygons[j] = poly
data_samples[i].pred_instances.polygons = filterd_polygons
data_samples[i].pred_instances.scores = torch.FloatTensor(
scores)[keep_idx]
else:
bboxes = det_data_sample.pred_instances.bboxes.cpu().numpy()
polygons = [bbox2poly(bbox) for bbox in bboxes]
data_samples[i].pred_instances.polygons = polygons
data_samples[i].pred_instances.scores = torch.FloatTensor(
det_data_sample.pred_instances.scores.cpu())
return data_samples
|