# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np from doctr.models.builder import DocumentBuilder from doctr.utils.geometry import extract_crops, extract_rcrops from .._utils import rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor from ..classification.predictor import OrientationPredictor __all__ = ["_OCRPredictor"] class _OCRPredictor: """Implements an object able to localize and identify text elements in a set of documents Args: ---- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. **kwargs: keyword args of `DocumentBuilder` """ crop_orientation_predictor: Optional[OrientationPredictor] def __init__( self, assume_straight_pages: bool = True, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, **kwargs: Any, ) -> None: self.assume_straight_pages = assume_straight_pages self.straighten_pages = straighten_pages self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) self.doc_builder = DocumentBuilder(**kwargs) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad self.hooks: List[Callable] = [] @staticmethod def _generate_crops( pages: List[np.ndarray], loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, ) -> List[List[np.ndarray]]: extraction_fn = extract_crops if assume_straight_pages else extract_rcrops crops = [ extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] for page, _boxes in zip(pages, loc_preds) ] return crops @staticmethod def _prepare_crops( pages: List[np.ndarray], loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) # Avoid sending zero-sized crops is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] crops = [ [crop for crop, _kept in zip(page_crops, page_kept) if _kept] for page_crops, page_kept in zip(crops, is_kept) ] loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] return crops, loc_preds def _rectify_crops( self, crops: List[List[np.ndarray]], loc_preds: List[np.ndarray], ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]: # Work at a page level orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc] rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] rect_loc_preds = [ rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds for page_loc_preds, orientation in zip(loc_preds, orientations) ] # Flatten to list of tuples with (value, confidence) crop_orientations = [ (orientation, prob) for page_classes, page_probs in zip(classes, probs) for orientation, prob in zip(page_classes, page_probs) ] return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value] def _remove_padding( self, pages: List[np.ndarray], loc_preds: List[np.ndarray], ) -> List[np.ndarray]: if self.preserve_aspect_ratio: # Rectify loc_preds to remove padding rectified_preds = [] for page, loc_pred in zip(pages, loc_preds): h, w = page.shape[0], page.shape[1] if h > w: # y unchanged, dilate x coord if self.symmetric_pad: if self.assume_straight_pages: loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1) else: loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1) else: if self.assume_straight_pages: loc_pred[:, [0, 2]] *= h / w else: loc_pred[:, :, 0] *= h / w elif w > h: # x unchanged, dilate y coord if self.symmetric_pad: if self.assume_straight_pages: loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1) else: loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1) else: if self.assume_straight_pages: loc_pred[:, [1, 3]] *= w / h else: loc_pred[:, :, 1] *= w / h rectified_preds.append(loc_pred) return rectified_preds return loc_preds @staticmethod def _process_predictions( loc_preds: List[np.ndarray], word_preds: List[Tuple[str, float]], crop_orientations: List[Dict[str, Any]], ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]: text_preds = [] crop_orientation_preds = [] if len(loc_preds) > 0: # Text & crop orientation predictions at page level _idx = 0 for page_boxes in loc_preds: text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]]) crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]]) _idx += page_boxes.shape[0] return loc_preds, text_preds, crop_orientation_preds def add_hook(self, hook: Callable) -> None: """Add a hook to the predictor Args: ---- hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` """ self.hooks.append(hook)