# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import List, Tuple import numpy as np from doctr.datasets import encode_sequences from doctr.utils.repr import NestedObject __all__ = ["RecognitionPostProcessor", "RecognitionModel"] class RecognitionModel(NestedObject): """Implements abstract RecognitionModel class""" vocab: str max_length: int def build_target( self, gts: List[str], ) -> Tuple[np.ndarray, List[int]]: """Encode a list of gts sequences into a np array and gives the corresponding* sequence lengths. Args: ---- gts: list of ground-truth labels Returns: ------- A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) """ encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab)) seq_len = [len(word) for word in gts] return encoded, seq_len class RecognitionPostProcessor(NestedObject): """Abstract class to postprocess the raw output of the model Args: ---- vocab: string containing the ordered sequence of supported characters """ def __init__( self, vocab: str, ) -> None: self.vocab = vocab self._embedding = list(self.vocab) + [""] def extra_repr(self) -> str: return f"vocab_size={len(self.vocab)}"