Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Sequence, Tuple | |
import torch | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextRecogDataSample | |
from .base import BaseTextRecogPostprocessor | |
# TODO support beam search | |
class CTCPostProcessor(BaseTextRecogPostprocessor): | |
"""PostProcessor for CTC.""" | |
def get_single_prediction(self, probs: torch.Tensor, | |
data_sample: TextRecogDataSample | |
) -> Tuple[Sequence[int], Sequence[float]]: | |
"""Convert the output probabilities of a single image to index and | |
score. | |
Args: | |
probs (torch.Tensor): Character probabilities with shape | |
:math:`(T, C)`. | |
data_sample (TextRecogDataSample): Datasample of an image. | |
Returns: | |
tuple(list[int], list[float]): index and score. | |
""" | |
feat_len = probs.size(0) | |
max_value, max_idx = torch.max(probs, -1) | |
valid_ratio = data_sample.get('valid_ratio', 1) | |
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio)) | |
index = [] | |
score = [] | |
prev_idx = self.dictionary.padding_idx | |
for t in range(decode_len): | |
tmp_value = max_idx[t].item() | |
if tmp_value not in (prev_idx, *self.ignore_indexes): | |
index.append(tmp_value) | |
score.append(max_value[t].item()) | |
prev_idx = tmp_value | |
return index, score | |
def __call__( | |
self, outputs: torch.Tensor, | |
data_samples: Sequence[TextRecogDataSample] | |
) -> Sequence[TextRecogDataSample]: | |
outputs = outputs.cpu().detach() | |
return super().__call__(outputs, data_samples) | |