adirathor07's picture
added doctr folder
153628e
raw
history blame
1.59 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import List, Tuple
import numpy as np
from doctr.datasets import encode_sequences
from doctr.utils.repr import NestedObject
__all__ = ["RecognitionPostProcessor", "RecognitionModel"]
class RecognitionModel(NestedObject):
"""Implements abstract RecognitionModel class"""
vocab: str
max_length: int
def build_target(
self,
gts: List[str],
) -> Tuple[np.ndarray, List[int]]:
"""Encode a list of gts sequences into a np array and gives the corresponding*
sequence lengths.
Args:
----
gts: list of ground-truth labels
Returns:
-------
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
"""
encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
seq_len = [len(word) for word in gts]
return encoded, seq_len
class RecognitionPostProcessor(NestedObject):
"""Abstract class to postprocess the raw output of the model
Args:
----
vocab: string containing the ordered sequence of supported characters
"""
def __init__(
self,
vocab: str,
) -> None:
self.vocab = vocab
self._embedding = list(self.vocab) + ["<eos>"]
def extra_repr(self) -> str:
return f"vocab_size={len(self.vocab)}"