File size: 1,774 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence, Tuple

import torch

from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseTextRecogPostprocessor


# TODO support beam search
@MODELS.register_module()
class CTCPostProcessor(BaseTextRecogPostprocessor):
    """PostProcessor for CTC."""

    def get_single_prediction(self, probs: torch.Tensor,
                              data_sample: TextRecogDataSample
                              ) -> Tuple[Sequence[int], Sequence[float]]:
        """Convert the output probabilities of a single image to index and
        score.

        Args:
            probs (torch.Tensor): Character probabilities with shape
                :math:`(T, C)`.
            data_sample (TextRecogDataSample): Datasample of an image.

        Returns:
            tuple(list[int], list[float]): index and score.
        """
        feat_len = probs.size(0)
        max_value, max_idx = torch.max(probs, -1)
        valid_ratio = data_sample.get('valid_ratio', 1)
        decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
        index = []
        score = []

        prev_idx = self.dictionary.padding_idx
        for t in range(decode_len):
            tmp_value = max_idx[t].item()
            if tmp_value not in (prev_idx, *self.ignore_indexes):
                index.append(tmp_value)
                score.append(max_value[t].item())
            prev_idx = tmp_value
        return index, score

    def __call__(
        self, outputs: torch.Tensor,
        data_samples: Sequence[TextRecogDataSample]
    ) -> Sequence[TextRecogDataSample]:
        outputs = outputs.cpu().detach()
        return super().__call__(outputs, data_samples)