import os from abc import ABC from pathlib import Path from typing import Any, List, Literal, Mapping, Optional, Tuple from zipfile import ZipFile import json from typing import Any, List, Literal, Mapping, Optional,Dict import uuid from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611 from doctr.models.recognition.zoo import ARCHS, recognition import torch # Numpy image type import numpy.typing as npt from numpy import uint8 ImageType = npt.NDArray[uint8] from utils import WordAnnotation,getlogger class DoctrTextRecognizer(): def __init__( self, architecture: str, path_weights: str, path_config_json: str = None, ) -> None: """ :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn", "crnn_mobilenet_v3_small". The full list can be found here: https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16. :param path_weights: Path to the weights of the model :param device: "cpu" or "cuda". :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used. :param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have a model trained on custom vocab. """ self.architecture = architecture self.path_weights = path_weights self.name = self.get_name(self.path_weights, self.architecture) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.path_config_json = path_config_json self.built_model = self.build_model(self.architecture, self.path_config_json) self.load_model(self.path_weights, self.built_model, self.device) self.doctr_predictor = self.get_wrapped_model() def predict(self, inputs: Dict[uuid.UUID, Tuple[ImageType,WordAnnotation]]) -> List[WordAnnotation]: """ Prediction on a batch of text lines :param images: Dictionary where key is word's object id and the value is tupe of cropped image and word annotation :return: A list of DetectionResult """ if inputs: predictor =self.doctr_predictor device = self.device word_uuids = list(inputs.keys()) cropped_images = [value[0] for value in inputs.values()] raw_output = predictor(list(cropped_images)) det_results =[] for uuid, output in zip(word_uuids, raw_output): ann = inputs[uuid][1] ann.text = output[0] det_results.append(ann) return det_results return [] def predict_for_tables(self, inputs: List[ImageType]) -> List[str]: if inputs: predictor =self.doctr_predictor device = self.device raw_output = predictor(list(inputs)) det_results =[] for output in raw_output: det_results.append(output[0]) return det_results return [] @staticmethod def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None: """Loading model weights 1. Load the State Dictionary: state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device. 2. Modify Keys in the State Dictionary: The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model. 3. Load State Dictionary into Model: doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model. 4. Move Model to Device: doctr_predictor.to(device) moves the model to the specified device. """ state_dict = torch.load(path_weights, map_location=device) for key in list(state_dict.keys()): state_dict["model." + key] = state_dict.pop(key) doctr_predictor.load_state_dict(state_dict) doctr_predictor.to(device) @staticmethod def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor": """Building the model 1. Specific keys (arch, url, task) are removed from custom_configs. mean and std values are moved to recognition_configs. 2. Creating model Check Architecture Type: Case 1 : If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS). If valid, it creates an instance of the model using the specified architecture and custom configurations. Handle Custom Architecture Instances: Case 2 : If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.). If valid, it assigns the provided architecture to model. Get Input Shape and Create RecognitionPredictor: 3. Retrieves the input_shape from the model's configuration. 4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model. """ # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py custom_configs = {} batch_size = 1024 recognition_configs = {} if path_config_json: with open(path_config_json, "r", encoding="utf-8") as f: custom_configs = json.load(f) custom_configs.pop("arch", None) custom_configs.pop("url", None) custom_configs.pop("task", None) recognition_configs["mean"] = custom_configs.pop("mean") recognition_configs["std"] = custom_configs.pop("std") #batch_size = custom_configs.pop("batch_size") recognition_configs["batch_size"] = batch_size if isinstance(architecture, str): if architecture not in ARCHS: raise ValueError(f"unknown architecture '{architecture}'") model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs) else: if not isinstance( architecture, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq), ): raise ValueError(f"unknown architecture: {type(architecture)}") model = architecture input_shape = model.cfg["input_shape"][-2:] """ (class) PreProcessor Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. Args: output_size: expected size of each page in format (H, W) batch_size: the size of page batches mean: mean value of the training distribution by channel std: standard deviation of the training distribution by channel """ return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model) def get_wrapped_model(self) -> Any: """ Get the inner (wrapped) model. """ doctr_predictor = self.build_model(self.architecture, self.path_config_json) device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.load_model(self.path_weights, doctr_predictor, device_str) return doctr_predictor @staticmethod def get_name(path_weights: str, architecture: str) -> str: """Returns the name of the model""" return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])