adirathor07's picture
added doctr folder
153628e
raw
history blame
2.55 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import Any, List, Tuple, Union
import numpy as np
import tensorflow as tf
from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
from ..core import RecognitionModel
from ._utils import remap_preds, split_crops
__all__ = ["RecognitionPredictor"]
class RecognitionPredictor(NestedObject):
"""Implements an object able to identify character sequences in images
Args:
----
pre_processor: transform inputs for easier batched model inference
model: core detection architecture
split_wide_crops: wether to use crop splitting for high aspect ratio crops
"""
_children_names: List[str] = ["pre_processor", "model"]
def __init__(
self,
pre_processor: PreProcessor,
model: RecognitionModel,
split_wide_crops: bool = True,
) -> None:
super().__init__()
self.pre_processor = pre_processor
self.model = model
self.split_wide_crops = split_wide_crops
self.critical_ar = 8 # Critical aspect ratio
self.dil_factor = 1.4 # Dilation factor to overlap the crops
self.target_ar = 6 # Target aspect ratio
def __call__(
self,
crops: List[Union[np.ndarray, tf.Tensor]],
**kwargs: Any,
) -> List[Tuple[str, float]]:
if len(crops) == 0:
return []
# Dimension check
if any(crop.ndim != 3 for crop in crops):
raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
# Split crops that are too wide
remapped = False
if self.split_wide_crops:
new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor)
if remapped:
crops = new_crops
# Resize & batch them
processed_batches = self.pre_processor(crops)
# Forward it
raw = [
self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator]
for batch in processed_batches
]
# Process outputs
out = [charseq for batch in raw for charseq in batch]
# Remap crops
if self.split_wide_crops and remapped:
out = remap_preds(out, crop_map, self.dil_factor)
return out