Spaces:
Sleeping
Sleeping
File size: 4,185 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
from mmengine.model import BaseTTAModel
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import RecSampleList
@MODELS.register_module()
class EncoderDecoderRecognizerTTAModel(BaseTTAModel):
"""Merge augmented recognition results. It will select the best result
according average scores from all augmented results.
Examples:
>>> tta_model = dict(
>>> type='EncoderDecoderRecognizerTTAModel')
>>>
>>> tta_pipeline = [
>>> dict(
>>> type='LoadImageFromFile',
>>> color_type='grayscale'),
>>> dict(
>>> type='TestTimeAug',
>>> transforms=[
>>> [
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=3, keep_size=False)])
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]"
>>> ),
>>> ],
>>> [
>>> dict(
>>> type='RescaleToHeight',
>>> height=32,
>>> min_width=32,
>>> max_width=None,
>>> width_divisor=16)
>>> ],
>>> # add loading annotation after ``Resize`` because ground truth
>>> # does not need to do resize data transform
>>> [dict(type='LoadOCRAnnotations', with_text=True)],
>>> [
>>> dict(
>>> type='PackTextRecogInputs',
>>> meta_keys=('img_path', 'ori_shape', 'img_shape',
>>> 'valid_ratio'))
>>> ]
>>> ])
>>> ]
"""
def merge_preds(self,
data_samples_list: List[RecSampleList]) -> RecSampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[RecSampleList]): List of predictions of
all enhanced data. The shape of data_samples_list is (B, M),
where B is the batch size and M is the number of augmented
data.
Returns:
RecSampleList: Merged prediction.
"""
predictions = list()
for data_samples in data_samples_list:
scores = [
data_sample.pred_text.score for data_sample in data_samples
]
average_scores = np.array(
[sum(score) / max(1, len(score)) for score in scores])
max_idx = np.argmax(average_scores)
predictions.append(data_samples[max_idx])
return predictions
|