diff --git a/doctr/__init__.py b/doctr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5620ba3fd315d2e8cc4e01f9306c1eff3b2bd12a --- /dev/null +++ b/doctr/__init__.py @@ -0,0 +1,3 @@ +from . import io, models, datasets, contrib, transforms, utils +from .file_utils import is_tf_available, is_torch_available +from .version import __version__ # noqa: F401 diff --git a/doctr/__pycache__/__init__.cpython-310.pyc b/doctr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b476da105c612b23ec7c59488e2bba2081e212 Binary files /dev/null and b/doctr/__pycache__/__init__.cpython-310.pyc differ diff --git a/doctr/__pycache__/__init__.cpython-311.pyc b/doctr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40adf57a79043f3d77e3b61b972a8a93f37ddce5 Binary files /dev/null and b/doctr/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/__pycache__/__init__.cpython-38.pyc b/doctr/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56e6364c3d31614ef7d283d4a8bd1c88e806ed3c Binary files /dev/null and b/doctr/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/__pycache__/file_utils.cpython-310.pyc b/doctr/__pycache__/file_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d675d1a11e87ed3f8b788ab01563096145f699 Binary files /dev/null and b/doctr/__pycache__/file_utils.cpython-310.pyc differ diff --git a/doctr/__pycache__/file_utils.cpython-311.pyc b/doctr/__pycache__/file_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd49de5d0243ec84fdf21e3203b160db3f2302cf Binary files /dev/null and b/doctr/__pycache__/file_utils.cpython-311.pyc differ diff --git a/doctr/__pycache__/file_utils.cpython-38.pyc b/doctr/__pycache__/file_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1573ad53f835cf6f0993e21ab9950bb6930fa47e Binary files /dev/null and b/doctr/__pycache__/file_utils.cpython-38.pyc differ diff --git a/doctr/__pycache__/version.cpython-311.pyc b/doctr/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a9227e53b9414d928b1ae9c9abdffa942fc170f Binary files /dev/null and b/doctr/__pycache__/version.cpython-311.pyc differ diff --git a/doctr/__pycache__/version.cpython-38.pyc b/doctr/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8380277719f302a7608095f2b06621c5cb83621 Binary files /dev/null and b/doctr/__pycache__/version.cpython-38.pyc differ diff --git a/doctr/contrib/__init__.py b/doctr/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/doctr/contrib/__pycache__/__init__.cpython-311.pyc b/doctr/contrib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b0578462ce3cfcce8f7b73d9ea5ef84cef9c098 Binary files /dev/null and b/doctr/contrib/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/contrib/__pycache__/__init__.cpython-38.pyc b/doctr/contrib/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbdbf037b04b393baf35f0c8c188c9683a162502 Binary files /dev/null and b/doctr/contrib/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/contrib/artefacts.py b/doctr/contrib/artefacts.py new file mode 100644 index 0000000000000000000000000000000000000000..646e1991869afb156e8c38a05f1fd8cd5bec3abf --- /dev/null +++ b/doctr/contrib/artefacts.py @@ -0,0 +1,131 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + +from doctr.file_utils import requires_package + +from .base import _BasePredictor + +__all__ = ["ArtefactDetector"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "yolov8_artefact": { + "input_shape": (3, 1024, 1024), + "labels": ["bar_code", "qr_code", "logo", "photo"], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0", + }, +} + + +class ArtefactDetector(_BasePredictor): + """ + A class to detect artefacts in images + + >>> from doctr.io import DocumentFile + >>> from doctr.contrib.artefacts import ArtefactDetector + >>> doc = DocumentFile.from_images(["path/to/image.jpg"]) + >>> detector = ArtefactDetector() + >>> results = detector(doc) + + Args: + ---- + arch: the architecture to use + batch_size: the batch size to use + model_path: the path to the model to use + labels: the labels to use + input_shape: the input shape to use + mask_labels: the mask labels to use + conf_threshold: the confidence threshold to use + iou_threshold: the intersection over union threshold to use + **kwargs: additional arguments to be passed to `download_from_url` + """ + + def __init__( + self, + arch: str = "yolov8_artefact", + batch_size: int = 2, + model_path: Optional[str] = None, + labels: Optional[List[str]] = None, + input_shape: Optional[Tuple[int, int, int]] = None, + conf_threshold: float = 0.5, + iou_threshold: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs) + self.labels = labels or default_cfgs[arch]["labels"] + self.input_shape = input_shape or default_cfgs[arch]["input_shape"] + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + + def preprocess(self, img: np.ndarray) -> np.ndarray: + return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0) + + def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]: + results = [] + + for batch in zip(output, input_images): + for out, img in zip(batch[0], batch[1]): + org_height, org_width = img.shape[:2] + width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1] + for res in out: + sample_results = [] + for row in np.transpose(np.squeeze(res)): + classes_scores = row[4:] + max_score = np.amax(classes_scores) + if max_score >= self.conf_threshold: + class_id = np.argmax(classes_scores) + x, y, w, h = row[0], row[1], row[2], row[3] + # to rescaled xmin, ymin, xmax, ymax + xmin = int((x - w / 2) * width_scale) + ymin = int((y - h / 2) * height_scale) + xmax = int((x + w / 2) * width_scale) + ymax = int((y + h / 2) * height_scale) + + sample_results.append({ + "label": self.labels[class_id], + "confidence": float(max_score), + "box": [xmin, ymin, xmax, ymax], + }) + + # Filter out overlapping boxes + boxes = [res["box"] for res in sample_results] + scores = [res["confidence"] for res in sample_results] + keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type] + sample_results = [sample_results[i] for i in keep_indices] + + results.append(sample_results) + + self._results = results + return results + + def show(self, **kwargs: Any) -> None: + """ + Display the results + + Args: + ---- + **kwargs: additional keyword arguments to be passed to `plt.show` + """ + requires_package("matplotlib", "`.show()` requires matplotlib installed") + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + # visualize the results with matplotlib + if self._results and self._inputs: + for img, res in zip(self._inputs, self._results): + plt.figure(figsize=(10, 10)) + plt.imshow(img) + for obj in res: + xmin, ymin, xmax, ymax = obj["box"] + label = obj["label"] + plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red") + plt.gca().add_patch( + Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2) + ) + plt.show(**kwargs) diff --git a/doctr/contrib/base.py b/doctr/contrib/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6583438333c4ac0d16b54c7b5210e822cca73d --- /dev/null +++ b/doctr/contrib/base.py @@ -0,0 +1,105 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Optional + +import numpy as np + +from doctr.file_utils import requires_package +from doctr.utils.data import download_from_url + + +class _BasePredictor: + """ + Base class for all predictors + + Args: + ---- + batch_size: the batch size to use + url: the url to use to download a model if needed + model_path: the path to the model to use + **kwargs: additional arguments to be passed to `download_from_url` + """ + + def __init__(self, batch_size: int, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs) -> None: + self.batch_size = batch_size + self.session = self._init_model(url, model_path, **kwargs) + + self._inputs: List[np.ndarray] = [] + self._results: List[Any] = [] + + def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = None, **kwargs: Any) -> Any: + """ + Download the model from the given url if needed + + Args: + ---- + url: the url to use + model_path: the path to the model to use + **kwargs: additional arguments to be passed to `download_from_url` + + Returns: + ------- + Any: the ONNX loaded model + """ + requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") + import onnxruntime as ort + + if not url and not model_path: + raise ValueError("You must provide either a url or a model_path") + onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] + return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + + def preprocess(self, img: np.ndarray) -> np.ndarray: + """ + Preprocess the input image + + Args: + ---- + img: the input image to preprocess + + Returns: + ------- + np.ndarray: the preprocessed image + """ + raise NotImplementedError + + def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> Any: + """ + Postprocess the model output + + Args: + ---- + output: the model output to postprocess + input_images: the input images used to generate the output + + Returns: + ------- + Any: the postprocessed output + """ + raise NotImplementedError + + def __call__(self, inputs: List[np.ndarray]) -> Any: + """ + Call the model on the given inputs + + Args: + ---- + inputs: the inputs to use + + Returns: + ------- + Any: the postprocessed output + """ + self._inputs = inputs + model_inputs = self.session.get_inputs() + + batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)] + processed_batches = [ + np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs + ] + + outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches] + return self.postprocess(outputs, batched_inputs) diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52035ff062341470f53badc8e023477fbfb2b3 --- /dev/null +++ b/doctr/datasets/__init__.py @@ -0,0 +1,26 @@ +from doctr.file_utils import is_tf_available + +from .generator import * +from .cord import * +from .detection import * +from .doc_artefacts import * +from .funsd import * +from .ic03 import * +from .ic13 import * +from .iiit5k import * +from .iiithws import * +from .imgur5k import * +from .mjsynth import * +from .ocr import * +from .recognition import * +from .orientation import * +from .sroie import * +from .svhn import * +from .svt import * +from .synthtext import * +from .utils import * +from .vocabs import * +from .wildreceipt import * + +if is_tf_available(): + from .loader import * diff --git a/doctr/datasets/__pycache__/__init__.cpython-311.pyc b/doctr/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f41a942291e6fa48622e98febcbfe2ada553e89 Binary files /dev/null and b/doctr/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/__init__.cpython-38.pyc b/doctr/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39d5bb334c178afc9eddb95e18c15ef65568d955 Binary files /dev/null and b/doctr/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/cord.cpython-311.pyc b/doctr/datasets/__pycache__/cord.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..238775191d3b564654317ff193757d05acf3090a Binary files /dev/null and b/doctr/datasets/__pycache__/cord.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/cord.cpython-38.pyc b/doctr/datasets/__pycache__/cord.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10fcc1777cd8e57634ca9411e58ef26b5456ab2c Binary files /dev/null and b/doctr/datasets/__pycache__/cord.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/detection.cpython-311.pyc b/doctr/datasets/__pycache__/detection.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310311afb03ddbdb84f8a585dbee612a349297f4 Binary files /dev/null and b/doctr/datasets/__pycache__/detection.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/detection.cpython-38.pyc b/doctr/datasets/__pycache__/detection.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38f5109a9089e39d5f3b28b6f90342ce96554a6 Binary files /dev/null and b/doctr/datasets/__pycache__/detection.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/doc_artefacts.cpython-311.pyc b/doctr/datasets/__pycache__/doc_artefacts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab027cbfa6c880e829f842e8d745ed5f3272534d Binary files /dev/null and b/doctr/datasets/__pycache__/doc_artefacts.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/doc_artefacts.cpython-38.pyc b/doctr/datasets/__pycache__/doc_artefacts.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5541aa5bdb98afb257b727e53dcf90a8d8371ee4 Binary files /dev/null and b/doctr/datasets/__pycache__/doc_artefacts.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/funsd.cpython-311.pyc b/doctr/datasets/__pycache__/funsd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f311359ceec1871f033335b6714f2236024c53 Binary files /dev/null and b/doctr/datasets/__pycache__/funsd.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/funsd.cpython-38.pyc b/doctr/datasets/__pycache__/funsd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d84f1045a1fa335426e94e281d36f216cde259e Binary files /dev/null and b/doctr/datasets/__pycache__/funsd.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/ic03.cpython-311.pyc b/doctr/datasets/__pycache__/ic03.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd515913d7e784b0d5daa6acb04d5dfe6f6d1c0e Binary files /dev/null and b/doctr/datasets/__pycache__/ic03.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/ic03.cpython-38.pyc b/doctr/datasets/__pycache__/ic03.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b754ea3abfd9c3bd3ba380e57b55f1b0c59cf44a Binary files /dev/null and b/doctr/datasets/__pycache__/ic03.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/ic13.cpython-311.pyc b/doctr/datasets/__pycache__/ic13.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b6eee02e377301b3ef44aa0618bb7470751915 Binary files /dev/null and b/doctr/datasets/__pycache__/ic13.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/ic13.cpython-38.pyc b/doctr/datasets/__pycache__/ic13.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c28405d785959f9341236a7b43dbb3980fdaa9 Binary files /dev/null and b/doctr/datasets/__pycache__/ic13.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/iiit5k.cpython-311.pyc b/doctr/datasets/__pycache__/iiit5k.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cece2efde0d3b7fcaa8b8fc055b7fe9bfa70ced Binary files /dev/null and b/doctr/datasets/__pycache__/iiit5k.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/iiit5k.cpython-38.pyc b/doctr/datasets/__pycache__/iiit5k.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbe08c05c2010d9c6bb5d90a4c5a4ebbd3ef9f9b Binary files /dev/null and b/doctr/datasets/__pycache__/iiit5k.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/iiithws.cpython-311.pyc b/doctr/datasets/__pycache__/iiithws.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f8346f51d60a185895a4fc543d18d109da76dd8 Binary files /dev/null and b/doctr/datasets/__pycache__/iiithws.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/iiithws.cpython-38.pyc b/doctr/datasets/__pycache__/iiithws.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fba943f58de9ef157e61d5f9fe6cf3297227bd6 Binary files /dev/null and b/doctr/datasets/__pycache__/iiithws.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/imgur5k.cpython-311.pyc b/doctr/datasets/__pycache__/imgur5k.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a63cf09435ef717e588dc49c3718d96f2b185ab Binary files /dev/null and b/doctr/datasets/__pycache__/imgur5k.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/imgur5k.cpython-38.pyc b/doctr/datasets/__pycache__/imgur5k.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30eb3152143f35afa5edd3bf29a528e36c85bbe4 Binary files /dev/null and b/doctr/datasets/__pycache__/imgur5k.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/loader.cpython-311.pyc b/doctr/datasets/__pycache__/loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24d6cf73b1f786a38a6ad314a0d53a2b90dc79a Binary files /dev/null and b/doctr/datasets/__pycache__/loader.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/loader.cpython-38.pyc b/doctr/datasets/__pycache__/loader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdcbedc50c352852b16c99fcde6063551e43d2ea Binary files /dev/null and b/doctr/datasets/__pycache__/loader.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/mjsynth.cpython-311.pyc b/doctr/datasets/__pycache__/mjsynth.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac25dc23bd56d7d2a3f467868def7fbc161746ab Binary files /dev/null and b/doctr/datasets/__pycache__/mjsynth.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/mjsynth.cpython-38.pyc b/doctr/datasets/__pycache__/mjsynth.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a245112c4bbd50b3553b1825f00f17c2db3b2bc Binary files /dev/null and b/doctr/datasets/__pycache__/mjsynth.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/ocr.cpython-311.pyc b/doctr/datasets/__pycache__/ocr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6436079caae239f054bb84c4edf2729f1a69578d Binary files /dev/null and b/doctr/datasets/__pycache__/ocr.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/ocr.cpython-38.pyc b/doctr/datasets/__pycache__/ocr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc3c92a845d8c25052853b63419b8193d84d8a41 Binary files /dev/null and b/doctr/datasets/__pycache__/ocr.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/orientation.cpython-311.pyc b/doctr/datasets/__pycache__/orientation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdc18b7da82c1a7ccb17f53967be6ad61c8ea2c2 Binary files /dev/null and b/doctr/datasets/__pycache__/orientation.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/orientation.cpython-38.pyc b/doctr/datasets/__pycache__/orientation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adbb4a651e500f1ddc452c64af4b8d790ad5d034 Binary files /dev/null and b/doctr/datasets/__pycache__/orientation.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/recognition.cpython-311.pyc b/doctr/datasets/__pycache__/recognition.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8848acd3e230da0f7ed63eec26a1aeabf48c5a Binary files /dev/null and b/doctr/datasets/__pycache__/recognition.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/recognition.cpython-38.pyc b/doctr/datasets/__pycache__/recognition.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37120e96eac9574a9ceb0e6f407ec0547ddf274a Binary files /dev/null and b/doctr/datasets/__pycache__/recognition.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/sroie.cpython-311.pyc b/doctr/datasets/__pycache__/sroie.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcea97184fe7bb01ac6b13c7ebc0ebfc5eba20c3 Binary files /dev/null and b/doctr/datasets/__pycache__/sroie.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/sroie.cpython-38.pyc b/doctr/datasets/__pycache__/sroie.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f401e846ef3fb6c6f4f0343e05987e8f0507227 Binary files /dev/null and b/doctr/datasets/__pycache__/sroie.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/svhn.cpython-311.pyc b/doctr/datasets/__pycache__/svhn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7160853abb61974e3fc989c226e910a10c3afa3 Binary files /dev/null and b/doctr/datasets/__pycache__/svhn.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/svhn.cpython-38.pyc b/doctr/datasets/__pycache__/svhn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98aef1c267a75c76a17558c4943b8fced371acc6 Binary files /dev/null and b/doctr/datasets/__pycache__/svhn.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/svt.cpython-311.pyc b/doctr/datasets/__pycache__/svt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e54be1c72ba9ceaa3370e376d052c030d6b9d7 Binary files /dev/null and b/doctr/datasets/__pycache__/svt.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/svt.cpython-38.pyc b/doctr/datasets/__pycache__/svt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a1a7961732f3bae7fd0a2719915ebfc63b76189 Binary files /dev/null and b/doctr/datasets/__pycache__/svt.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/synthtext.cpython-311.pyc b/doctr/datasets/__pycache__/synthtext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93adac5ecc81ed3618d2c23c46fce0a6b3f753e1 Binary files /dev/null and b/doctr/datasets/__pycache__/synthtext.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/synthtext.cpython-38.pyc b/doctr/datasets/__pycache__/synthtext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef28dc169595dab158d328f59f5666188156be6 Binary files /dev/null and b/doctr/datasets/__pycache__/synthtext.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/utils.cpython-311.pyc b/doctr/datasets/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d35f419e7298b0e70ab42cd1ff99781ed78827 Binary files /dev/null and b/doctr/datasets/__pycache__/utils.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/utils.cpython-38.pyc b/doctr/datasets/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c1a8e06aaf2138da7d34fdcf837bdd4b1ce1ad1 Binary files /dev/null and b/doctr/datasets/__pycache__/utils.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/vocabs.cpython-311.pyc b/doctr/datasets/__pycache__/vocabs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..400ac4e2fcf1fb70983c31d304314ca7b0fb906c Binary files /dev/null and b/doctr/datasets/__pycache__/vocabs.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/vocabs.cpython-38.pyc b/doctr/datasets/__pycache__/vocabs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c54775e81ba0d838c772770348f0c26714812563 Binary files /dev/null and b/doctr/datasets/__pycache__/vocabs.cpython-38.pyc differ diff --git a/doctr/datasets/__pycache__/wildreceipt.cpython-311.pyc b/doctr/datasets/__pycache__/wildreceipt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1e1a5d1915dbc04fc2b345552d7bc0d2212a80 Binary files /dev/null and b/doctr/datasets/__pycache__/wildreceipt.cpython-311.pyc differ diff --git a/doctr/datasets/__pycache__/wildreceipt.cpython-38.pyc b/doctr/datasets/__pycache__/wildreceipt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..818700cdb74c09ec4096996717231ce121413e13 Binary files /dev/null and b/doctr/datasets/__pycache__/wildreceipt.cpython-38.pyc differ diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py new file mode 100644 index 0000000000000000000000000000000000000000..b88fbb28e89e3327b5ce5603bf6cd865b8febb3b --- /dev/null +++ b/doctr/datasets/cord.py @@ -0,0 +1,121 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["CORD"] + + +class CORD(VisionDataset): + """CORD dataset from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/cord-grid.png&src=0 + :align: center + + >>> from doctr.datasets import CORD + >>> train_set = CORD(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ( + "https://doctr-static.mindee.com/models?id=v0.1.1/cord_train.zip&src=0", + "45f9dc77f126490f3e52d7cb4f70ef3c57e649ea86d19d862a2757c9c455d7f8", + "cord_train.zip", + ) + + TEST = ( + "https://doctr-static.mindee.com/models?id=v0.1.1/cord_test.zip&src=0", + "8c895e3d6f7e1161c5b7245e3723ce15c04d84be89eaa6093949b75a66fb3c58", + "cord_test.zip", + ) + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + url, sha256, name = self.TRAIN if train else self.TEST + super().__init__( + url, + name, + sha256, + True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + + # List images + tmp_root = os.path.join(self.root, "image") + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.train = train + np_dtype = np.float32 + for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + _targets = [] + with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f: + label = json.load(f) + for line in label["valid_line"]: + for word in line["words"]: + if len(word["text"]) > 0: + x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"] + y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"] + box: Union[List[float], np.ndarray] + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box = np.array( + [ + [x[0], y[0]], + [x[1], y[1]], + [x[2], y[2]], + [x[3], y[3]], + ], + dtype=np_dtype, + ) + else: + # Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax + box = [min(x), min(y), max(x), max(y)] + _targets.append((word["text"], box)) + + text_targets, box_targets = zip(*_targets) + + if recognition_task: + crops = crop_bboxes_from_image( + img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0) + ) + for crop, label in zip(crops, list(text_targets)): + self.data.append((crop, label)) + else: + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)), + )) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/datasets/__init__.py b/doctr/datasets/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/datasets/datasets/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/datasets/datasets/__pycache__/__init__.cpython-311.pyc b/doctr/datasets/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09bc9e02c5a68966fb136d72785841b1109f4ee1 Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/__init__.cpython-38.pyc b/doctr/datasets/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f7f4de158abd3c979eb0c275145b9f5345265ff Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/base.cpython-311.pyc b/doctr/datasets/datasets/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8831ecaa1533722b7c8f9836812ec882a221c42e Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/base.cpython-38.pyc b/doctr/datasets/datasets/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72030334c6cbe8c618c45b848f811032ffefd0b8 Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/pytorch.cpython-311.pyc b/doctr/datasets/datasets/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c83d640bf112cc5988e49f9a1c730e4fadf06dd Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/tensorflow.cpython-311.pyc b/doctr/datasets/datasets/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..524c349f17e1b9f512608be7847e787a6b0aa67a Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/datasets/datasets/__pycache__/tensorflow.cpython-38.pyc b/doctr/datasets/datasets/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efd2c5dbe8b6a95e78d92853be39bda0b84d5894 Binary files /dev/null and b/doctr/datasets/datasets/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..58f1ca29f6b5e3eae62d587a9444fb63b2e4c340 --- /dev/null +++ b/doctr/datasets/datasets/base.py @@ -0,0 +1,132 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +import shutil +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np + +from doctr.io.image import get_img_shape +from doctr.utils.data import download_from_url + +from ...models.utils import _copy_tensor + +__all__ = ["_AbstractDataset", "_VisionDataset"] + + +class _AbstractDataset: + data: List[Any] = [] + _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None + + def __init__( + self, + root: Union[str, Path], + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + if not Path(root).is_dir(): + raise ValueError(f"expected a path to a reachable folder: {root}") + + self.root = root + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + self._pre_transforms = pre_transforms + self._get_img_shape = get_img_shape + + def __len__(self) -> int: + return len(self.data) + + def _read_sample(self, index: int) -> Tuple[Any, Any]: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # Read image + img, target = self._read_sample(index) + # Pre-transforms (format conversion at run-time etc.) + if self._pre_transforms is not None: + img, target = self._pre_transforms(img, target) + + if self.img_transforms is not None: + # typing issue cf. https://github.com/python/mypy/issues/5485 + img = self.img_transforms(img) + + if self.sample_transforms is not None: + # Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks. + if ( + isinstance(target, dict) + and all(isinstance(item, np.ndarray) for item in target.values()) + and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target + ): + img_transformed = _copy_tensor(img) + for class_name, bboxes in target.items(): + img_transformed, target[class_name] = self.sample_transforms(img, bboxes) + img = img_transformed + else: + img, target = self.sample_transforms(img, target) + + return img, target + + def extra_repr(self) -> str: + return "" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.extra_repr()})" + + +class _VisionDataset(_AbstractDataset): + """Implements an abstract dataset + + Args: + ---- + url: URL of the dataset + file_name: name of the file once downloaded + file_hash: expected SHA256 of the file + extract_archive: whether the downloaded file is an archive to be extracted + download: whether the dataset should be downloaded if not present on disk + overwrite: whether the archive should be re-extracted + cache_dir: cache directory + cache_subdir: subfolder to use in the cache + """ + + def __init__( + self, + url: str, + file_name: Optional[str] = None, + file_hash: Optional[str] = None, + extract_archive: bool = False, + download: bool = False, + overwrite: bool = False, + cache_dir: Optional[str] = None, + cache_subdir: Optional[str] = None, + **kwargs: Any, + ) -> None: + cache_dir = ( + str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr"))) + if cache_dir is None + else cache_dir + ) + + cache_subdir = "datasets" if cache_subdir is None else cache_subdir + + file_name = file_name if isinstance(file_name, str) else os.path.basename(url) + # Download the file if not present + archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name) + + if not os.path.exists(archive_path) and not download: + raise ValueError("the dataset needs to be downloaded first with download=True") + + archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir) + + # Extract the archive + if extract_archive: + archive_path = Path(archive_path) + dataset_path = archive_path.parent.joinpath(archive_path.stem) + if not dataset_path.is_dir() or overwrite: + shutil.unpack_archive(archive_path, dataset_path) + + super().__init__(dataset_path if extract_archive else archive_path, **kwargs) diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4d8401680152c35cdd10e16d970115751ef596 --- /dev/null +++ b/doctr/datasets/datasets/pytorch.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from copy import deepcopy +from typing import Any, List, Tuple + +import numpy as np +import torch + +from doctr.io import read_img_as_tensor, tensor_from_numpy + +from .base import _AbstractDataset, _VisionDataset + +__all__ = ["AbstractDataset", "VisionDataset"] + + +class AbstractDataset(_AbstractDataset): + """Abstract class for all datasets""" + + def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]: + img_name, target = self.data[index] + + # Check target + if isinstance(target, dict): + assert "boxes" in target, "Target should contain 'boxes' key" + assert "labels" in target, "Target should contain 'labels' key" + elif isinstance(target, tuple): + assert len(target) == 2 + assert isinstance(target[0], str) or isinstance( + target[0], np.ndarray + ), "first element of the tuple should be a string or a numpy array" + assert isinstance(target[1], list), "second element of the tuple should be a list" + else: + assert isinstance(target, str) or isinstance( + target, np.ndarray + ), "Target should be a string or a numpy array" + + # Read image + img = ( + tensor_from_numpy(img_name, dtype=torch.float32) + if isinstance(img_name, np.ndarray) + else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32) + ) + + return img, deepcopy(target) + + @staticmethod + def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]: + images, targets = zip(*samples) + images = torch.stack(images, dim=0) # type: ignore[assignment] + + return images, list(targets) # type: ignore[return-value] + + +class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101 + pass diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..86b7b79289a9cf89f7f3e0c4cf0e7046ba002f75 --- /dev/null +++ b/doctr/datasets/datasets/tensorflow.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from copy import deepcopy +from typing import Any, List, Tuple + +import numpy as np +import tensorflow as tf + +from doctr.io import read_img_as_tensor, tensor_from_numpy + +from .base import _AbstractDataset, _VisionDataset + +__all__ = ["AbstractDataset", "VisionDataset"] + + +class AbstractDataset(_AbstractDataset): + """Abstract class for all datasets""" + + def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]: + img_name, target = self.data[index] + + # Check target + if isinstance(target, dict): + assert "boxes" in target, "Target should contain 'boxes' key" + assert "labels" in target, "Target should contain 'labels' key" + elif isinstance(target, tuple): + assert len(target) == 2 + assert isinstance(target[0], str) or isinstance( + target[0], np.ndarray + ), "first element of the tuple should be a string or a numpy array" + assert isinstance(target[1], list), "second element of the tuple should be a list" + else: + assert isinstance(target, str) or isinstance( + target, np.ndarray + ), "Target should be a string or a numpy array" + + # Read image + img = ( + tensor_from_numpy(img_name, dtype=tf.float32) + if isinstance(img_name, np.ndarray) + else read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32) + ) + + return img, deepcopy(target) + + @staticmethod + def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]: + images, targets = zip(*samples) + images = tf.stack(images, axis=0) + + return images, list(targets) + + +class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101 + pass diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..0000704dfa2a2e074924a9efef5e6fae75d4d018 --- /dev/null +++ b/doctr/datasets/detection.py @@ -0,0 +1,98 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from typing import Any, Dict, List, Tuple, Type, Union + +import numpy as np + +from doctr.file_utils import CLASS_NAME + +from .datasets import AbstractDataset +from .utils import pre_transform_multiclass + +__all__ = ["DetectionDataset"] + + +class DetectionDataset(AbstractDataset): + """Implements a text detection dataset + + >>> from doctr.datasets import DetectionDataset + >>> train_set = DetectionDataset(img_folder="/path/to/images", + >>> label_path="/path/to/labels.json") + >>> img, target = train_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_path: path to the annotations of each image + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, + pre_transforms=pre_transform_multiclass, + **kwargs, + ) + + # File existence check + self._class_names: List = [] + if not os.path.exists(label_path): + raise FileNotFoundError(f"unable to locate {label_path}") + with open(label_path, "rb") as f: + labels = json.load(f) + + self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = [] + np_dtype = np.float32 + for img_name, label in labels.items(): + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype) + + self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes))) + + def format_polygons( + self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type + ) -> Tuple[np.ndarray, List[str]]: + """Format polygons into an array + + Args: + ---- + polygons: the bounding boxes + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + np_dtype: dtype of array + + Returns: + ------- + geoms: bounding boxes as np array + polygons_classes: list of classes for each bounding box + """ + if isinstance(polygons, list): + self._class_names += [CLASS_NAME] + polygons_classes = [CLASS_NAME for _ in polygons] + _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype) + elif isinstance(polygons, dict): + self._class_names += list(polygons.keys()) + polygons_classes = [k for k, v in polygons.items() for _ in v] + _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0) + else: + raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}") + geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1) + return geoms, polygons_classes + + @property + def class_names(self): + return sorted(set(self._class_names)) diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py new file mode 100644 index 0000000000000000000000000000000000000000..6a05a01150316970521086e7722ce628808287a2 --- /dev/null +++ b/doctr/datasets/doc_artefacts.py @@ -0,0 +1,82 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import VisionDataset + +__all__ = ["DocArtefacts"] + + +class DocArtefacts(VisionDataset): + """Object detection dataset for non-textual elements in documents. + The dataset includes a variety of synthetic document pages with non-textual elements. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0 + :align: center + + >>> from doctr.datasets import DocArtefacts + >>> train_set = DocArtefacts(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0" + SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b" + CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"] + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(self.URL, None, self.SHA256, True, **kwargs) + self.train = train + + # Update root + self.root = os.path.join(self.root, "train" if train else "val") + # List images + tmp_root = os.path.join(self.root, "images") + with open(os.path.join(self.root, "labels.json"), "rb") as f: + labels = json.load(f) + self.data: List[Tuple[str, Dict[str, Any]]] = [] + img_list = os.listdir(tmp_root) + if len(labels) != len(img_list): + raise AssertionError("the number of images and labels do not match") + np_dtype = np.float32 + for img_name, label in labels.items(): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") + + # xmin, ymin, xmax, ymax + boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype) + classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64) + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + boxes = np.stack( + [ + np.stack([boxes[:, 0], boxes[:, 1]], axis=-1), + np.stack([boxes[:, 2], boxes[:, 1]], axis=-1), + np.stack([boxes[:, 2], boxes[:, 3]], axis=-1), + np.stack([boxes[:, 0], boxes[:, 3]], axis=-1), + ], + axis=1, + ) + self.data.append((img_name, dict(boxes=boxes, labels=classes))) + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py new file mode 100644 index 0000000000000000000000000000000000000000..0580b473a7ad39b56c3a6593948d7234b7a787bf --- /dev/null +++ b/doctr/datasets/funsd.py @@ -0,0 +1,112 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["FUNSD"] + + +class FUNSD(VisionDataset): + """FUNSD dataset from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/funsd-grid.png&src=0 + :align: center + + >>> from doctr.datasets import FUNSD + >>> train_set = FUNSD(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = "https://guillaumejaume.github.io/FUNSD/dataset.zip" + SHA256 = "c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f" + FILE_NAME = "funsd.zip" + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + self.URL, + self.FILE_NAME, + self.SHA256, + True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + np_dtype = np.float32 + + # Use the subset + subfolder = os.path.join("dataset", "training_data" if train else "testing_data") + + # # List images + tmp_root = os.path.join(self.root, subfolder, "images") + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + with open(os.path.join(self.root, subfolder, "annotations", f"{stem}.json"), "rb") as f: + data = json.load(f) + + _targets = [ + (word["text"], word["box"]) + for block in data["form"] + for word in block["words"] + if len(word["text"]) > 0 + ] + text_targets, box_targets = zip(*_targets) + if use_polygons: + # xmin, ymin, xmax, ymax -> (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ # type: ignore[assignment] + [ + [box[0], box[1]], + [box[2], box[1]], + [box[2], box[3]], + [box[0], box[3]], + ] + for box in box_targets + ] + + if recognition_task: + crops = crop_bboxes_from_image( + img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=np_dtype) + ) + for crop, label in zip(crops, list(text_targets)): + # filter labels with unknown characters + if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]): + self.data.append((crop, label)) + else: + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(text_targets)), + )) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/generator/__init__.py b/doctr/datasets/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/datasets/generator/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/datasets/generator/__pycache__/__init__.cpython-311.pyc b/doctr/datasets/generator/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd0799a88d878df08726e31f4e28e9e662af0b8 Binary files /dev/null and b/doctr/datasets/generator/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/datasets/generator/__pycache__/__init__.cpython-38.pyc b/doctr/datasets/generator/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaf6e319cdb90fc2da7e1a2427dd8b093b421611 Binary files /dev/null and b/doctr/datasets/generator/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/datasets/generator/__pycache__/base.cpython-311.pyc b/doctr/datasets/generator/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82509b58b10b0eb7bc0fc504f96300510aa20ffe Binary files /dev/null and b/doctr/datasets/generator/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/datasets/generator/__pycache__/base.cpython-38.pyc b/doctr/datasets/generator/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc31ced46df456d923c3eb8cbd347c30fcbaaf5c Binary files /dev/null and b/doctr/datasets/generator/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/datasets/generator/__pycache__/pytorch.cpython-311.pyc b/doctr/datasets/generator/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9406f8a41a844ece016856584367fddec207c52 Binary files /dev/null and b/doctr/datasets/generator/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/datasets/generator/__pycache__/tensorflow.cpython-311.pyc b/doctr/datasets/generator/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0f4f44986b0ac1a83bcf955ff5649ee10079a91 Binary files /dev/null and b/doctr/datasets/generator/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/datasets/generator/__pycache__/tensorflow.cpython-38.pyc b/doctr/datasets/generator/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fbbcf5dc3c335f2c1e8bff94ee51656557131f1 Binary files /dev/null and b/doctr/datasets/generator/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py new file mode 100644 index 0000000000000000000000000000000000000000..424f59563d1165989dfe12ea06ab6410e7241fb9 --- /dev/null +++ b/doctr/datasets/generator/base.py @@ -0,0 +1,155 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import random +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image, ImageDraw + +from doctr.io.image import tensor_from_pil +from doctr.utils.fonts import get_font + +from ..datasets import AbstractDataset + + +def synthesize_text_img( + text: str, + font_size: int = 32, + font_family: Optional[str] = None, + background_color: Optional[Tuple[int, int, int]] = None, + text_color: Optional[Tuple[int, int, int]] = None, +) -> Image.Image: + """Generate a synthetic text image + + Args: + ---- + text: the text to render as an image + font_size: the size of the font + font_family: the font family (has to be installed on your system) + background_color: background color of the final image + text_color: text color on the final image + + Returns: + ------- + PIL image of the text + """ + background_color = (0, 0, 0) if background_color is None else background_color + text_color = (255, 255, 255) if text_color is None else text_color + + font = get_font(font_family, font_size) + left, top, right, bottom = font.getbbox(text) + text_w, text_h = right - left, bottom - top + h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w)) + # If single letter, make the image square, otherwise expand to meet the text size + img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w)) + + img = Image.new("RGB", img_size[::-1], color=background_color) + d = ImageDraw.Draw(img) + + # Offset so that the text is centered + text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2))) + # Draw the text + d.text(text_pos, text, font=font, fill=text_color) + return img + + +class _CharacterGenerator(AbstractDataset): + def __init__( + self, + vocab: str, + num_samples: int, + cache_samples: bool = False, + font_family: Optional[Union[str, List[str]]] = None, + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + self.vocab = vocab + self._num_samples = num_samples + self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] + # Validate fonts + if isinstance(font_family, list): + for font in self.font_family: + try: + _ = get_font(font, 10) + except OSError: + raise ValueError(f"unable to locate font: {font}") + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + + self._data: List[Image.Image] = [] + if cache_samples: + self._data = [ + (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc] + for idx, char in enumerate(self.vocab) + for font in self.font_family + ] + + def __len__(self) -> int: + return self._num_samples + + def _read_sample(self, index: int) -> Tuple[Any, int]: + # Samples are already cached + if len(self._data) > 0: + idx = index % len(self._data) + pil_img, target = self._data[idx] # type: ignore[misc] + else: + target = index % len(self.vocab) + pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family)) + img = tensor_from_pil(pil_img) + + return img, target + + +class _WordGenerator(AbstractDataset): + def __init__( + self, + vocab: str, + min_chars: int, + max_chars: int, + num_samples: int, + cache_samples: bool = False, + font_family: Optional[Union[str, List[str]]] = None, + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + self.vocab = vocab + self.wordlen_range = (min_chars, max_chars) + self._num_samples = num_samples + self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] + # Validate fonts + if isinstance(font_family, list): + for font in self.font_family: + try: + _ = get_font(font, 10) + except OSError: + raise ValueError(f"unable to locate font: {font}") + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + + self._data: List[Image.Image] = [] + if cache_samples: + _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)] + self._data = [ + (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc] + for text in _words + ] + + def _generate_string(self, min_chars: int, max_chars: int) -> str: + num_chars = random.randint(min_chars, max_chars) + return "".join(random.choice(self.vocab) for _ in range(num_chars)) + + def __len__(self) -> int: + return self._num_samples + + def _read_sample(self, index: int) -> Tuple[Any, str]: + # Samples are already cached + if len(self._data) > 0: + pil_img, target = self._data[index] # type: ignore[misc] + else: + target = self._generate_string(*self.wordlen_range) + pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family)) + img = tensor_from_pil(pil_img) + + return img, target diff --git a/doctr/datasets/generator/pytorch.py b/doctr/datasets/generator/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b254c91e4a383f49721a58939bdba92660b00cba --- /dev/null +++ b/doctr/datasets/generator/pytorch.py @@ -0,0 +1,54 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from torch.utils.data._utils.collate import default_collate + +from .base import _CharacterGenerator, _WordGenerator + +__all__ = ["CharacterGenerator", "WordGenerator"] + + +class CharacterGenerator(_CharacterGenerator): + """Implements a character image generation dataset + + >>> from doctr.datasets import CharacterGenerator + >>> ds = CharacterGenerator(vocab='abdef', num_samples=100) + >>> img, target = ds[0] + + Args: + ---- + vocab: vocabulary to take the character from + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + setattr(self, "collate_fn", default_collate) + + +class WordGenerator(_WordGenerator): + """Implements a character image generation dataset + + >>> from doctr.datasets import WordGenerator + >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100) + >>> img, target = ds[0] + + Args: + ---- + vocab: vocabulary to take the character from + min_chars: minimum number of characters in a word + max_chars: maximum number of characters in a word + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + pass diff --git a/doctr/datasets/generator/tensorflow.py b/doctr/datasets/generator/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..82e205e03862db57360a7ec3c38a350c12cd7bb7 --- /dev/null +++ b/doctr/datasets/generator/tensorflow.py @@ -0,0 +1,60 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +from .base import _CharacterGenerator, _WordGenerator + +__all__ = ["CharacterGenerator", "WordGenerator"] + + +class CharacterGenerator(_CharacterGenerator): + """Implements a character image generation dataset + + >>> from doctr.datasets import CharacterGenerator + >>> ds = CharacterGenerator(vocab='abdef', num_samples=100) + >>> img, target = ds[0] + + Args: + ---- + vocab: vocabulary to take the character from + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @staticmethod + def collate_fn(samples): + images, targets = zip(*samples) + images = tf.stack(images, axis=0) + + return images, tf.convert_to_tensor(targets) + + +class WordGenerator(_WordGenerator): + """Implements a character image generation dataset + + >>> from doctr.datasets import WordGenerator + >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100) + >>> img, target = ds[0] + + Args: + ---- + vocab: vocabulary to take the character from + min_chars: minimum number of characters in a word + max_chars: maximum number of characters in a word + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + pass diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py new file mode 100644 index 0000000000000000000000000000000000000000..6f080e4d450b1eac9f630435f3d93b239095db0f --- /dev/null +++ b/doctr/datasets/ic03.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple, Union + +import defusedxml.ElementTree as ET +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["IC03"] + + +class IC03(VisionDataset): + """IC03 dataset from `"ICDAR 2003 Robust Reading Competitions: Entries, Results and Future Directions" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic03-grid.png&src=0 + :align: center + + >>> from doctr.datasets import IC03 + >>> train_set = IC03(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ( + "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTrain/scene.zip", + "9d86df514eb09dd693fb0b8c671ef54a0cfe02e803b1bbef9fc676061502eb94", + "ic03_train.zip", + ) + TEST = ( + "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTest/scene.zip", + "dbc4b5fd5d04616b8464a1b42ea22db351ee22c2546dd15ac35611857ea111f8", + "ic03_test.zip", + ) + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + url, sha256, file_name = self.TRAIN if train else self.TEST + super().__init__( + url, + file_name, + sha256, + True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + # Load xml data + tmp_root = ( + os.path.join(self.root, "SceneTrialTrain" if self.train else "SceneTrialTest") if sha256 else self.root + ) + xml_tree = ET.parse(os.path.join(tmp_root, "words.xml")) + xml_root = xml_tree.getroot() + + for image in tqdm(iterable=xml_root, desc="Unpacking IC03", total=len(xml_root)): + name, _resolution, rectangles = image + + # File existence check + if not os.path.exists(os.path.join(tmp_root, name.text)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}") + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + _boxes = [ + [ + [float(rect.attrib["x"]), float(rect.attrib["y"])], + [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])], + [ + float(rect.attrib["x"]) + float(rect.attrib["width"]), + float(rect.attrib["y"]) + float(rect.attrib["height"]), + ], + [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])], + ] + for rect in rectangles + ] + else: + # x_min, y_min, x_max, y_max + _boxes = [ + [ + float(rect.attrib["x"]), # type: ignore[list-item] + float(rect.attrib["y"]), # type: ignore[list-item] + float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item] + float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item] + ] + for rect in rectangles + ] + + # filter images without boxes + if len(_boxes) > 0: + boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype) + # Get the labels + labels = [lab.text for rect in rectangles for lab in rect if lab.text] + + if recognition_task: + crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes) + for crop, label in zip(crops, labels): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + self.data.append((crop, label)) + else: + self.data.append((name.text, dict(boxes=boxes, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py new file mode 100644 index 0000000000000000000000000000000000000000..81ba62f00145487d5af0f6937305019da65ce210 --- /dev/null +++ b/doctr/datasets/ic13.py @@ -0,0 +1,99 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import csv +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from tqdm import tqdm + +from .datasets import AbstractDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["IC13"] + + +class IC13(AbstractDataset): + """IC13 dataset from `"ICDAR 2013 Robust Reading Competition" `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic13-grid.png&src=0 + :align: center + + >>> # NOTE: You need to download both image and label parts from Focused Scene Text challenge Task2.1 2013-2015. + >>> from doctr.datasets import IC13 + >>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images", + >>> label_folder="/path/to/Challenge2_Training_Task1_GT") + >>> img, target = train_set[0] + >>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images", + >>> label_folder="/path/to/Challenge2_Test_Task1_GT") + >>> img, target = test_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_folder: folder with all annotation files for the images + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_folder: str, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs + ) + + # File existence check + if not os.path.exists(label_folder) or not os.path.exists(img_folder): + raise FileNotFoundError( + f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}" + ) + + self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + img_names = os.listdir(img_folder) + + for img_name in tqdm(iterable=img_names, desc="Unpacking IC13", total=len(img_names)): + img_path = Path(img_folder, img_name) + label_path = Path(label_folder, "gt_" + Path(img_name).stem + ".txt") + + with open(label_path, newline="\n") as f: + _lines = [ + [val[:-1] if val.endswith(",") else val for val in row] + for row in csv.reader(f, delimiter=" ", quotechar="'") + ] + labels = [line[-1].replace('"', "") for line in _lines] + # xmin, ymin, xmax, ymax + box_targets: np.ndarray = np.array([list(map(int, line[:4])) for line in _lines], dtype=np_dtype) + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = np.array( + [ + [ + [coords[0], coords[1]], + [coords[2], coords[1]], + [coords[2], coords[3]], + [coords[0], coords[3]], + ] + for coords in box_targets + ], + dtype=np_dtype, + ) + + if recognition_task: + crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets) + for crop, label in zip(crops, labels): + self.data.append((crop, label)) + else: + self.data.append((img_path, dict(boxes=box_targets, labels=labels))) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py new file mode 100644 index 0000000000000000000000000000000000000000..2b33ebb50b3297b27a0db4184af283fb6a2e0d2f --- /dev/null +++ b/doctr/datasets/iiit5k.py @@ -0,0 +1,103 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import scipy.io as sio +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ["IIIT5K"] + + +class IIIT5K(VisionDataset): + """IIIT-5K character-level localization dataset from + `"BMVC 2012 Scene Text Recognition using Higher Order Language Priors" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/iiit5k-grid.png&src=0 + :align: center + + >>> # NOTE: this dataset is for character-level localization + >>> from doctr.datasets import IIIT5K + >>> train_set = IIIT5K(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = "https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz" + SHA256 = "7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e" + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + self.URL, + None, + file_hash=self.SHA256, + extract_archive=True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + + # Load mat data + tmp_root = os.path.join(self.root, "IIIT5K") if self.SHA256 else self.root + mat_file = "trainCharBound" if self.train else "testCharBound" + mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0] + + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)): + _raw_path = img_path[0] + _raw_label = label[0] + + # File existence check + if not os.path.exists(os.path.join(tmp_root, _raw_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}") + + if recognition_task: + self.data.append((_raw_path, _raw_label)) + else: + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ + [ + [box[0], box[1]], + [box[0] + box[2], box[1]], + [box[0] + box[2], box[1] + box[3]], + [box[0], box[1] + box[3]], + ] + for box in box_targets + ] + else: + # xmin, ymin, xmax, ymax + box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] + + # label are casted to list where each char corresponds to the character's bounding box + self.data.append(( + _raw_path, + dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(_raw_label)), + )) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/iiithws.py b/doctr/datasets/iiithws.py new file mode 100644 index 0000000000000000000000000000000000000000..e33e3acd536af612c65125e1a138ec29d8b62727 --- /dev/null +++ b/doctr/datasets/iiithws.py @@ -0,0 +1,75 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from random import sample +from typing import Any, List, Tuple + +from tqdm import tqdm + +from .datasets import AbstractDataset + +__all__ = ["IIITHWS"] + + +class IIITHWS(AbstractDataset): + """IIITHWS dataset from `"Generating Synthetic Data for Text Recognition" + `_ | `"repository" `_ | + `"website" `_. + + >>> # NOTE: This is a pure recognition dataset without bounding box labels. + >>> # NOTE: You need to download the dataset. + >>> from doctr.datasets import IIITHWS + >>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", + >>> label_path="/path/to/IIIT-HWS-90K.txt", + >>> train=True) + >>> img, target = train_set[0] + >>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", + >>> label_path="/path/to/IIIT-HWS-90K.txt") + >>> train=False) + >>> img, target = test_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_path: path to the file with the labels + train: whether the subset should be the training one + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + self.data: List[Tuple[str, str]] = [] + self.train = train + + with open(label_path) as f: + annotations = f.readlines() + + # Shuffle the dataset otherwise the test set will contain the same labels n times + annotations = sample(annotations, len(annotations)) + train_samples = int(len(annotations) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + for annotation in tqdm( + iterable=annotations[set_slice], desc="Unpacking IIITHWS", total=len(annotations[set_slice]) + ): + img_path, label = annotation.split()[0:2] + img_path = os.path.join(img_folder, img_path) + + self.data.append((img_path, label)) + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py new file mode 100644 index 0000000000000000000000000000000000000000..ce70c9f3bc982b75c69ebce7f016bd76adcf6147 --- /dev/null +++ b/doctr/datasets/imgur5k.py @@ -0,0 +1,147 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import glob +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image +from tqdm import tqdm + +from .datasets import AbstractDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["IMGUR5K"] + + +class IMGUR5K(AbstractDataset): + """IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" + `_ | + `repository `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/imgur5k-grid.png&src=0 + :align: center + :width: 630 + :height: 400 + + >>> # NOTE: You need to download/generate the dataset from the repository. + >>> from doctr.datasets import IMGUR5K + >>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", + >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") + >>> img, target = train_set[0] + >>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", + >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") + >>> img, target = test_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_path: path to the annotations file of the dataset + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs + ) + + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + self.train = train + np_dtype = np.float32 + + img_names = os.listdir(img_folder) + train_samples = int(len(img_names) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + # define folder to write IMGUR5K recognition dataset + reco_folder_name = "IMGUR5K_recognition_train" if self.train else "IMGUR5K_recognition_test" + reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name + reco_folder_path = os.path.join(os.path.dirname(self.root), reco_folder_name) + reco_images_counter = 0 + + if recognition_task and os.path.isdir(reco_folder_path): + self._read_from_folder(reco_folder_path) + return + elif recognition_task and not os.path.isdir(reco_folder_path): + os.makedirs(reco_folder_path, exist_ok=False) + + with open(label_path) as f: + annotation_file = json.load(f) + + for img_name in tqdm(iterable=img_names[set_slice], desc="Unpacking IMGUR5K", total=len(img_names[set_slice])): + img_path = Path(img_folder, img_name) + img_id = img_name.split(".")[0] + + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + # some files have no annotations which are marked with only a dot in the 'word' key + # ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md + if img_id not in annotation_file["index_to_ann_map"].keys(): + continue + ann_ids = annotation_file["index_to_ann_map"][img_id] + annotations = [annotation_file["ann_id"][a_id] for a_id in ann_ids] + + labels = [ann["word"] for ann in annotations if ann["word"] != "."] + # x_center, y_center, width, height, angle + _boxes = [ + list(map(float, ann["bounding_box"].strip("[ ]").split(", "))) + for ann in annotations + if ann["word"] != "." + ] + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type] + + if not use_polygons: + # xmin, ymin, xmax, ymax + box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets] + + # filter images without boxes + if len(box_targets) > 0: + if recognition_task: + crops = crop_bboxes_from_image( + img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype) + ) + for crop, label in zip(crops, labels): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + # write data to disk + with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f: + f.write(label) + tmp_img = Image.fromarray(crop) + tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) + reco_images_counter += 1 + else: + self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels))) + + if recognition_task: + self._read_from_folder(reco_folder_path) + + def extra_repr(self) -> str: + return f"train={self.train}" + + def _read_from_folder(self, path: str) -> None: + for img_path in glob.glob(os.path.join(path, "*.png")): + with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f: + self.data.append((img_path, f.read())) diff --git a/doctr/datasets/loader.py b/doctr/datasets/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f08f7496afa155389116b8d3b52922152938b237 --- /dev/null +++ b/doctr/datasets/loader.py @@ -0,0 +1,102 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Callable, Optional + +import numpy as np +import tensorflow as tf + +from doctr.utils.multithreading import multithread_exec + +__all__ = ["DataLoader"] + + +def default_collate(samples): + """Collate multiple elements into batches + + Args: + ---- + samples: list of N tuples containing M elements + + Returns: + ------- + Tuple of M sequences contianing N elements each + """ + batch_data = zip(*samples) + + tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data) + + return tf_data + + +class DataLoader: + """Implements a dataset wrapper for fast data loading + + >>> from doctr.datasets import CORD, DataLoader + >>> train_set = CORD(train=True, download=True) + >>> train_loader = DataLoader(train_set, batch_size=32) + >>> train_iter = iter(train_loader) + >>> images, targets = next(train_iter) + + Args: + ---- + dataset: the dataset + shuffle: whether the samples should be shuffled before passing it to the iterator + batch_size: number of elements in each batch + drop_last: if `True`, drops the last batch if it isn't full + num_workers: number of workers to use for data loading + collate_fn: function to merge samples into a batch + """ + + def __init__( + self, + dataset, + shuffle: bool = True, + batch_size: int = 1, + drop_last: bool = False, + num_workers: Optional[int] = None, + collate_fn: Optional[Callable] = None, + ) -> None: + self.dataset = dataset + self.shuffle = shuffle + self.batch_size = batch_size + nb = len(self.dataset) / batch_size + self.num_batches = math.floor(nb) if drop_last else math.ceil(nb) + if collate_fn is None: + self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate + else: + self.collate_fn = collate_fn + self.num_workers = num_workers + self.reset() + + def __len__(self) -> int: + return self.num_batches + + def reset(self) -> None: + # Updates indices after each epoch + self._num_yielded = 0 + self.indices = np.arange(len(self.dataset)) + if self.shuffle is True: + np.random.shuffle(self.indices) + + def __iter__(self): + self.reset() + return self + + def __next__(self): + if self._num_yielded < self.num_batches: + # Get next indices + idx = self._num_yielded * self.batch_size + indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)] + + samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers)) + + batch_data = self.collate_fn(samples) + + self._num_yielded += 1 + return batch_data + else: + raise StopIteration diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b16caebe22e0812be38d26ed8d83003533d493 --- /dev/null +++ b/doctr/datasets/mjsynth.py @@ -0,0 +1,106 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, List, Tuple + +from tqdm import tqdm + +from .datasets import AbstractDataset + +__all__ = ["MJSynth"] + + +class MJSynth(AbstractDataset): + """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" + `_. + + >>> # NOTE: This is a pure recognition dataset without bounding box labels. + >>> # NOTE: You need to download the dataset. + >>> from doctr.datasets import MJSynth + >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", + >>> train=True) + >>> img, target = train_set[0] + >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") + >>> train=False) + >>> img, target = test_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_path: path to the file with the labels + train: whether the subset should be the training one + **kwargs: keyword arguments from `AbstractDataset`. + """ + + # filter corrupted or missing images + BLACKLIST = [ + "./1881/4/225_Marbling_46673.jpg\n", + "./2069/4/192_whittier_86389.jpg\n", + "./869/4/234_TRIASSIC_80582.jpg\n", + "./173/2/358_BURROWING_10395.jpg\n", + "./913/4/231_randoms_62372.jpg\n", + "./596/2/372_Ump_81662.jpg\n", + "./936/2/375_LOCALITIES_44992.jpg\n", + "./2540/4/246_SQUAMOUS_73902.jpg\n", + "./1332/4/224_TETHERED_78397.jpg\n", + "./627/6/83_PATRIARCHATE_55931.jpg\n", + "./2013/2/370_refract_63890.jpg\n", + "./2911/6/77_heretical_35885.jpg\n", + "./1730/2/361_HEREON_35880.jpg\n", + "./2194/2/334_EFFLORESCENT_24742.jpg\n", + "./2025/2/364_SNORTERS_72304.jpg\n", + "./368/4/232_friar_30876.jpg\n", + "./275/6/96_hackle_34465.jpg\n", + "./384/4/220_bolts_8596.jpg\n", + "./905/4/234_Postscripts_59142.jpg\n", + "./2749/6/101_Chided_13155.jpg\n", + "./495/6/81_MIDYEAR_48332.jpg\n", + "./2852/6/60_TOILSOME_79481.jpg\n", + "./554/2/366_Teleconferences_77948.jpg\n", + "./1696/4/211_Queened_61779.jpg\n", + "./2128/2/369_REDACTED_63458.jpg\n", + "./2557/2/351_DOWN_23492.jpg\n", + "./2489/4/221_snored_72290.jpg\n", + "./1650/2/355_stony_74902.jpg\n", + "./1863/4/223_Diligently_21672.jpg\n", + "./264/2/362_FORETASTE_30276.jpg\n", + "./429/4/208_Mainmasts_46140.jpg\n", + "./1817/2/363_actuating_904.jpg\n", + ] + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + self.data: List[Tuple[str, str]] = [] + self.train = train + + with open(label_path) as f: + img_paths = f.readlines() + + train_samples = int(len(img_paths) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + for path in tqdm(iterable=img_paths[set_slice], desc="Unpacking MJSynth", total=len(img_paths[set_slice])): + if path not in self.BLACKLIST: + label = path.split("_")[1] + img_path = os.path.join(img_folder, path[2:]).strip() + + self.data.append((img_path, label)) + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..b93c124ce74dcadc143abe3792e447454e391b01 --- /dev/null +++ b/doctr/datasets/ocr.py @@ -0,0 +1,71 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import AbstractDataset + +__all__ = ["OCRDataset"] + + +class OCRDataset(AbstractDataset): + """Implements an OCR dataset + + >>> from doctr.datasets import OCRDataset + >>> train_set = OCRDataset(img_folder="/path/to/images", + >>> label_file="/path/to/labels.json") + >>> img, target = train_set[0] + + Args: + ---- + img_folder: local path to image folder (all jpg at the root) + label_file: local path to the label file + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_file: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # List images + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + with open(label_file, "rb") as f: + data = json.load(f) + + for img_name, annotations in data.items(): + # Get image path + img_name = Path(img_name) + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + # handle empty images + if len(annotations["typed_words"]) == 0: + self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[]))) + continue + # Unpack the straight boxes (xmin, ymin, xmax, ymax) + geoms = [list(map(float, obj["geometry"][:4])) for obj in annotations["typed_words"]] + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + geoms = [ + [geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item] + for geom in geoms + ] + + text_targets = [obj["value"] for obj in annotations["typed_words"]] + + self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets))) diff --git a/doctr/datasets/orientation.py b/doctr/datasets/orientation.py new file mode 100644 index 0000000000000000000000000000000000000000..10bd55444e65cb2770ce9a4d15711a82cba5e06a --- /dev/null +++ b/doctr/datasets/orientation.py @@ -0,0 +1,40 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, List, Tuple + +import numpy as np + +from .datasets import AbstractDataset + +__all__ = ["OrientationDataset"] + + +class OrientationDataset(AbstractDataset): + """Implements a basic image dataset where targets are filled with zeros. + + >>> from doctr.datasets import OrientationDataset + >>> train_set = OrientationDataset(img_folder="/path/to/images") + >>> img, target = train_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, + **kwargs, + ) + + # initialize dataset with 0 degree rotation targets + self.data: List[Tuple[str, np.ndarray]] = [(img_name, np.array([0])) for img_name in os.listdir(self.root)] diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf37a20acfa2c3787bc8e4a1e88692d13fcdd15 --- /dev/null +++ b/doctr/datasets/recognition.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, List, Tuple + +from .datasets import AbstractDataset + +__all__ = ["RecognitionDataset"] + + +class RecognitionDataset(AbstractDataset): + """Dataset implementation for text recognition tasks + + >>> from doctr.datasets import RecognitionDataset + >>> train_set = RecognitionDataset(img_folder="/path/to/images", + >>> labels_path="/path/to/labels.json") + >>> img, target = train_set[0] + + Args: + ---- + img_folder: path to the images folder + labels_path: pathe to the json file containing all labels (character sequences) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + labels_path: str, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + self.data: List[Tuple[str, str]] = [] + with open(labels_path, encoding="utf-8") as f: + labels = json.load(f) + + for img_name, label in labels.items(): + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + self.data.append((img_name, label)) + + def merge_dataset(self, ds: AbstractDataset) -> None: + # Update data with new root for self + self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data] + # Define new root + self.root = Path("/") + # Merge with ds data + for img_path, label in ds.data: + self.data.append((str(Path(ds.root).joinpath(img_path)), label)) diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py new file mode 100644 index 0000000000000000000000000000000000000000..e72fde68a1f5e54333b5f7ab68c21069286770a1 --- /dev/null +++ b/doctr/datasets/sroie.py @@ -0,0 +1,103 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import csv +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["SROIE"] + + +class SROIE(VisionDataset): + """SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/sroie-grid.png&src=0 + :align: center + + >>> from doctr.datasets import SROIE + >>> train_set = SROIE(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ( + "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_train_task1.zip&src=0", + "d4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f", + "sroie2019_train_task1.zip", + ) + TEST = ( + "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_test.zip&src=0", + "41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2", + "sroie2019_test.zip", + ) + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + url, sha256, name = self.TRAIN if train else self.TEST + super().__init__( + url, + name, + sha256, + True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + + tmp_root = os.path.join(self.root, "images") + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + with open(os.path.join(self.root, "annotations", f"{stem}.txt"), encoding="latin") as f: + _rows = [row for row in list(csv.reader(f, delimiter=",")) if len(row) > 0] + + labels = [",".join(row[8:]) for row in _rows] + # reorder coordinates (8 -> (4,2) -> + # (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty lines + coords: np.ndarray = np.stack( + [np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2)) for row in _rows], axis=0 + ) + + if not use_polygons: + # xmin, ymin, xmax, ymax + coords = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1) + + if recognition_task: + crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path), geoms=coords) + for crop, label in zip(crops, labels): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + self.data.append((crop, label)) + else: + self.data.append((img_path, dict(boxes=coords, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py new file mode 100644 index 0000000000000000000000000000000000000000..57085c5213a549f276858e6623d7e2a91006ad65 --- /dev/null +++ b/doctr/datasets/svhn.py @@ -0,0 +1,131 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple, Union + +import h5py +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["SVHN"] + + +class SVHN(VisionDataset): + """SVHN dataset from `"The Street View House Numbers (SVHN) Dataset" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svhn-grid.png&src=0 + :align: center + + >>> from doctr.datasets import SVHN + >>> train_set = SVHN(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ( + "http://ufldl.stanford.edu/housenumbers/train.tar.gz", + "4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898", + "svhn_train.tar", + ) + + TEST = ( + "http://ufldl.stanford.edu/housenumbers/test.tar.gz", + "57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5", + "svhn_test.tar", + ) + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + url, sha256, name = self.TRAIN if train else self.TEST + super().__init__( + url, + file_name=name, + file_hash=sha256, + extract_archive=True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + tmp_root = os.path.join(self.root, "train" if train else "test") + + # Load mat data (matlab v7.3 - can not be loaded with scipy) + with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f: + img_refs = f["digitStruct/name"] + box_refs = f["digitStruct/bbox"] + for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc="Unpacking SVHN", total=len(img_refs)): + # convert ascii matrix to string + img_name = "".join(map(chr, f[img_ref[0]][()].flatten())) + + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") + + # Unpack the information + box = f[box_ref[0]] + if box["left"].shape[0] == 1: + box_dict = {k: [int(vals[0][0])] for k, vals in box.items()} + else: + box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()} + + # Convert it to the right format + coords: np.ndarray = np.array( + [box_dict["left"], box_dict["top"], box_dict["width"], box_dict["height"]], dtype=np_dtype + ).transpose() + label_targets = list(map(str, box_dict["label"])) + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets: np.ndarray = np.stack( + [ + np.stack([coords[:, 0], coords[:, 1]], axis=-1), + np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1), + np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1), + np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1), + ], + axis=1, + ) + else: + # x, y, width, height -> xmin, ymin, xmax, ymax + box_targets = np.stack( + [ + coords[:, 0], + coords[:, 1], + coords[:, 0] + coords[:, 2], + coords[:, 1] + coords[:, 3], + ], + axis=-1, + ) + + if recognition_task: + crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets) + for crop, label in zip(crops, label_targets): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + self.data.append((crop, label)) + else: + self.data.append((img_name, dict(boxes=box_targets, labels=label_targets))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb7b6d599e6e2dc5cf4c424da6f9c61a579adf0 --- /dev/null +++ b/doctr/datasets/svt.py @@ -0,0 +1,117 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple, Union + +import defusedxml.ElementTree as ET +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["SVT"] + + +class SVT(VisionDataset): + """SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision" + `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0 + :align: center + + >>> from doctr.datasets import SVT + >>> train_set = SVT(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = "http://vision.ucsd.edu/~kai/svt/svt.zip" + SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf" + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + self.URL, + None, + self.SHA256, + True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + # Load xml data + tmp_root = os.path.join(self.root, "svt1") if self.SHA256 else self.root + xml_tree = ( + ET.parse(os.path.join(tmp_root, "train.xml")) + if self.train + else ET.parse(os.path.join(tmp_root, "test.xml")) + ) + xml_root = xml_tree.getroot() + + for image in tqdm(iterable=xml_root, desc="Unpacking SVT", total=len(xml_root)): + name, _, _, _resolution, rectangles = image + + # File existence check + if not os.path.exists(os.path.join(tmp_root, name.text)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}") + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + _boxes = [ + [ + [float(rect.attrib["x"]), float(rect.attrib["y"])], + [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])], + [ + float(rect.attrib["x"]) + float(rect.attrib["width"]), + float(rect.attrib["y"]) + float(rect.attrib["height"]), + ], + [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])], + ] + for rect in rectangles + ] + else: + # x_min, y_min, x_max, y_max + _boxes = [ + [ + float(rect.attrib["x"]), # type: ignore[list-item] + float(rect.attrib["y"]), # type: ignore[list-item] + float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item] + float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item] + ] + for rect in rectangles + ] + + boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype) + # Get the labels + labels = [lab.text for rect in rectangles for lab in rect] + + if recognition_task: + crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes) + for crop, label in zip(crops, labels): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + self.data.append((crop, label)) + else: + self.data.append((name.text, dict(boxes=boxes, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py new file mode 100644 index 0000000000000000000000000000000000000000..a60e22e83212d2586159612d00e651dd66f82a5f --- /dev/null +++ b/doctr/datasets/synthtext.py @@ -0,0 +1,128 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import glob +import os +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from PIL import Image +from scipy import io as sio +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["SynthText"] + + +class SynthText(VisionDataset): + """SynthText dataset from `"Synthetic Data for Text Localisation in Natural Images" + `_ | `"repository" `_ | + `"website" `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0 + :align: center + + >>> from doctr.datasets import SynthText + >>> train_set = SynthText(train=True, download=True) + >>> img, target = train_set[0] + + Args: + ---- + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip" + SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1" + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + self.URL, + None, + file_hash=None, + extract_archive=True, + pre_transforms=convert_target_to_relative if not recognition_task else None, + **kwargs, + ) + self.train = train + self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = [] + np_dtype = np.float32 + + # Load mat data + tmp_root = os.path.join(self.root, "SynthText") if self.SHA256 else self.root + # define folder to write SynthText recognition dataset + reco_folder_name = "SynthText_recognition_train" if self.train else "SynthText_recognition_test" + reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name + reco_folder_path = os.path.join(tmp_root, reco_folder_name) + reco_images_counter = 0 + + if recognition_task and os.path.isdir(reco_folder_path): + self._read_from_folder(reco_folder_path) + return + elif recognition_task and not os.path.isdir(reco_folder_path): + os.makedirs(reco_folder_path, exist_ok=False) + + mat_data = sio.loadmat(os.path.join(tmp_root, "gt.mat")) + train_samples = int(len(mat_data["imnames"][0]) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + paths = mat_data["imnames"][0][set_slice] + boxes = mat_data["wordBB"][0][set_slice] + labels = mat_data["txt"][0][set_slice] + del mat_data + + for img_path, word_boxes, txt in tqdm( + iterable=zip(paths, boxes, labels), desc="Unpacking SynthText", total=len(paths) + ): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path[0])): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}") + + labels = [elt for word in txt.tolist() for elt in word.split()] + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + word_boxes = ( + word_boxes.transpose(2, 1, 0) + if word_boxes.ndim == 3 + else np.expand_dims(word_boxes.transpose(1, 0), axis=0) + ) + + if not use_polygons: + # xmin, ymin, xmax, ymax + word_boxes = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1) + + if recognition_task: + crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes) + for crop, label in zip(crops, labels): + if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0: + # write data to disk + with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f: + f.write(label) + tmp_img = Image.fromarray(crop) + tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png")) + reco_images_counter += 1 + else: + self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels))) + + if recognition_task: + self._read_from_folder(reco_folder_path) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" + + def _read_from_folder(self, path: str) -> None: + for img_path in glob.glob(os.path.join(path, "*.png")): + with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f: + self.data.append((img_path, f.read())) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..860e19a229c7c758896578659b279c817172187c --- /dev/null +++ b/doctr/datasets/utils.py @@ -0,0 +1,217 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import string +import unicodedata +from collections.abc import Sequence +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Sequence as SequenceType + +import numpy as np +from PIL import Image + +from doctr.io.image import get_img_shape +from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extract_rcrops + +from .vocabs import VOCABS + +__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"] + +ImageTensor = TypeVar("ImageTensor") + + +def translate( + input_string: str, + vocab_name: str, + unknown_char: str = "■", +) -> str: + """Translate a string input in a given vocabulary + + Args: + ---- + input_string: input string to translate + vocab_name: vocabulary to use (french, latin, ...) + unknown_char: unknown character for non-translatable characters + + Returns: + ------- + A string translated in a given vocab + """ + if VOCABS.get(vocab_name) is None: + raise KeyError("output vocabulary must be in vocabs dictionnary") + + translated = "" + for char in input_string: + if char not in VOCABS[vocab_name]: + # we need to translate char into a vocab char + if char in string.whitespace: + # remove whitespaces + continue + # normalize character if it is not in vocab + char = unicodedata.normalize("NFD", char).encode("ascii", "ignore").decode("ascii") + if char == "" or char not in VOCABS[vocab_name]: + # if normalization fails or char still not in vocab, return unknown character) + char = unknown_char + translated += char + return translated + + +def encode_string( + input_string: str, + vocab: str, +) -> List[int]: + """Given a predefined mapping, encode the string to a sequence of numbers + + Args: + ---- + input_string: string to encode + vocab: vocabulary (string), the encoding is given by the indexing of the character sequence + + Returns: + ------- + A list encoding the input_string + """ + try: + return list(map(vocab.index, input_string)) + except ValueError: + raise ValueError( + f"some characters cannot be found in 'vocab'. \ + Please check the input string {input_string} and the vocabulary {vocab}" + ) + + +def decode_sequence( + input_seq: Union[np.ndarray, SequenceType[int]], + mapping: str, +) -> str: + """Given a predefined mapping, decode the sequence of numbers to a string + + Args: + ---- + input_seq: array to decode + mapping: vocabulary (string), the encoding is given by the indexing of the character sequence + + Returns: + ------- + A string, decoded from input_seq + """ + if not isinstance(input_seq, (Sequence, np.ndarray)): + raise TypeError("Invalid sequence type") + if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)): + raise AssertionError("Input must be an array of int, with max less than mapping size") + + return "".join(map(mapping.__getitem__, input_seq)) + + +def encode_sequences( + sequences: List[str], + vocab: str, + target_size: Optional[int] = None, + eos: int = -1, + sos: Optional[int] = None, + pad: Optional[int] = None, + dynamic_seq_length: bool = False, +) -> np.ndarray: + """Encode character sequences using a given vocab as mapping + + Args: + ---- + sequences: the list of character sequences of size N + vocab: the ordered vocab to use for encoding + target_size: maximum length of the encoded data + eos: encoding of End Of String + sos: optional encoding of Start Of String + pad: optional encoding for padding. In case of padding, all sequences are followed by 1 EOS then PAD + dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size + + Returns: + ------- + the padded encoded data as a tensor + """ + if 0 <= eos < len(vocab): + raise ValueError("argument 'eos' needs to be outside of vocab possible indices") + + if not isinstance(target_size, int) or dynamic_seq_length: + # Maximum string length + EOS + max_length = max(len(w) for w in sequences) + 1 + if isinstance(sos, int): + max_length += 1 + if isinstance(pad, int): + max_length += 1 + target_size = max_length if not isinstance(target_size, int) else min(max_length, target_size) + + # Pad all sequences + if isinstance(pad, int): # pad with padding symbol + if 0 <= pad < len(vocab): + raise ValueError("argument 'pad' needs to be outside of vocab possible indices") + # In that case, add EOS at the end of the word before padding + default_symbol = pad + else: # pad with eos symbol + default_symbol = eos + encoded_data: np.ndarray = np.full([len(sequences), target_size], default_symbol, dtype=np.int32) + + # Encode the strings + for idx, seq in enumerate(map(partial(encode_string, vocab=vocab), sequences)): + if isinstance(pad, int): # add eos at the end of the sequence + seq.append(eos) + encoded_data[idx, : min(len(seq), target_size)] = seq[: min(len(seq), target_size)] + + if isinstance(sos, int): # place sos symbol at the beginning of each sequence + if 0 <= sos < len(vocab): + raise ValueError("argument 'sos' needs to be outside of vocab possible indices") + encoded_data = np.roll(encoded_data, 1) + encoded_data[:, 0] = sos + + return encoded_data + + +def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]: + target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) + return img, target + + +def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> List[np.ndarray]: + """Crop a set of bounding boxes from an image + + Args: + ---- + img_path: path to the image + geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) + + Returns: + ------- + a list of cropped images + """ + with Image.open(img_path) as pil_img: + img: np.ndarray = np.array(pil_img.convert("RGB")) + # Polygon + if geoms.ndim == 3 and geoms.shape[1:] == (4, 2): + return extract_rcrops(img, geoms.astype(dtype=int)) + if geoms.ndim == 2 and geoms.shape[1] == 4: + return extract_crops(img, geoms.astype(dtype=int)) + raise ValueError("Invalid geometry format") + + +def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]: + """Converts multiclass target to relative coordinates. + + Args: + ---- + img: Image + target: tuple of target polygons and their classes names + + Returns: + ------- + Image and dictionary of boxes, with class names as keys + """ + boxes = convert_to_relative_coords(target[0], get_img_shape(img)) + boxes_classes = target[1] + boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))} + for k, poly in zip(boxes_classes, boxes): + boxes_dict[k].append(poly) + boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()} + return img, boxes_dict diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc32d866581c931419072f97b641846b25a18db --- /dev/null +++ b/doctr/datasets/vocabs.py @@ -0,0 +1,71 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import string +from typing import Dict + +__all__ = ["VOCABS"] + + +VOCABS: Dict[str, str] = { + "digits": string.digits, + "ascii_letters": string.ascii_letters, + "punctuation": string.punctuation, + "currency": "£€¥¢฿", + "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ", + "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي", + "persian_letters": "پچڢڤگ", + "hindi_digits": "٠١٢٣٤٥٦٧٨٩", + "arabic_diacritics": "ًٌٍَُِّْ", + "arabic_punctuation": "؟؛«»—", +} + +VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"] +VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"] +VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"] +VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ" +VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ" +VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿" +VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙÚ" +VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ" +VOCABS["arabic"] = ( + VOCABS["digits"] + + VOCABS["hindi_digits"] + + VOCABS["arabic_letters"] + + VOCABS["persian_letters"] + + VOCABS["arabic_diacritics"] + + VOCABS["arabic_punctuation"] + + VOCABS["punctuation"] +) +VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ" +VOCABS["polish"] = VOCABS["english"] + "ąćęłńóśźżĄĆĘŁŃÓŚŹŻ" +VOCABS["dutch"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" +VOCABS["norwegian"] = VOCABS["english"] + "æøåÆØÅ" +VOCABS["danish"] = VOCABS["english"] + "æøåÆØÅ" +VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ" +VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ" +VOCABS["vietnamese"] = ( + VOCABS["english"] + + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ" + + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ" +) +VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪" +VOCABS["multilingual"] = "".join( + dict.fromkeys( + VOCABS["french"] + + VOCABS["portuguese"] + + VOCABS["spanish"] + + VOCABS["german"] + + VOCABS["czech"] + + VOCABS["polish"] + + VOCABS["dutch"] + + VOCABS["italian"] + + VOCABS["norwegian"] + + VOCABS["danish"] + + VOCABS["finnish"] + + VOCABS["swedish"] + + "§" + ) +) diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py new file mode 100644 index 0000000000000000000000000000000000000000..19108d77612af08cb227750abea1beae938605ff --- /dev/null +++ b/doctr/datasets/wildreceipt.py @@ -0,0 +1,111 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np + +from .datasets import AbstractDataset +from .utils import convert_target_to_relative, crop_bboxes_from_image + +__all__ = ["WILDRECEIPT"] + + +class WILDRECEIPT(AbstractDataset): + """WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction" + `_ | + `repository `_. + + .. image:: https://doctr-static.mindee.com/models?id=v0.7.0/wildreceipt-dataset.jpg&src=0 + :align: center + + >>> # NOTE: You need to download the dataset first. + >>> from doctr.datasets import WILDRECEIPT + >>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/", + >>> label_path="/path/to/wildreceipt/train.txt") + >>> img, target = train_set[0] + >>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/", + >>> label_path="/path/to/wildreceipt/test.txt") + >>> img, target = test_set[0] + + Args: + ---- + img_folder: folder with all the images of the dataset + label_path: path to the annotations file of the dataset + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + recognition_task: whether the dataset should be used for recognition task + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + use_polygons: bool = False, + recognition_task: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs + ) + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + tmp_root = img_folder + self.train = train + np_dtype = np.float32 + self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = [] + + with open(label_path, "r") as file: + data = file.read() + # Split the text file into separate JSON strings + json_strings = data.strip().split("\n") + box: Union[List[float], np.ndarray] + _targets = [] + for json_string in json_strings: + json_data = json.loads(json_string) + img_path = json_data["file_name"] + annotations = json_data["annotations"] + for annotation in annotations: + coordinates = annotation["box"] + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box = np.array( + [ + [coordinates[0], coordinates[1]], + [coordinates[2], coordinates[3]], + [coordinates[4], coordinates[5]], + [coordinates[6], coordinates[7]], + ], + dtype=np_dtype, + ) + else: + x, y = coordinates[::2], coordinates[1::2] + box = [min(x), min(y), max(x), max(y)] + _targets.append((annotation["text"], box)) + text_targets, box_targets = zip(*_targets) + + if recognition_task: + crops = crop_bboxes_from_image( + img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0) + ) + for crop, label in zip(crops, list(text_targets)): + if label and " " not in label: + self.data.append((crop, label)) + else: + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)), + )) + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/file_utils.py b/doctr/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68e9dfffac321618cf79322ce27a3cf2b6abe4f2 --- /dev/null +++ b/doctr/file_utils.py @@ -0,0 +1,106 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py + +import importlib.metadata +import importlib.util +import logging +import os +from typing import Optional + +CLASS_NAME: str = "words" + + +__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + + +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib.metadata.version("torch") + logging.info(f"PyTorch version {_torch_version} available.") + except importlib.metadata.PackageNotFoundError: # pragma: no cover + _torch_available = False +else: # pragma: no cover + logging.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "tensorflow-rocm", + "tensorflow-macos", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib.metadata.version(pkg) + break + except importlib.metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover + logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.") + _tf_available = False + else: + logging.info(f"TensorFlow version {_tf_version} available.") +else: # pragma: no cover + logging.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if not _torch_available and not _tf_available: # pragma: no cover + raise ModuleNotFoundError( + "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them" + " is installed and that either USE_TF or USE_TORCH is enabled." + ) + + +def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover + """ + package requirement helper + + Args: + ---- + name: name of the package + extra_message: additional message to display if the package is not found + """ + try: + _pkg_version = importlib.metadata.version(name) + logging.info(f"{name} version {_pkg_version} available.") + except importlib.metadata.PackageNotFoundError: + raise ImportError( + f"\n\n{extra_message if extra_message is not None else ''} " + f"\nPlease install it with the following command: pip install {name}\n" + ) + + +def is_torch_available(): + """Whether PyTorch is installed.""" + return _torch_available + + +def is_tf_available(): + """Whether TensorFlow is installed.""" + return _tf_available diff --git a/doctr/io/__init__.py b/doctr/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6eab8c240615522894f0b8a4ef09aaf59636a811 --- /dev/null +++ b/doctr/io/__init__.py @@ -0,0 +1,5 @@ +from .elements import * +from .html import * +from .image import * +from .pdf import * +from .reader import * diff --git a/doctr/io/__pycache__/__init__.cpython-310.pyc b/doctr/io/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f0d8dec8d7979fcb33c8139d2e7ff655616d52 Binary files /dev/null and b/doctr/io/__pycache__/__init__.cpython-310.pyc differ diff --git a/doctr/io/__pycache__/__init__.cpython-311.pyc b/doctr/io/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a64fcb3929d747eccba467de8778611dcb42ce46 Binary files /dev/null and b/doctr/io/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/io/__pycache__/__init__.cpython-38.pyc b/doctr/io/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d641328129c15cb8edf8d5e98affec8998b066 Binary files /dev/null and b/doctr/io/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/io/__pycache__/elements.cpython-310.pyc b/doctr/io/__pycache__/elements.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..595028927590fa6341c09d5481dad7b2cf47494f Binary files /dev/null and b/doctr/io/__pycache__/elements.cpython-310.pyc differ diff --git a/doctr/io/__pycache__/elements.cpython-311.pyc b/doctr/io/__pycache__/elements.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40f632822fe75de9f0d688773caa7e987a57ba9c Binary files /dev/null and b/doctr/io/__pycache__/elements.cpython-311.pyc differ diff --git a/doctr/io/__pycache__/elements.cpython-38.pyc b/doctr/io/__pycache__/elements.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4403bde2036726203e62b93ad5238ac4c468e3ae Binary files /dev/null and b/doctr/io/__pycache__/elements.cpython-38.pyc differ diff --git a/doctr/io/__pycache__/html.cpython-310.pyc b/doctr/io/__pycache__/html.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24386a87e7d69f7cfcb48de59c5ad24600b6517b Binary files /dev/null and b/doctr/io/__pycache__/html.cpython-310.pyc differ diff --git a/doctr/io/__pycache__/html.cpython-311.pyc b/doctr/io/__pycache__/html.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383501ed25d53bd1b4f2686d7854d435f6ef1738 Binary files /dev/null and b/doctr/io/__pycache__/html.cpython-311.pyc differ diff --git a/doctr/io/__pycache__/html.cpython-38.pyc b/doctr/io/__pycache__/html.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e47ad66633eb0ab5518df7ceb6536b08667fd6b1 Binary files /dev/null and b/doctr/io/__pycache__/html.cpython-38.pyc differ diff --git a/doctr/io/__pycache__/pdf.cpython-311.pyc b/doctr/io/__pycache__/pdf.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40a760ec7ce102fe151ae79e6d79fb19a3d30ab4 Binary files /dev/null and b/doctr/io/__pycache__/pdf.cpython-311.pyc differ diff --git a/doctr/io/__pycache__/pdf.cpython-38.pyc b/doctr/io/__pycache__/pdf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3487185bd67ced72aa905d48b3faf1894fb5656 Binary files /dev/null and b/doctr/io/__pycache__/pdf.cpython-38.pyc differ diff --git a/doctr/io/__pycache__/reader.cpython-311.pyc b/doctr/io/__pycache__/reader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..569f8346b2bb65433713621f170727d7e05dac50 Binary files /dev/null and b/doctr/io/__pycache__/reader.cpython-311.pyc differ diff --git a/doctr/io/__pycache__/reader.cpython-38.pyc b/doctr/io/__pycache__/reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b71fbdc496396f6baf5da8f2a41d9cadff8d9dd1 Binary files /dev/null and b/doctr/io/__pycache__/reader.cpython-38.pyc differ diff --git a/doctr/io/elements.py b/doctr/io/elements.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d2cb9b5b91b5ac69b3d12039ebd3456748324e --- /dev/null +++ b/doctr/io/elements.py @@ -0,0 +1,634 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple, Union + +from defusedxml import defuse_stdlib + +defuse_stdlib() +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import Element as ETElement +from xml.etree.ElementTree import SubElement + +import numpy as np + +import doctr +from doctr.file_utils import requires_package +from doctr.utils.common_types import BoundingBox +from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox +from doctr.utils.reconstitution import synthesize_kie_page, synthesize_page +from doctr.utils.repr import NestedObject + +try: # optional dependency for visualization + from doctr.utils.visualization import visualize_kie_page, visualize_page +except ModuleNotFoundError: + pass + +__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"] + + +class Element(NestedObject): + """Implements an abstract document element with exporting and text rendering capabilities""" + + _children_names: List[str] = [] + _exported_keys: List[str] = [] + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + if k in self._children_names: + setattr(self, k, v) + else: + raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'") + + def export(self) -> Dict[str, Any]: + """Exports the object into a nested dict format""" + export_dict = {k: getattr(self, k) for k in self._exported_keys} + for children_name in self._children_names: + if children_name in ["predictions"]: + export_dict[children_name] = { + k: [item.export() for item in c] for k, c in getattr(self, children_name).items() + } + else: + export_dict[children_name] = [c.export() for c in getattr(self, children_name)] + + return export_dict + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + raise NotImplementedError + + def render(self) -> str: + raise NotImplementedError + + +class Word(Element): + """Implements a word element + + Args: + ---- + value: the text string of the word + confidence: the confidence associated with the text prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size + crop_orientation: the general orientation of the crop in degrees and its confidence + """ + + _exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"] + _children_names: List[str] = [] + + def __init__( + self, + value: str, + confidence: float, + geometry: Union[BoundingBox, np.ndarray], + crop_orientation: Dict[str, Any], + ) -> None: + super().__init__() + self.value = value + self.confidence = confidence + self.geometry = geometry + self.crop_orientation = crop_orientation + + def render(self) -> str: + """Renders the full text of the element""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Artefact(Element): + """Implements a non-textual element + + Args: + ---- + artefact_type: the type of artefact + confidence: the confidence of the type prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. + """ + + _exported_keys: List[str] = ["geometry", "type", "confidence"] + _children_names: List[str] = [] + + def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None: + super().__init__() + self.geometry = geometry + self.type = artefact_type + self.confidence = confidence + + def render(self) -> str: + """Renders the full text of the element""" + return f"[{self.type.upper()}]" + + def extra_repr(self) -> str: + return f"type='{self.type}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Line(Element): + """Implements a line element as a collection of words + + Args: + ---- + words: list of word elements + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all words in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ["words"] + words: List[Word] = [] + + def __init__( + self, + words: List[Word], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + # Check whether this is a rotated or straight box + box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox + geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator] + + super().__init__(words=words) + self.geometry = geometry + + def render(self) -> str: + """Renders the full text of the element""" + return " ".join(w.render() for w in self.words) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + "words": [Word.from_dict(_dict) for _dict in save_dict["words"]], + }) + return cls(**kwargs) + + +class Prediction(Word): + """Implements a prediction element""" + + def render(self) -> str: + """Renders the full text of the element""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}" + + +class Block(Element): + """Implements a block element as a collection of lines and artefacts + + Args: + ---- + lines: list of line elements + artefacts: list of artefacts + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all lines and artefacts in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ["lines", "artefacts"] + lines: List[Line] = [] + artefacts: List[Artefact] = [] + + def __init__( + self, + lines: List[Line] = [], + artefacts: List[Artefact] = [], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + line_boxes = [word.geometry for line in lines for word in line.words] + artefact_boxes = [artefact.geometry for artefact in artefacts] + box_resolution_fn = ( + resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox + ) + geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator] + + super().__init__(lines=lines, artefacts=artefacts) + self.geometry = geometry + + def render(self, line_break: str = "\n") -> str: + """Renders the full text of the element""" + return line_break.join(line.render() for line in self.lines) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + "lines": [Line.from_dict(_dict) for _dict in save_dict["lines"]], + "artefacts": [Artefact.from_dict(_dict) for _dict in save_dict["artefacts"]], + }) + return cls(**kwargs) + + +class Page(Element): + """Implements a page element as a collection of blocks + + Args: + ---- + page: image encoded as a numpy array in uint8 + blocks: list of block elements + page_idx: the index of the page in the input raw document + dimensions: the page size in pixels in format (height, width) + orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction + language: a dictionary with the language value and confidence of the prediction + """ + + _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"] + _children_names: List[str] = ["blocks"] + blocks: List[Block] = [] + + def __init__( + self, + page: np.ndarray, + blocks: List[Block], + page_idx: int, + dimensions: Tuple[int, int], + orientation: Optional[Dict[str, Any]] = None, + language: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(blocks=blocks) + self.page = page + self.page_idx = page_idx + self.dimensions = dimensions + self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None) + self.language = language if isinstance(language, dict) else dict(value=None, confidence=None) + + def render(self, block_break: str = "\n\n") -> str: + """Renders the full text of the element""" + return block_break.join(b.render() for b in self.blocks) + + def extra_repr(self) -> str: + return f"dimensions={self.dimensions}" + + def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None: + """Overlay the result on a given image + + Args: + interactive: whether the display should be interactive + preserve_aspect_ratio: pass True if you passed True to the predictor + **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method + """ + requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") + requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") + import matplotlib.pyplot as plt + + visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) + plt.show(**kwargs) + + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Returns + ------- + synthesized page + """ + return synthesize_page(self.export(), **kwargs) + + def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]: + """Export the page as XML (hOCR-format) + convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md + + Args: + ---- + file_title: the title of the XML file + + Returns: + ------- + a tuple of the XML byte string, and its ElementTree + """ + p_idx = self.page_idx + block_count: int = 1 + line_count: int = 1 + word_count: int = 1 + height, width = self.dimensions + language = self.language if "language" in self.language.keys() else "en" + # Create the XML root element + page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)}) + # Create the header / SubElements of the root element + head = SubElement(page_hocr, "head") + SubElement(head, "title").text = file_title + SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"}) + SubElement( + head, + "meta", + attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined] + ) + SubElement( + head, + "meta", + attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"}, + ) + # Create the body + body = SubElement(page_hocr, "body") + SubElement( + body, + "div", + attrib={ + "class": "ocr_page", + "id": f"page_{p_idx + 1}", + "title": f"image; bbox 0 0 {width} {height}; ppageno 0", + }, + ) + # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes + for block in self.blocks: + if len(block.geometry) != 2: + raise TypeError("XML export is only available for straight bounding boxes for now.") + (xmin, ymin), (xmax, ymax) = block.geometry + block_div = SubElement( + body, + "div", + attrib={ + "class": "ocr_carea", + "id": f"block_{block_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + paragraph = SubElement( + block_div, + "p", + attrib={ + "class": "ocr_par", + "id": f"par_{block_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + block_count += 1 + for line in block.lines: + (xmin, ymin), (xmax, ymax) = line.geometry + # NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0 + line_span = SubElement( + paragraph, + "span", + attrib={ + "class": "ocr_line", + "id": f"line_{line_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0", + }, + ) + line_count += 1 + for word in line.words: + (xmin, ymin), (xmax, ymax) = word.geometry + conf = word.confidence + word_div = SubElement( + line_span, + "span", + attrib={ + "class": "ocrx_word", + "id": f"word_{word_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + x_wconf {int(round(conf * 100))}", + }, + ) + # set the text + word_div.text = word.value + word_count += 1 + + return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]}) + return cls(**kwargs) + + +class KIEPage(Element): + """Implements a KIE page element as a collection of predictions + + Args: + ---- + predictions: Dictionary with list of block elements for each detection class + page: image encoded as a numpy array in uint8 + page_idx: the index of the page in the input raw document + dimensions: the page size in pixels in format (height, width) + orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction + language: a dictionary with the language value and confidence of the prediction + """ + + _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"] + _children_names: List[str] = ["predictions"] + predictions: Dict[str, List[Prediction]] = {} + + def __init__( + self, + page: np.ndarray, + predictions: Dict[str, List[Prediction]], + page_idx: int, + dimensions: Tuple[int, int], + orientation: Optional[Dict[str, Any]] = None, + language: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(predictions=predictions) + self.page = page + self.page_idx = page_idx + self.dimensions = dimensions + self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None) + self.language = language if isinstance(language, dict) else dict(value=None, confidence=None) + + def render(self, prediction_break: str = "\n\n") -> str: + """Renders the full text of the element""" + return prediction_break.join( + f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions + ) + + def extra_repr(self) -> str: + return f"dimensions={self.dimensions}" + + def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None: + """Overlay the result on a given image + + Args: + interactive: whether the display should be interactive + preserve_aspect_ratio: pass True if you passed True to the predictor + **kwargs: keyword arguments passed to the matplotlib.pyplot.show method + """ + requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") + requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") + import matplotlib.pyplot as plt + + visualize_kie_page( + self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio + ) + plt.show(**kwargs) + + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Args: + ---- + **kwargs: keyword arguments passed to the matplotlib.pyplot.show method + + Returns: + ------- + synthesized page + """ + return synthesize_kie_page(self.export(), **kwargs) + + def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]: + """Export the page as XML (hOCR-format) + convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md + + Args: + ---- + file_title: the title of the XML file + + Returns: + ------- + a tuple of the XML byte string, and its ElementTree + """ + p_idx = self.page_idx + prediction_count: int = 1 + height, width = self.dimensions + language = self.language if "language" in self.language.keys() else "en" + # Create the XML root element + page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)}) + # Create the header / SubElements of the root element + head = SubElement(page_hocr, "head") + SubElement(head, "title").text = file_title + SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"}) + SubElement( + head, + "meta", + attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined] + ) + SubElement( + head, + "meta", + attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"}, + ) + # Create the body + body = SubElement(page_hocr, "body") + SubElement( + body, + "div", + attrib={ + "class": "ocr_page", + "id": f"page_{p_idx + 1}", + "title": f"image; bbox 0 0 {width} {height}; ppageno 0", + }, + ) + # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes + for class_name, predictions in self.predictions.items(): + for prediction in predictions: + if len(prediction.geometry) != 2: + raise TypeError("XML export is only available for straight bounding boxes for now.") + (xmin, ymin), (xmax, ymax) = prediction.geometry + prediction_div = SubElement( + body, + "div", + attrib={ + "class": "ocr_carea", + "id": f"{class_name}_prediction_{prediction_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + prediction_div.text = prediction.value + prediction_count += 1 + + return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + "predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]] + }) + return cls(**kwargs) + + +class Document(Element): + """Implements a document element as a collection of pages + + Args: + ---- + pages: list of page elements + """ + + _children_names: List[str] = ["pages"] + pages: List[Page] = [] + + def __init__( + self, + pages: List[Page], + ) -> None: + super().__init__(pages=pages) + + def render(self, page_break: str = "\n\n\n\n") -> str: + """Renders the full text of the element""" + return page_break.join(p.render() for p in self.pages) + + def show(self, **kwargs) -> None: + """Overlay the result on a given image""" + for result in self.pages: + result.show(**kwargs) + + def synthesize(self, **kwargs) -> List[np.ndarray]: + """Synthesize all pages from their predictions + + Returns + ------- + list of synthesized pages + """ + return [page.synthesize() for page in self.pages] + + def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: + """Export the document as XML (hOCR-format) + + Args: + ---- + **kwargs: additional keyword arguments passed to the Page.export_as_xml method + + Returns: + ------- + list of tuple of (bytes, ElementTree) + """ + return [page.export_as_xml(**kwargs) for page in self.pages] + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]}) + return cls(**kwargs) + + +class KIEDocument(Document): + """Implements a document element as a collection of pages + + Args: + ---- + pages: list of page elements + """ + + _children_names: List[str] = ["pages"] + pages: List[KIEPage] = [] # type: ignore[assignment] + + def __init__( + self, + pages: List[KIEPage], + ) -> None: + super().__init__(pages=pages) # type: ignore[arg-type] diff --git a/doctr/io/html.py b/doctr/io/html.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a8da237d2298b0f0b30578a783a2f45ea1be5c --- /dev/null +++ b/doctr/io/html.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +__all__ = ["read_html"] + + +def read_html(url: str, **kwargs: Any) -> bytes: + """Read a PDF file and convert it into an image in numpy format + + >>> from doctr.io import read_html + >>> doc = read_html("https://www.yoursite.com") + + Args: + ---- + url: URL of the target web page + **kwargs: keyword arguments from `weasyprint.HTML` + + Returns: + ------- + decoded PDF file as a bytes stream + """ + from weasyprint import HTML + + return HTML(url, **kwargs).write_pdf() diff --git a/doctr/io/image/__init__.py b/doctr/io/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..393c70c359df5bcebea5d8cdcb2277d60923e8e6 --- /dev/null +++ b/doctr/io/image/__init__.py @@ -0,0 +1,8 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +from .base import * + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/io/image/__pycache__/__init__.cpython-310.pyc b/doctr/io/image/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb878f2e6204bdd3d57754ab78f760c78767c190 Binary files /dev/null and b/doctr/io/image/__pycache__/__init__.cpython-310.pyc differ diff --git a/doctr/io/image/__pycache__/__init__.cpython-311.pyc b/doctr/io/image/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91028227f8173ee748520b23da4c1aa646403f2a Binary files /dev/null and b/doctr/io/image/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/io/image/__pycache__/__init__.cpython-38.pyc b/doctr/io/image/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7932c1d3cfd77fd715a5f74ad98657bd805675c Binary files /dev/null and b/doctr/io/image/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/io/image/__pycache__/base.cpython-310.pyc b/doctr/io/image/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cd7b6c5ac28df3ea620a2e42de5cf48056a0e6b Binary files /dev/null and b/doctr/io/image/__pycache__/base.cpython-310.pyc differ diff --git a/doctr/io/image/__pycache__/base.cpython-311.pyc b/doctr/io/image/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cd23f63ed524a937e93f2c48caaa53bf5231441 Binary files /dev/null and b/doctr/io/image/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/io/image/__pycache__/base.cpython-38.pyc b/doctr/io/image/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c38a4bfa7e8dae4ce9a049b71982f37a248d30 Binary files /dev/null and b/doctr/io/image/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/io/image/__pycache__/pytorch.cpython-311.pyc b/doctr/io/image/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daee97a71cad1c991485262ce746345c385865f3 Binary files /dev/null and b/doctr/io/image/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/io/image/__pycache__/tensorflow.cpython-310.pyc b/doctr/io/image/__pycache__/tensorflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b949786cc209386d67c007252e3d10f2ff11a200 Binary files /dev/null and b/doctr/io/image/__pycache__/tensorflow.cpython-310.pyc differ diff --git a/doctr/io/image/__pycache__/tensorflow.cpython-311.pyc b/doctr/io/image/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f6c1ffcb47b77ea800b09efb203f1ecb80904a9 Binary files /dev/null and b/doctr/io/image/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/io/image/__pycache__/tensorflow.cpython-38.pyc b/doctr/io/image/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85729c6615058a55971ac4ce91a16ca3f43afbcd Binary files /dev/null and b/doctr/io/image/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/io/image/base.py b/doctr/io/image/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c2ed3065bb326c99b70e7d9c52cdb9d36ef809 --- /dev/null +++ b/doctr/io/image/base.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from doctr.utils.common_types import AbstractFile + +__all__ = ["read_img_as_numpy"] + + +def read_img_as_numpy( + file: AbstractFile, + output_size: Optional[Tuple[int, int]] = None, + rgb_output: bool = True, +) -> np.ndarray: + """Read an image file into numpy format + + >>> from doctr.io import read_img_as_numpy + >>> page = read_img_as_numpy("path/to/your/doc.jpg") + + Args: + ---- + file: the path to the image file + output_size: the expected output size of each page in format H x W + rgb_output: whether the output ndarray channel order should be RGB instead of BGR. + + Returns: + ------- + the page decoded as numpy ndarray of shape H x W x 3 + """ + if isinstance(file, (str, Path)): + if not Path(file).is_file(): + raise FileNotFoundError(f"unable to access {file}") + img = cv2.imread(str(file), cv2.IMREAD_COLOR) + elif isinstance(file, bytes): + _file: np.ndarray = np.frombuffer(file, np.uint8) + img = cv2.imdecode(_file, cv2.IMREAD_COLOR) + else: + raise TypeError("unsupported object type for argument 'file'") + + # Validity check + if img is None: + raise ValueError("unable to read file.") + # Resizing + if isinstance(output_size, tuple): + img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR) + # Switch the channel order + if rgb_output: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img diff --git a/doctr/io/image/pytorch.py b/doctr/io/image/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..67d37e46b35b8efcb04820cfce594edded9cc2f8 --- /dev/null +++ b/doctr/io/image/pytorch.py @@ -0,0 +1,107 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from io import BytesIO +from typing import Tuple + +import numpy as np +import torch +from PIL import Image +from torchvision.transforms.functional import to_tensor + +from doctr.utils.common_types import AbstractPath + +__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] + + +def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Convert a PIL Image to a PyTorch tensor + + Args: + ---- + pil_img: a PIL image + dtype: the output tensor data type + + Returns: + ------- + decoded image as tensor + """ + if dtype == torch.float32: + img = to_tensor(pil_img) + else: + img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype) + + return img + + +def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read an image file as a PyTorch tensor + + Args: + ---- + img_path: location of the image file + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + decoded image as a tensor + """ + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + with Image.open(img_path, mode="r") as pil_img: + return tensor_from_pil(pil_img.convert("RGB"), dtype) + + +def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read a byte stream as a PyTorch tensor + + Args: + ---- + img_content: bytes of a decoded image + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + decoded image as a tensor + """ + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + with Image.open(BytesIO(img_content), mode="r") as pil_img: + return tensor_from_pil(pil_img.convert("RGB"), dtype) + + +def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read an image file as a PyTorch tensor + + Args: + ---- + npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8 + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + same image as a tensor of shape (C, H, W) + """ + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + if dtype == torch.float32: + img = to_tensor(npy_img) + else: + img = torch.from_numpy(npy_img) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)).contiguous() + if dtype == torch.float16: + # Switch to FP16 + img = img.to(dtype=torch.float16).div(255) + + return img + + +def get_img_shape(img: torch.Tensor) -> Tuple[int, int]: + """Get the shape of an image""" + return img.shape[-2:] # type: ignore[return-value] diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..28fb2fadd5cb103bb0c581a5fc10083c7570013b --- /dev/null +++ b/doctr/io/image/tensorflow.py @@ -0,0 +1,110 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Tuple + +import numpy as np +import tensorflow as tf +from PIL import Image +from tensorflow.keras.utils import img_to_array + +from doctr.utils.common_types import AbstractPath + +__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] + + +def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Convert a PIL Image to a TensorFlow tensor + + Args: + ---- + pil_img: a PIL image + dtype: the output tensor data type + + Returns: + ------- + decoded image as tensor + """ + npy_img = img_to_array(pil_img) + + return tensor_from_numpy(npy_img, dtype) + + +def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read an image file as a TensorFlow tensor + + Args: + ---- + img_path: location of the image file + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + decoded image as a tensor + """ + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + img = tf.io.read_file(img_path) + img = tf.image.decode_jpeg(img, channels=3) + + if dtype != tf.uint8: + img = tf.image.convert_image_dtype(img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read a byte stream as a TensorFlow tensor + + Args: + ---- + img_content: bytes of a decoded image + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + decoded image as a tensor + """ + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + img = tf.io.decode_image(img_content, channels=3) + + if dtype != tf.uint8: + img = tf.image.convert_image_dtype(img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read an image file as a TensorFlow tensor + + Args: + ---- + npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8 + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + ------- + same image as a tensor of shape (H, W, C) + """ + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + if dtype == tf.uint8: + img = tf.convert_to_tensor(npy_img, dtype=dtype) + else: + img = tf.image.convert_image_dtype(npy_img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def get_img_shape(img: tf.Tensor) -> Tuple[int, int]: + """Get the shape of an image""" + return img.shape[:2] diff --git a/doctr/io/pdf.py b/doctr/io/pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..e91413f7b1b50cf11061f986cc2f4d2a3a9daacf --- /dev/null +++ b/doctr/io/pdf.py @@ -0,0 +1,42 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Optional + +import numpy as np +import pypdfium2 as pdfium + +from doctr.utils.common_types import AbstractFile + +__all__ = ["read_pdf"] + + +def read_pdf( + file: AbstractFile, + scale: float = 2, + rgb_mode: bool = True, + password: Optional[str] = None, + **kwargs: Any, +) -> List[np.ndarray]: + """Read a PDF file and convert it into an image in numpy format + + >>> from doctr.io import read_pdf + >>> doc = read_pdf("path/to/your/doc.pdf") + + Args: + ---- + file: the path to the PDF file + scale: rendering scale (1 corresponds to 72dpi) + rgb_mode: if True, the output will be RGB, otherwise BGR + password: a password to unlock the document, if encrypted + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x C + """ + # Rasterise pages to numpy ndarrays with pypdfium2 + pdf = pdfium.PdfDocument(file, password=password, autoclose=True) + return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf] diff --git a/doctr/io/reader.py b/doctr/io/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..76f7317cb1a82989d75a49b9461f2d423c275a52 --- /dev/null +++ b/doctr/io/reader.py @@ -0,0 +1,85 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Sequence, Union + +import numpy as np + +from doctr.file_utils import requires_package +from doctr.utils.common_types import AbstractFile + +from .html import read_html +from .image import read_img_as_numpy +from .pdf import read_pdf + +__all__ = ["DocumentFile"] + + +class DocumentFile: + """Read a document from multiple extensions""" + + @classmethod + def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]: + """Read a PDF file + + >>> from doctr.io import DocumentFile + >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") + + Args: + ---- + file: the path to the PDF file or a binary stream + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + return read_pdf(file, **kwargs) + + @classmethod + def from_url(cls, url: str, **kwargs) -> List[np.ndarray]: + """Interpret a web page as a PDF document + + >>> from doctr.io import DocumentFile + >>> doc = DocumentFile.from_url("https://www.yoursite.com") + + Args: + ---- + url: the URL of the target web page + **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + requires_package( + "weasyprint", + "`.from_url` requires weasyprint installed.\n" + + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation", + ) + pdf_stream = read_html(url) + return cls.from_pdf(pdf_stream, **kwargs) + + @classmethod + def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]: + """Read an image file (or a collection of image files) and convert it into an image in numpy format + + >>> from doctr.io import DocumentFile + >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) + + Args: + ---- + files: the path to the image file or a binary stream, or a collection of those + **kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy` + + Returns: + ------- + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + if isinstance(files, (str, Path, bytes)): + files = [files] + + return [read_img_as_numpy(file, **kwargs) for file in files] diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6db1c0678087074f1277c4748c7d3414419110c --- /dev/null +++ b/doctr/models/__init__.py @@ -0,0 +1,5 @@ +from .classification import * +from .detection import * +from .recognition import * +from .zoo import * +from .factory import * diff --git a/doctr/models/__pycache__/__init__.cpython-311.pyc b/doctr/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8614ac99e603b18243a042cbd704b99e56fc107 Binary files /dev/null and b/doctr/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/__pycache__/__init__.cpython-38.pyc b/doctr/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff76dfba92521edaa291fc8036229dccc4c3ad22 Binary files /dev/null and b/doctr/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/__pycache__/_utils.cpython-311.pyc b/doctr/models/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab30347435f49a851fcbb9aaa4106aac323aa883 Binary files /dev/null and b/doctr/models/__pycache__/_utils.cpython-311.pyc differ diff --git a/doctr/models/__pycache__/_utils.cpython-38.pyc b/doctr/models/__pycache__/_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43d266879ad887453dc606d2d3da8d100be48450 Binary files /dev/null and b/doctr/models/__pycache__/_utils.cpython-38.pyc differ diff --git a/doctr/models/__pycache__/builder.cpython-311.pyc b/doctr/models/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63452efe3ceaed9fbe79ea86f3f62f27156dd08c Binary files /dev/null and b/doctr/models/__pycache__/builder.cpython-311.pyc differ diff --git a/doctr/models/__pycache__/builder.cpython-38.pyc b/doctr/models/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eb7a965628791299cece9fe3b379e8ce664586e Binary files /dev/null and b/doctr/models/__pycache__/builder.cpython-38.pyc differ diff --git a/doctr/models/__pycache__/core.cpython-311.pyc b/doctr/models/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a43947b8349df7f43aff863ab9dcba410417f1 Binary files /dev/null and b/doctr/models/__pycache__/core.cpython-311.pyc differ diff --git a/doctr/models/__pycache__/core.cpython-38.pyc b/doctr/models/__pycache__/core.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbb6400a89e22b163f336afdb4f5d8f630d23fa5 Binary files /dev/null and b/doctr/models/__pycache__/core.cpython-38.pyc differ diff --git a/doctr/models/__pycache__/zoo.cpython-311.pyc b/doctr/models/__pycache__/zoo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06620735766dcbb008587f3f51547a8e102eaa1d Binary files /dev/null and b/doctr/models/__pycache__/zoo.cpython-311.pyc differ diff --git a/doctr/models/__pycache__/zoo.cpython-38.pyc b/doctr/models/__pycache__/zoo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354094f236c083ffe685ca39534a74eb74dee569 Binary files /dev/null and b/doctr/models/__pycache__/zoo.cpython-38.pyc differ diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4394363cfd0666d7827883d4ee4ab10634db7d0 --- /dev/null +++ b/doctr/models/_utils.py @@ -0,0 +1,163 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from math import floor +from statistics import median_low +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +from langdetect import LangDetectException, detect_langs + +__all__ = ["estimate_orientation", "get_language", "invert_data_structure"] + + +def get_max_width_length_ratio(contour: np.ndarray) -> float: + """Get the maximum shape ratio of a contour. + + Args: + ---- + contour: the contour from cv2.findContour + + Returns: + ------- + the maximum shape ratio + """ + _, (w, h), _ = cv2.minAreaRect(contour) + return max(w / h, h / w) + + +def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int: + """Estimate the angle of the general document orientation based on the + lines of the document and the assumption that they should be horizontal. + + Args: + ---- + img: the img or bitmap to analyze (H, W, C) + n_ct: the number of contours used for the orientation estimation + ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines + + Returns: + ------- + the angle of the general document orientation + """ + assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" + max_value = np.max(img) + min_value = np.min(img) + if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1): + thresh = img.astype(np.uint8) + if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3: + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = cv2.medianBlur(gray_img, 5) + thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment] + + # try to merge words in lines + (h, w) = img.shape[:2] + k_x = max(1, (floor(w / 100))) + k_y = max(1, (floor(h / 100))) + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) + thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment] + + # extract contours + contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + # Sort contours + contours = sorted(contours, key=get_max_width_length_ratio, reverse=True) + + angles = [] + for contour in contours[:n_ct]: + _, (w, h), angle = cv2.minAreaRect(contour) + if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines + angles.append(angle) + elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree + angles.append(angle - 90) + + if len(angles) == 0: + return 0 # in case no angles is found + else: + median = -median_low(angles) + return round(median) if abs(median) != 0 else 0 + + +def rectify_crops( + crops: List[np.ndarray], + orientations: List[int], +) -> List[np.ndarray]: + """Rotate each crop of the list according to the predicted orientation: + 0: already straight, no rotation + 1: 90 ccw, rotate 3 times ccw + 2: 180, rotate 2 times ccw + 3: 270 ccw, rotate 1 time ccw + """ + # Inverse predictions (if angle of +90 is detected, rotate by -90) + orientations = [4 - pred if pred != 0 else 0 for pred in orientations] + return ( + [crop if orientation == 0 else np.rot90(crop, orientation) for orientation, crop in zip(orientations, crops)] + if len(orientations) > 0 + else [] + ) + + +def rectify_loc_preds( + page_loc_preds: np.ndarray, + orientations: List[int], +) -> Optional[np.ndarray]: + """Orient the quadrangle (Polygon4P) according to the predicted orientation, + so that the points are in this order: top L, top R, bot R, bot L if the crop is readable + """ + return ( + np.stack( + [ + np.roll(page_loc_pred, orientation, axis=0) + for orientation, page_loc_pred in zip(orientations, page_loc_preds) + ], + axis=0, + ) + if len(orientations) > 0 + else None + ) + + +def get_language(text: str) -> Tuple[str, float]: + """Get languages of a text using langdetect model. + Get the language with the highest probability or no language if only a few words or a low probability + + Args: + ---- + text (str): text + + Returns: + ------- + The detected language in ISO 639 code and confidence score + """ + try: + lang = detect_langs(text.lower())[0] + except LangDetectException: + return "unknown", 0.0 + if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2): + return "unknown", 0.0 + return lang.lang, lang.prob + + +def invert_data_structure( + x: Union[List[Dict[str, Any]], Dict[str, List[Any]]], +) -> Union[List[Dict[str, Any]], Dict[str, List[Any]]]: + """Invert a List of Dict of elements to a Dict of list of elements and the other way around + + Args: + ---- + x: a list of dictionaries with the same keys or a dictionary of lists of the same length + + Returns: + ------- + dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists + """ + if isinstance(x, dict): + assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionnary should have the same length." + return [dict(zip(x, t)) for t in zip(*x.values())] + elif isinstance(x, list): + return {k: [dic[k] for dic in x] for k in x[0]} + else: + raise TypeError(f"Expected input to be either a dict or a list, got {type(input)} instead.") diff --git a/doctr/models/builder.py b/doctr/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..397b3b4244ff5b9cfe5efee589677c0b5385a291 --- /dev/null +++ b/doctr/models/builder.py @@ -0,0 +1,487 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from scipy.cluster.hierarchy import fclusterdata + +from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Prediction, Word +from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes +from doctr.utils.repr import NestedObject + +__all__ = ["DocumentBuilder"] + + +class DocumentBuilder(NestedObject): + """Implements a document builder + + Args: + ---- + resolve_lines: whether words should be automatically grouped into lines + resolve_blocks: whether lines should be automatically grouped into blocks + paragraph_break: relative length of the minimum space separating paragraphs + export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle + box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is. + """ + + def __init__( + self, + resolve_lines: bool = True, + resolve_blocks: bool = True, + paragraph_break: float = 0.035, + export_as_straight_boxes: bool = False, + ) -> None: + self.resolve_lines = resolve_lines + self.resolve_blocks = resolve_blocks + self.paragraph_break = paragraph_break + self.export_as_straight_boxes = export_as_straight_boxes + + @staticmethod + def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Sort bounding boxes from top to bottom, left to right + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) + + Returns: + ------- + tuple: indices of ordered boxes of shape (N,), boxes + If straight boxes are passed tpo the function, boxes are unchanged + else: boxes returned are straight boxes fitted to the straightened rotated boxes + so that we fit the lines afterwards to the straigthened page + """ + if boxes.ndim == 3: + boxes = rotate_boxes( + loc_preds=boxes, + angle=-estimate_page_angle(boxes), + orig_shape=(1024, 1024), + min_angle=5.0, + ) + boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1) + return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes + + def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]: + """Split a line in sub_lines + + Args: + ---- + boxes: bounding boxes of shape (N, 4) + word_idcs: list of indexes for the words of the line + + Returns: + ------- + A list of (sub-)lines computed from the original line (words) + """ + lines = [] + # Sort words horizontally + word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()] + + # Eventually split line horizontally + if len(word_idcs) < 2: + lines.append(word_idcs) + else: + sub_line = [word_idcs[0]] + for i in word_idcs[1:]: + horiz_break = True + + prev_box = boxes[sub_line[-1]] + # Compute distance between boxes + dist = boxes[i, 0] - prev_box[2] + # If distance between boxes is lower than paragraph break, same sub-line + if dist < self.paragraph_break: + horiz_break = False + + if horiz_break: + lines.append(sub_line) + sub_line = [] + + sub_line.append(i) + lines.append(sub_line) + + return lines + + def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: + """Order boxes to group them in lines + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox + + Returns: + ------- + nested list of box indices + """ + # Sort boxes, and straighten the boxes if they are rotated + idxs, boxes = self._sort_boxes(boxes) + + # Compute median for boxes heights + y_med = np.median(boxes[:, 3] - boxes[:, 1]) + + lines = [] + words = [idxs[0]] # Assign the top-left word to the first line + # Define a mean y-center for the line + y_center_sum = boxes[idxs[0]][[1, 3]].mean() + + for idx in idxs[1:]: + vert_break = True + + # Compute y_dist + y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words)) + # If y-center of the box is close enough to mean y-center of the line, same line + if y_dist < y_med / 2: + vert_break = False + + if vert_break: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + words = [] + y_center_sum = 0 + + words.append(idx) + y_center_sum += boxes[idx][[1, 3]].mean() + + # Use the remaining words to form the last(s) line(s) + if len(words) > 0: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + + return lines + + @staticmethod + def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]: + """Order lines to group them in blocks + + Args: + ---- + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) + lines: list of lines, each line is a list of idx + + Returns: + ------- + nested list of box indices + """ + # Resolve enclosing boxes of lines + if boxes.ndim == 3: + box_lines: np.ndarray = np.asarray([ + resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc] + for line in lines + ]) + else: + _box_lines = [ + resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line]) + for line in lines + ] + box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines]) + + # Compute geometrical features of lines to clusterize + # Clusterizing only with box centers yield to poor results for complex documents + if boxes.ndim == 3: + box_features: np.ndarray = np.stack( + ( + (box_lines[:, 0, 0] + box_lines[:, 0, 1]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 2, 0] + box_lines[:, 2, 1]) / 2, + ), + axis=-1, + ) + else: + box_features = np.stack( + ( + (box_lines[:, 0] + box_lines[:, 3]) / 2, + (box_lines[:, 1] + box_lines[:, 2]) / 2, + (box_lines[:, 0] + box_lines[:, 2]) / 2, + (box_lines[:, 1] + box_lines[:, 3]) / 2, + box_lines[:, 0], + box_lines[:, 1], + ), + axis=-1, + ) + # Compute clusters + clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean") + + _blocks: Dict[int, List[int]] = {} + # Form clusters + for line_idx, cluster_idx in enumerate(clusters): + if cluster_idx in _blocks.keys(): + _blocks[cluster_idx].append(line_idx) + else: + _blocks[cluster_idx] = [line_idx] + + # Retrieve word-box level to return a fully nested structure + blocks = [[lines[idx] for idx in block] for block in _blocks.values()] + + return blocks + + def _build_blocks( + self, + boxes: np.ndarray, + word_preds: List[Tuple[str, float]], + crop_orientations: List[Dict[str, Any]], + ) -> List[Block]: + """Gather independent words in structured blocks + + Args: + ---- + boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2) + word_preds: list of all detected words of the page, of shape N + crop_orientations: list of dictoinaries containing + the general orientation (orientations + confidences) of the crops + + Returns: + ------- + list of block elements + """ + if boxes.shape[0] != len(word_preds): + raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}") + + if boxes.shape[0] == 0: + return [] + + # Decide whether we try to form lines + _boxes = boxes + if self.resolve_lines: + lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) + # Decide whether we try to form blocks + if self.resolve_blocks and len(lines) > 1: + _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines) + else: + _blocks = [lines] + else: + # Sort bounding boxes, one line for all boxes, one block for the line + lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] # type: ignore[list-item] + _blocks = [lines] + + blocks = [ + Block([ + Line([ + Word( + *word_preds[idx], + tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + crop_orientations[idx], + ) + if boxes.ndim == 3 + else Word( + *word_preds[idx], + ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])), + crop_orientations[idx], + ) + for idx in line + ]) + for line in lines + ]) + for lines in _blocks + ] + + return blocks + + def extra_repr(self) -> str: + return ( + f"resolve_lines={self.resolve_lines}, resolve_blocks={self.resolve_blocks}, " + f"paragraph_break={self.paragraph_break}, " + f"export_as_straight_boxes={self.export_as_straight_boxes}" + ) + + def __call__( + self, + pages: List[np.ndarray], + boxes: List[np.ndarray], + text_preds: List[List[Tuple[str, float]]], + page_shapes: List[Tuple[int, int]], + crop_orientations: List[Dict[str, Any]], + orientations: Optional[List[Dict[str, Any]]] = None, + languages: Optional[List[Dict[str, Any]]] = None, + ) -> Document: + """Re-arrange detected words into structured blocks + + Args: + ---- + pages: list of N elements, where each element represents the page image + boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5) + or (*, 6) for all words for a given page + text_preds: list of N elements, where each element is the list of all word prediction (text + confidence) + page_shapes: shape of each page, of size N + crop_orientations: list of N elements, where each element is + a dictionary containing the general orientation (orientations + confidences) of the crops + orientations: optional, list of N elements, + where each element is a dictionary containing the orientation (orientation + confidence) + languages: optional, list of N elements, + where each element is a dictionary containing the language (language + confidence) + + Returns: + ------- + document object + """ + if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len( + crop_orientations + ): + raise ValueError("All arguments are expected to be lists of the same size") + + _orientations = ( + orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item] + ) + _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item] + if self.export_as_straight_boxes and len(boxes) > 0: + # If boxes are already straight OK, else fit a bounding rect + if boxes[0].ndim == 3: + # Iterate over pages and boxes + boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes] + + _pages = [ + Page( + page, + self._build_blocks( + page_boxes, + word_preds, + word_crop_orientations, + ), + _idx, + shape, + orientation, + language, + ) + for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip( + pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages + ) + ] + + return Document(_pages) + + +class KIEDocumentBuilder(DocumentBuilder): + """Implements a KIE document builder + + Args: + ---- + resolve_lines: whether words should be automatically grouped into lines + resolve_blocks: whether lines should be automatically grouped into blocks + paragraph_break: relative length of the minimum space separating paragraphs + export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle + box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is. + """ + + def __call__( # type: ignore[override] + self, + pages: List[np.ndarray], + boxes: List[Dict[str, np.ndarray]], + text_preds: List[Dict[str, List[Tuple[str, float]]]], + page_shapes: List[Tuple[int, int]], + crop_orientations: List[Dict[str, List[Dict[str, Any]]]], + orientations: Optional[List[Dict[str, Any]]] = None, + languages: Optional[List[Dict[str, Any]]] = None, + ) -> KIEDocument: + """Re-arrange detected words into structured predictions + + Args: + ---- + pages: list of N elements, where each element represents the page image + boxes: list of N dictionaries, where each element represents the localization predictions for a class, + of shape (*, 5) or (*, 6) for all predictions + text_preds: list of N dictionaries, where each element is the list of all word prediction + page_shapes: shape of each page, of size N + crop_orientations: list of N dictonaries, where each element is + a list containing the general crop orientations (orientations + confidences) of the crops + orientations: optional, list of N elements, + where each element is a dictionary containing the orientation (orientation + confidence) + languages: optional, list of N elements, + where each element is a dictionary containing the language (language + confidence) + + Returns: + ------- + document object + """ + if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len( + crop_orientations + ): + raise ValueError("All arguments are expected to be lists of the same size") + _orientations = ( + orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item] + ) + _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item] + if self.export_as_straight_boxes and len(boxes) > 0: + # If boxes are already straight OK, else fit a bounding rect + if next(iter(boxes[0].values())).ndim == 3: + straight_boxes: List[Dict[str, np.ndarray]] = [] + # Iterate over pages + for p_boxes in boxes: + # Iterate over boxes of the pages + straight_boxes_dict = {} + for k, box in p_boxes.items(): + straight_boxes_dict[k] = np.concatenate((box.min(1), box.max(1)), 1) + straight_boxes.append(straight_boxes_dict) + boxes = straight_boxes + + _pages = [ + KIEPage( + page, + { + k: self._build_blocks( + page_boxes[k], + word_preds[k], + word_crop_orientations[k], + ) + for k in page_boxes.keys() + }, + _idx, + shape, + orientation, + language, + ) + for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip( + pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages + ) + ] + + return KIEDocument(_pages) + + def _build_blocks( # type: ignore[override] + self, + boxes: np.ndarray, + word_preds: List[Tuple[str, float]], + crop_orientations: List[Dict[str, Any]], + ) -> List[Prediction]: + """Gather independent words in structured blocks + + Args: + ---- + boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2) + word_preds: list of all detected words of the page, of shape N + crop_orientations: list of orientations for each word crop + + Returns: + ------- + list of block elements + """ + if boxes.shape[0] != len(word_preds): + raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}") + + if boxes.shape[0] == 0: + return [] + + # Decide whether we try to form lines + _boxes = boxes + idxs, _ = self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) + predictions = [ + Prediction( + value=word_preds[idx][0], + confidence=word_preds[idx][1], + geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + crop_orientation=crop_orientations[idx], + ) + if boxes.ndim == 3 + else Prediction( + value=word_preds[idx][0], + confidence=word_preds[idx][1], + geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])), + crop_orientation=crop_orientations[idx], + ) + for idx in idxs + ] + return predictions diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..496079740010bce5f1da5161808ed66ccfb9a898 --- /dev/null +++ b/doctr/models/classification/__init__.py @@ -0,0 +1,7 @@ +from .mobilenet import * +from .resnet import * +from .vgg import * +from .magc_resnet import * +from .vit import * +from .textnet import * +from .zoo import * diff --git a/doctr/models/classification/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13a4acd46b995903c58bc3037ac9b9949510b949 Binary files /dev/null and b/doctr/models/classification/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e0be5db2758c5a9eb7d014de800b891ef64461 Binary files /dev/null and b/doctr/models/classification/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/__pycache__/zoo.cpython-311.pyc b/doctr/models/classification/__pycache__/zoo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..927649ef89bb371094ce77e64fc5f56030ea3fc7 Binary files /dev/null and b/doctr/models/classification/__pycache__/zoo.cpython-311.pyc differ diff --git a/doctr/models/classification/__pycache__/zoo.cpython-38.pyc b/doctr/models/classification/__pycache__/zoo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6dd66a03f2020b5d842411cca5b890edd58ea20 Binary files /dev/null and b/doctr/models/classification/__pycache__/zoo.cpython-38.pyc differ diff --git a/doctr/models/classification/magc_resnet/__init__.py b/doctr/models/classification/magc_resnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/classification/magc_resnet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca9b1bf255cee3a366bdee1ddcc8de8a89f2aa2 Binary files /dev/null and b/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da99be9c857e8d175cd4359d00314f1af6fdfb0c Binary files /dev/null and b/doctr/models/classification/magc_resnet/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/magc_resnet/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/magc_resnet/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e14ba0dfb74a3522fc320c183a66381c072018ae Binary files /dev/null and b/doctr/models/classification/magc_resnet/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83d2c5b3a99688b4c41f46b1759d71f944375f40 Binary files /dev/null and b/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fee6ddbe58f428639a3ccc86fd4273efaecd6bc Binary files /dev/null and b/doctr/models/classification/magc_resnet/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/magc_resnet/pytorch.py b/doctr/models/classification/magc_resnet/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..f503d7c7fadbe0026b2c93032172258eb94145ad --- /dev/null +++ b/doctr/models/classification/magc_resnet/pytorch.py @@ -0,0 +1,177 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import math +from copy import deepcopy +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from doctr.datasets import VOCABS + +from ...utils.pytorch import load_pretrained_params +from ..resnet.pytorch import ResNet + +__all__ = ["magc_resnet31"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "magc_resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/magc_resnet31-857391d8.pt&src=0", + }, +} + + +class MAGC(nn.Module): + """Implements the Multi-Aspect Global Context Attention, as described in + `_. + + Args: + ---- + inplanes: input channels + headers: number of headers to split channels + attn_scale: if True, re-scale attention to counteract the variance distibutions + ratio: bottleneck ratio + **kwargs + """ + + def __init__( + self, + inplanes: int, + headers: int = 8, + attn_scale: bool = False, + ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.headers = headers + self.inplanes = inplanes + self.attn_scale = attn_scale + self.planes = int(inplanes * ratio) + + self.single_header_inplanes = int(inplanes / headers) + + self.conv_mask = nn.Conv2d(self.single_header_inplanes, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=1) + + self.transform = nn.Sequential( + nn.Conv2d(self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), + nn.Conv2d(self.planes, self.inplanes, kernel_size=1), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + batch, _, height, width = inputs.size() + # (N * headers, C / headers, H , W) + x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width) + shortcut = x + # (N * headers, C / headers, H * W) + shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width) + + # (N * headers, 1, H, W) + context_mask = self.conv_mask(x) + # (N * headers, H * W) + context_mask = context_mask.view(batch * self.headers, -1) + + # scale variance + if self.attn_scale and self.headers > 1: + context_mask = context_mask / math.sqrt(self.single_header_inplanes) + + # (N * headers, H * W) + context_mask = self.softmax(context_mask) + + # (N * headers, C / headers) + context = (shortcut * context_mask.unsqueeze(1)).sum(-1) + + # (N, C, 1, 1) + context = context.view(batch, self.headers * self.single_header_inplanes, 1, 1) + + # Transform: B, C, 1, 1 -> B, C, 1, 1 + transformed = self.transform(context) + return inputs + transformed + + +def _magc_resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_stride: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> ResNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = ResNet( + num_blocks, + output_channels, + stage_stride, + stage_conv, + stage_pooling, + attn_module=partial(MAGC, headers=8, attn_scale=True), + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with Multi-Aspect Global Context Attention as described in + `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", + `_. + + >>> import torch + >>> from doctr.models import magc_resnet31 + >>> model = magc_resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A feature extractor model + """ + return _magc_resnet( + "magc_resnet31", + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [1, 1, 1, 1], + [True] * 4, + [(2, 2), (2, 1), None, None], + origin_stem=False, + stem_channels=128, + ignore_keys=["13.weight", "13.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e791e661bfc9653e69302ce9d5f11315cd19ff6e --- /dev/null +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -0,0 +1,192 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from copy import deepcopy +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params +from ..resnet.tensorflow import ResNet + +__all__ = ["magc_resnet31"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "magc_resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0", + }, +} + + +class MAGC(layers.Layer): + """Implements the Multi-Aspect Global Context Attention, as described in + `_. + + Args: + ---- + inplanes: input channels + headers: number of headers to split channels + attn_scale: if True, re-scale attention to counteract the variance distibutions + ratio: bottleneck ratio + **kwargs + """ + + def __init__( + self, + inplanes: int, + headers: int = 8, + attn_scale: bool = False, + ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.headers = headers # h + self.inplanes = inplanes # C + self.attn_scale = attn_scale + self.planes = int(inplanes * ratio) + + self.single_header_inplanes = int(inplanes / headers) # C / h + + self.conv_mask = layers.Conv2D(filters=1, kernel_size=1, kernel_initializer=tf.initializers.he_normal()) + + self.transform = Sequential( + [ + layers.Conv2D(filters=self.planes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()), + layers.LayerNormalization([1, 2, 3]), + layers.ReLU(), + layers.Conv2D(filters=self.inplanes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()), + ], + name="transform", + ) + + def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor: + b, h, w, c = (tf.shape(inputs)[i] for i in range(4)) + + # B, H, W, C -->> B*h, H, W, C/h + x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes)) + x = tf.transpose(x, perm=(0, 3, 1, 2, 4)) + x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes)) + + # Compute shorcut + shortcut = x + # B*h, 1, H*W, C/h + shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes)) + # B*h, 1, C/h, H*W + shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2]) + + # Compute context mask + # B*h, H, W, 1 + context_mask = self.conv_mask(x) + # B*h, 1, H*W, 1 + context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1)) + # scale variance + if self.attn_scale and self.headers > 1: + context_mask = context_mask / math.sqrt(self.single_header_inplanes) + # B*h, 1, H*W, 1 + context_mask = tf.keras.activations.softmax(context_mask, axis=2) + + # Compute context + # B*h, 1, C/h, 1 + context = tf.matmul(shortcut, context_mask) + context = tf.reshape(context, shape=(b, 1, c, 1)) + # B, 1, 1, C + context = tf.transpose(context, perm=(0, 1, 3, 2)) + # Set shape to resolve shape when calling this module in the Sequential MAGCResnet + batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1] + context.set_shape([batch, 1, 1, chan]) + return context + + def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: + # Context modeling: B, H, W, C -> B, 1, 1, C + context = self.context_modeling(inputs) + # Transform: B, 1, 1, C -> B, 1, 1, C + transformed = self.transform(context) + return inputs + transformed + + +def _magc_resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + **kwargs: Any, +) -> ResNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + _cfg["input_shape"] = kwargs["input_shape"] + kwargs.pop("classes") + + # Build the model + model = ResNet( + num_blocks, + output_channels, + stage_downsample, + stage_conv, + stage_pooling, + origin_stem, + attn_module=partial(MAGC, headers=8, attn_scale=True), + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with Multi-Aspect Global Context Attention as described in + `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import magc_resnet31 + >>> model = magc_resnet31(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A feature extractor model + """ + return _magc_resnet( + "magc_resnet31", + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [False] * 4, + [True] * 4, + [(2, 2), (2, 1), None, None], + False, + stem_channels=128, + **kwargs, + ) diff --git a/doctr/models/classification/mobilenet/__init__.py b/doctr/models/classification/mobilenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64556e403a5697432f805a5af28dab812fa8b932 --- /dev/null +++ b/doctr/models/classification/mobilenet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b84c62368857769719531dabd7e6a959e787e582 Binary files /dev/null and b/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..562fa6c62dc5b95231a32118adb205ee6eec39b4 Binary files /dev/null and b/doctr/models/classification/mobilenet/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/mobilenet/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/mobilenet/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2d3dc0ef1be7500129f1ccc9fda152166ae2d9e Binary files /dev/null and b/doctr/models/classification/mobilenet/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2734baeda92361813c79232747e01261e7e43eb6 Binary files /dev/null and b/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22f2686080c2d835c895697e270f006fb7424f5d Binary files /dev/null and b/doctr/models/classification/mobilenet/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..57451816bc3622d260725ddc068d144614250f10 --- /dev/null +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -0,0 +1,273 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py + +from copy import deepcopy +from typing import Any, Dict, List, Optional + +from torchvision.models import mobilenetv3 + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params + +__all__ = [ + "mobilenet_v3_small", + "mobilenet_v3_small_r", + "mobilenet_v3_large", + "mobilenet_v3_large_r", + "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", +] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "mobilenet_v3_large": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-11fc8cb9.pt&src=0", + }, + "mobilenet_v3_large_r": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-74a22066.pt&src=0", + }, + "mobilenet_v3_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-6a4bfa6b.pt&src=0", + }, + "mobilenet_v3_small_r": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0", + }, + "mobilenet_v3_small_crop_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 256, 256), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0", + }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 512, 512), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0", + }, +} + + +def _mobilenet_v3( + arch: str, + pretrained: bool, + rect_strides: Optional[List[str]] = None, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> mobilenetv3.MobileNetV3: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + if arch.startswith("mobilenet_v3_small"): + model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None) + else: + model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None) + + # Rectangular strides + if isinstance(rect_strides, list): + for layer_name in rect_strides: + m = model + for child in layer_name.split("."): + m = getattr(m, child) + m.stride = (2, 1) + + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + model.cfg = _cfg + + return model + + +def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import torch + >>> from doctr.models import mobilenet_v3_small + >>> model = mobilenetv3_small(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs + ) + + +def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + >>> import torch + >>> from doctr.models import mobilenet_v3_small_r + >>> model = mobilenet_v3_small_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_small_r", + pretrained, + ["features.2.block.1.0", "features.4.block.1.0", "features.9.block.1.0"], + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) + + +def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import torch + >>> from doctr.models import mobilenet_v3_large + >>> model = mobilenet_v3_large(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_large", + pretrained, + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) + + +def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + >>> import torch + >>> from doctr.models import mobilenet_v3_large_r + >>> model = mobilenet_v3_large_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_large_r", + pretrained, + ["features.4.block.1.0", "features.7.block.1.0", "features.13.block.1.0"], + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) + + +def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import torch + >>> from doctr.models import mobilenet_v3_small_crop_orientation + >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_small_crop_orientation", + pretrained, + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) + + +def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + >>> import torch + >>> from doctr.models import mobilenet_v3_small_page_orientation + >>> model = mobilenet_v3_small_page_orientation(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_small_page_orientation", + pretrained, + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0b99a9ecc80c69df3e0741c26e67dbf72f89d4 --- /dev/null +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -0,0 +1,437 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from ....datasets import VOCABS +from ...utils import conv_sequence, load_pretrained_params + +__all__ = [ + "MobileNetV3", + "mobilenet_v3_small", + "mobilenet_v3_small_r", + "mobilenet_v3_large", + "mobilenet_v3_large_r", + "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", +] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "mobilenet_v3_large": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0", + }, + "mobilenet_v3_large_r": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0", + }, + "mobilenet_v3_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0", + }, + "mobilenet_v3_small_r": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0", + }, + "mobilenet_v3_small_crop_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (128, 128, 3), + "classes": [0, -90, 180, 90], + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0", + }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (512, 512, 3), + "classes": [0, -90, 180, 90], + "url": None, + }, +} + + +def hard_swish(x: tf.Tensor) -> tf.Tensor: + return x * tf.nn.relu6(x + 3.0) / 6.0 + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class SqueezeExcitation(Sequential): + """Squeeze and Excitation.""" + + def __init__(self, chan: int, squeeze_factor: int = 4) -> None: + super().__init__([ + layers.GlobalAveragePooling2D(), + layers.Dense(chan // squeeze_factor, activation="relu"), + layers.Dense(chan, activation="hard_sigmoid"), + layers.Reshape((1, 1, chan)), + ]) + + def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor: + x = super().call(inputs, **kwargs) + x = tf.math.multiply(inputs, x) + return x + + +class InvertedResidualConfig: + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: Union[int, Tuple[int, int]], + width_mult: float = 1, + ) -> None: + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.kernel = kernel + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + + +class InvertedResidual(layers.Layer): + """InvertedResidual for mobilenet + + Args: + ---- + conf: configuration object for inverted residual + """ + + def __init__( + self, + conf: InvertedResidualConfig, + **kwargs: Any, + ) -> None: + _kwargs = {"input_shape": kwargs.pop("input_shape")} if isinstance(kwargs.get("input_shape"), tuple) else {} + super().__init__(**kwargs) + + act_fn = hard_swish if conf.use_hs else tf.nn.relu + + _is_s1 = (isinstance(conf.stride, tuple) and conf.stride == (1, 1)) or conf.stride == 1 + self.use_res_connect = _is_s1 and conf.input_channels == conf.out_channels + + _layers = [] + # expand + if conf.expanded_channels != conf.input_channels: + _layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs)) + + # depth-wise + _layers.extend( + conv_sequence( + conf.expanded_channels, + act_fn, + kernel_size=conf.kernel, + strides=conf.stride, + bn=True, + groups=conf.expanded_channels, + ) + ) + + if conf.use_se: + _layers.append(SqueezeExcitation(conf.expanded_channels)) + + # project + _layers.extend( + conv_sequence( + conf.out_channels, + None, + kernel_size=1, + bn=True, + ) + ) + + self.block = Sequential(_layers) + + def call( + self, + inputs: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + out = self.block(inputs, **kwargs) + if self.use_res_connect: + out = tf.add(out, inputs) + + return out + + +class MobileNetV3(Sequential): + """Implements MobileNetV3, inspired from both: + `_. + and `_. + """ + + def __init__( + self, + layout: List[InvertedResidualConfig], + include_top: bool = True, + head_chans: int = 1024, + num_classes: int = 1000, + cfg: Optional[Dict[str, Any]] = None, + input_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + _layers = [ + Sequential( + conv_sequence( + layout[0].input_channels, hard_swish, True, kernel_size=3, strides=2, input_shape=input_shape + ), + name="stem", + ) + ] + + for idx, conf in enumerate(layout): + _layers.append( + InvertedResidual(conf, name=f"inverted_{idx}"), + ) + + _layers.append( + Sequential(conv_sequence(6 * layout[-1].out_channels, hard_swish, True, kernel_size=1), name="final_block") + ) + + if include_top: + _layers.extend([ + layers.GlobalAveragePooling2D(), + layers.Dense(head_chans, activation=hard_swish), + layers.Dropout(0.2), + layers.Dense(num_classes), + ]) + + super().__init__(_layers) + self.cfg = cfg + + +def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + _cfg["input_shape"] = kwargs["input_shape"] + kwargs.pop("classes") + + # cf. Table 1 & 2 of the paper + if arch.startswith("mobilenet_v3_small"): + inverted_residual_setting = [ + InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1 + InvertedResidualConfig(16, 3, 72, 24, False, "RE", (2, 1) if rect_strides else 2), # C2 + InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1), + InvertedResidualConfig(24, 5, 96, 40, True, "HS", (2, 1) if rect_strides else 2), # C3 + InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1), + InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1), + InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1), + InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1), + InvertedResidualConfig(48, 5, 288, 96, True, "HS", (2, 1) if rect_strides else 2), # C4 + InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1), + InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1), + ] + head_chans = 1024 + else: + inverted_residual_setting = [ + InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1), + InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1 + InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1), + InvertedResidualConfig(24, 5, 72, 40, True, "RE", (2, 1) if rect_strides else 2), # C2 + InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1), + InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1), + InvertedResidualConfig(40, 3, 240, 80, False, "HS", (2, 1) if rect_strides else 2), # C3 + InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1), + InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1), + InvertedResidualConfig(112, 5, 672, 160, True, "HS", (2, 1) if rect_strides else 2), # C4 + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1), + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1), + ] + head_chans = 1280 + + kwargs["num_classes"] = _cfg["num_classes"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + model = MobileNetV3( + inverted_residual_setting, + head_chans=head_chans, + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small + >>> model = mobilenet_v3_small(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs) + + +def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_r + >>> model = mobilenet_v3_small_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs) + + +def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_large + >>> model = mobilenet_v3_large(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs) + + +def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_large_r + >>> model = mobilenet_v3_large_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs) + + +def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_crop_orientation + >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs) + + +def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_page_orientation + >>> model = mobilenet_v3_small_page_orientation(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs) diff --git a/doctr/models/classification/predictor/__init__.py b/doctr/models/classification/predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/classification/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/predictor/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/predictor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc9ded7145d2ee3e470b43de0bc805763d3ae996 Binary files /dev/null and b/doctr/models/classification/predictor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/predictor/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60688261a6ca4f6d3f97b405f9110bae06188deb Binary files /dev/null and b/doctr/models/classification/predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/predictor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/predictor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f0af879edf7406908a2b5d12aedd401727107ee Binary files /dev/null and b/doctr/models/classification/predictor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ffd955be80360858a5991169de97ef93f89a06 Binary files /dev/null and b/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..157e278a5580cc973634393f9142e54b06d22101 Binary files /dev/null and b/doctr/models/classification/predictor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..d0612505652de9a3d82d590c710b9e4710f660d5 --- /dev/null +++ b/doctr/models/classification/predictor/pytorch.py @@ -0,0 +1,63 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import set_device_and_dtype + +__all__ = ["OrientationPredictor"] + + +class OrientationPredictor(nn.Module): + """Implements an object able to detect the reading direction of a text box or a page. + 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core classification architecture (backbone + classification head) + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.inference_mode() + def forward( + self, + inputs: List[Union[np.ndarray, torch.Tensor]], + ) -> List[Union[List[int], List[float]]]: + # Dimension check + if any(input.ndim != 3 for input in inputs): + raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(inputs) + _params = next(self.model.parameters()) + self.model, processed_batches = set_device_and_dtype( + self.model, processed_batches, _params.device, _params.dtype + ) + predicted_batches = [self.model(batch) for batch in processed_batches] + # confidence + probs = [ + torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches + ] + # Postprocess predictions + predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] + + class_idxs = [int(pred) for batch in predicted_batches for pred in batch] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] + confs = [round(float(p), 2) for prob in probs for p in prob] + + return [class_idxs, classes, confs] diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..95295584f1e6abaf59633b9cb6e79be071991538 --- /dev/null +++ b/doctr/models/classification/predictor/tensorflow.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Union + +import numpy as np +import tensorflow as tf +from tensorflow import keras + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +__all__ = ["OrientationPredictor"] + + +class OrientationPredictor(NestedObject): + """Implements an object able to detect the reading direction of a text box or a page. + 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core classification architecture (backbone + classification head) + """ + + _children_names: List[str] = ["pre_processor", "model"] + + def __init__( + self, + pre_processor: PreProcessor, + model: keras.Model, + ) -> None: + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + inputs: List[Union[np.ndarray, tf.Tensor]], + ) -> List[Union[List[int], List[float]]]: + # Dimension check + if any(input.ndim != 3 for input in inputs): + raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(inputs) + predicted_batches = [self.model(batch, training=False) for batch in processed_batches] + + # confidence + probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches] + # Postprocess predictions + predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches] + + class_idxs = [int(pred) for batch in predicted_batches for pred in batch] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] + confs = [round(float(p), 2) for prob in probs for p in prob] + + return [class_idxs, classes, confs] diff --git a/doctr/models/classification/resnet/__init__.py b/doctr/models/classification/resnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/classification/resnet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/resnet/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/resnet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f064d606b9272572794279c8eb7ecd3458571fc8 Binary files /dev/null and b/doctr/models/classification/resnet/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/resnet/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/resnet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d68b5f93309c9c498153dde645cb7f93dc7b0624 Binary files /dev/null and b/doctr/models/classification/resnet/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/resnet/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/resnet/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4132d69accefe31c0635ac3f77382fb0288bfdb Binary files /dev/null and b/doctr/models/classification/resnet/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db3e49a7d315c73b7fef9ea9fd88cb5abf025dc4 Binary files /dev/null and b/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee4188acc188eccd79189f08bf313b614da6b28b Binary files /dev/null and b/doctr/models/classification/resnet/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/resnet/pytorch.py b/doctr/models/classification/resnet/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7591741c2985f48f3be3b18e766d293d754e04ed --- /dev/null +++ b/doctr/models/classification/resnet/pytorch.py @@ -0,0 +1,366 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import nn +from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import ResNet as TVResNet +from torchvision.models.resnet import resnet18 as tv_resnet18 +from torchvision.models.resnet import resnet34 as tv_resnet34 +from torchvision.models.resnet import resnet50 as tv_resnet50 + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence_pt, load_pretrained_params + +__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "resnet18": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-244bf390.pt&src=0", + }, + "resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet31-1056cc5c.pt&src=0", + }, + "resnet34": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-bd8725db.pt&src=0", + }, + "resnet50": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-1a6c155e.pt&src=0", + }, + "resnet34_wide": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/resnet34_wide-b4b3e39e.pt&src=0", + }, +} + + +def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]: + """Build a ResNet stage""" + _layers: List[nn.Module] = [] + + in_chan = in_channels + s = stride + for _ in range(num_blocks): + downsample = None + if in_chan != out_channels: + downsample = nn.Sequential(*conv_sequence_pt(in_chan, out_channels, False, True, kernel_size=1, stride=s)) + + _layers.append(BasicBlock(in_chan, out_channels, stride=s, downsample=downsample)) + in_chan = out_channels + # Only the first block can have stride != 1 + s = 1 + + return _layers + + +class ResNet(nn.Sequential): + """Implements a ResNet-31 architecture from `"Show, Attend and Read:A Simple and Strong Baseline for Irregular + Text Recognition" `_. + + Args: + ---- + num_blocks: number of resnet block in each stage + output_channels: number of channels in each stage + stage_conv: whether to add a conv_sequence after each stage + stage_pooling: pooling to add after each stage (if None, no pooling) + origin_stem: whether to use the orginal ResNet stem or ResNet-31's + stem_channels: number of output channels of the stem convolutions + attn_module: attention module to use in each stage + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + """ + + def __init__( + self, + num_blocks: List[int], + output_channels: List[int], + stage_stride: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + stem_channels: int = 64, + attn_module: Optional[Callable[[int], nn.Module]] = None, + include_top: bool = True, + num_classes: int = 1000, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers: List[nn.Module] + if origin_stem: + _layers = [ + *conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + else: + _layers = [ + *conv_sequence_pt(3, stem_channels // 2, True, True, kernel_size=3, padding=1), + *conv_sequence_pt(stem_channels // 2, stem_channels, True, True, kernel_size=3, padding=1), + nn.MaxPool2d(2), + ] + in_chans = [stem_channels] + output_channels[:-1] + for n_blocks, in_chan, out_chan, stride, conv, pool in zip( + num_blocks, in_chans, output_channels, stage_stride, stage_conv, stage_pooling + ): + _stage = resnet_stage(in_chan, out_chan, n_blocks, stride) + if attn_module is not None: + _stage.append(attn_module(out_chan)) + if conv: + _stage.extend(conv_sequence_pt(out_chan, out_chan, True, True, kernel_size=3, padding=1)) + if pool is not None: + _stage.append(nn.MaxPool2d(pool)) + _layers.append(nn.Sequential(*_stage)) + + if include_top: + _layers.extend([ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(1), + nn.Linear(output_channels[-1], num_classes, bias=True), + ]) + + super().__init__(*_layers) + self.cfg = cfg + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def _resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_stride: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> ResNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = ResNet(num_blocks, output_channels, stage_stride, stage_conv, stage_pooling, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def _tv_resnet( + arch: str, + pretrained: bool, + arch_fn, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> TVResNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = arch_fn(**kwargs, weights=None) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + model.cfg = _cfg + + return model + + +def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet: + """ResNet-18 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import torch + >>> from doctr.models import resnet18 + >>> model = resnet18(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A resnet18 model + """ + return _tv_resnet( + "resnet18", + pretrained, + tv_resnet18, + ignore_keys=["fc.weight", "fc.bias"], + **kwargs, + ) + + +def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with rectangular pooling windows as described in + `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition", + `_. Downsizing: (H, W) --> (H/8, W/4) + + >>> import torch + >>> from doctr.models import resnet31 + >>> model = resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A resnet31 model + """ + return _resnet( + "resnet31", + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [1, 1, 1, 1], + [True] * 4, + [(2, 2), (2, 1), None, None], + origin_stem=False, + stem_channels=128, + ignore_keys=["13.weight", "13.bias"], + **kwargs, + ) + + +def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet: + """ResNet-34 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import torch + >>> from doctr.models import resnet34 + >>> model = resnet34(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A resnet34 model + """ + return _tv_resnet( + "resnet34", + pretrained, + tv_resnet34, + ignore_keys=["fc.weight", "fc.bias"], + **kwargs, + ) + + +def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet: + """ResNet-34 architecture as described in `"Deep Residual Learning for Image Recognition", + `_ with twice as many output channels. + + >>> import torch + >>> from doctr.models import resnet34_wide + >>> model = resnet34_wide(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A resnet34_wide model + """ + return _resnet( + "resnet34_wide", + pretrained, + [3, 4, 6, 3], + [128, 256, 512, 1024], + [1, 2, 2, 2], + [False] * 4, + [None] * 4, + origin_stem=True, + stem_channels=128, + ignore_keys=["10.weight", "10.bias"], + **kwargs, + ) + + +def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet: + """ResNet-50 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import torch + >>> from doctr.models import resnet50 + >>> model = resnet50(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A resnet50 model + """ + return _tv_resnet( + "resnet50", + pretrained, + tv_resnet50, + ignore_keys=["fc.weight", "fc.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..7648e5f8d064390adae1f3b9b9386913498ef9a5 --- /dev/null +++ b/doctr/models/classification/resnet/tensorflow.py @@ -0,0 +1,395 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "resnet18": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0", + }, + "resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0", + }, + "resnet34": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0", + }, + "resnet50": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0", + }, + "resnet34_wide": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0", + }, +} + + +class ResnetBlock(layers.Layer): + """Implements a resnet31 block with shortcut + + Args: + ---- + conv_shortcut: Use of shortcut + output_channels: number of channels to use in Conv2D + kernel_size: size of square kernels + strides: strides to use in the first convolution of the block + """ + + def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None: + super().__init__(**kwargs) + if conv_shortcut: + self.shortcut = Sequential([ + layers.Conv2D( + filters=output_channels, + strides=strides, + padding="same", + kernel_size=1, + use_bias=False, + kernel_initializer="he_normal", + ), + layers.BatchNormalization(), + ]) + else: + self.shortcut = layers.Lambda(lambda x: x) + self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides)) + self.act = layers.Activation("relu") + + @staticmethod + def conv_resnetblock( + output_channels: int, + kernel_size: int, + strides: int = 1, + ) -> List[layers.Layer]: + return [ + *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size), + *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size), + ] + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + clone = self.shortcut(inputs) + conv_out = self.conv_block(inputs) + out = self.act(clone + conv_out) + + return out + + +def resnet_stage( + num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False +) -> List[layers.Layer]: + _layers: List[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)] + + for _ in range(1, num_blocks): + _layers.append(ResnetBlock(out_channels, conv_shortcut=False)) + + return _layers + + +class ResNet(Sequential): + """Implements a ResNet architecture + + Args: + ---- + num_blocks: number of resnet block in each stage + output_channels: number of channels in each stage + stage_downsample: whether the first residual block of a stage should downsample + stage_conv: whether to add a conv_sequence after each stage + stage_pooling: pooling to add after each stage (if None, no pooling) + origin_stem: whether to use the orginal ResNet stem or ResNet-31's + stem_channels: number of output channels of the stem convolutions + attn_module: attention module to use in each stage + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + input_shape: shape of inputs + """ + + def __init__( + self, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + stem_channels: int = 64, + attn_module: Optional[Callable[[int], layers.Layer]] = None, + include_top: bool = True, + num_classes: int = 1000, + cfg: Optional[Dict[str, Any]] = None, + input_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + inplanes = stem_channels + if origin_stem: + _layers = [ + *conv_sequence(inplanes, "relu", True, kernel_size=7, strides=2, input_shape=input_shape), + layers.MaxPool2D(pool_size=(3, 3), strides=2, padding="same"), + ] + else: + _layers = [ + *conv_sequence(inplanes // 2, "relu", True, kernel_size=3, input_shape=input_shape), + *conv_sequence(inplanes, "relu", True, kernel_size=3), + layers.MaxPool2D(pool_size=2, strides=2, padding="valid"), + ] + + for n_blocks, out_chan, down, conv, pool in zip( + num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling + ): + _layers.extend(resnet_stage(n_blocks, out_chan, out_chan != inplanes, down)) + if attn_module is not None: + _layers.append(attn_module(out_chan)) + if conv: + _layers.extend(conv_sequence(out_chan, activation="relu", bn=True, kernel_size=3)) + if pool: + _layers.append(layers.MaxPool2D(pool_size=pool, strides=pool, padding="valid")) + inplanes = out_chan + + if include_top: + _layers.extend([ + layers.GlobalAveragePooling2D(), + layers.Dense(num_classes), + ]) + + super().__init__(_layers) + self.cfg = cfg + + +def _resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + **kwargs: Any, +) -> ResNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + _cfg["input_shape"] = kwargs["input_shape"] + kwargs.pop("classes") + + # Build the model + model = ResNet( + num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet-18 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import resnet18 + >>> model = resnet18(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A classification model + """ + return _resnet( + "resnet18", + pretrained, + [2, 2, 2, 2], + [64, 128, 256, 512], + [False, True, True, True], + [False] * 4, + [None] * 4, + True, + **kwargs, + ) + + +def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with rectangular pooling windows as described in + `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition", + `_. Downsizing: (H, W) --> (H/8, W/4) + + >>> import tensorflow as tf + >>> from doctr.models import resnet31 + >>> model = resnet31(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A classification model + """ + return _resnet( + "resnet31", + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [False] * 4, + [True] * 4, + [(2, 2), (2, 1), None, None], + False, + stem_channels=128, + **kwargs, + ) + + +def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import resnet34 + >>> model = resnet34(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A classification model + """ + return _resnet( + "resnet34", + pretrained, + [3, 4, 6, 3], + [64, 128, 256, 512], + [False, True, True, True], + [False] * 4, + [None] * 4, + True, + **kwargs, + ) + + +def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet-50 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + >>> import tensorflow as tf + >>> from doctr.models import resnet50 + >>> model = resnet50(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A classification model + """ + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"]) + + _cfg = deepcopy(default_cfgs["resnet50"]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + _cfg["input_shape"] = kwargs["input_shape"] + kwargs.pop("classes") + + model = ResNet50( + weights=None, + include_top=True, + pooling=True, + input_shape=kwargs["input_shape"], + classes=kwargs["num_classes"], + classifier_activation=None, + ) + + model.cfg = _cfg + + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs["resnet50"]["url"]) + + return model + + +def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition", + `_ with twice as many output channels for each stage. + + >>> import tensorflow as tf + >>> from doctr.models import resnet34_wide + >>> model = resnet34_wide(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the ResNet architecture + + Returns: + ------- + A classification model + """ + return _resnet( + "resnet34_wide", + pretrained, + [3, 4, 6, 3], + [128, 256, 512, 1024], + [False, True, True, True], + [False] * 4, + [None] * 4, + True, + stem_channels=128, + **kwargs, + ) diff --git a/doctr/models/classification/textnet/__init__.py b/doctr/models/classification/textnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/classification/textnet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/textnet/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/textnet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8234234751e774f77d82e54cd8a1635c7adfeb7e Binary files /dev/null and b/doctr/models/classification/textnet/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/textnet/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/textnet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f47e7dae4833517b275c838798ba0f143e60cfd3 Binary files /dev/null and b/doctr/models/classification/textnet/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/textnet/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/textnet/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5916f4ff9aca7670deb080151aa6610c73baccda Binary files /dev/null and b/doctr/models/classification/textnet/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d3c286b663052ebd01b8b26c1dfc8506ce125fe Binary files /dev/null and b/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb27f186dee99c399bfca7f50a8dff4fa5199bdd Binary files /dev/null and b/doctr/models/classification/textnet/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/textnet/pytorch.py b/doctr/models/classification/textnet/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbb719f8bad20653ae3337d000db12f1d73168e --- /dev/null +++ b/doctr/models/classification/textnet/pytorch.py @@ -0,0 +1,275 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +from torch import nn + +from doctr.datasets import VOCABS + +from ...modules.layers.pytorch import FASTConvLayer +from ...utils import conv_sequence_pt, load_pretrained_params + +__all__ = ["textnet_tiny", "textnet_small", "textnet_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "textnet_tiny": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0", + }, + "textnet_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0", + }, + "textnet_base": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0", + }, +} + + +class TextNet(nn.Sequential): + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + Args: + ---- + stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage. + include_top (bool, optional): Whether to include the classifier head. Defaults to True. + num_classes (int, optional): Number of output classes. Defaults to 1000. + cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None. + """ + + def __init__( + self, + stages: List[Dict[str, List[int]]], + input_shape: Tuple[int, int, int] = (3, 32, 32), + num_classes: int = 1000, + include_top: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers: List[nn.Module] = [ + *conv_sequence_pt( + in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1) + ), + *[ + nn.Sequential(*[ + FASTConvLayer(**params) # type: ignore[arg-type] + for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))] + ]) + for stage in stages + ], + ] + + if include_top: + _layers.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(1), + nn.Linear(stages[-1]["out_channels"][-1], num_classes), + ) + ) + + super().__init__(*_layers) + self.cfg = cfg + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def _textnet( + arch: str, + pretrained: bool, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> TextNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = TextNet(**kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + model.cfg = _cfg + + return model + + +def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnet_tiny + >>> model = textnet_tiny(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A textnet tiny model + """ + return _textnet( + "textnet_tiny", + pretrained, + stages=[ + {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]}, + { + "in_channels": [64, 128, 128, 128], + "out_channels": [128] * 4, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)], + "stride": [2, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256], + "out_channels": [256] * 4, + "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)], + "stride": [2, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512], + "out_channels": [512] * 4, + "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)], + "stride": [2, 1, 1, 1], + }, + ], + ignore_keys=["7.2.weight", "7.2.bias"], + **kwargs, + ) + + +def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnet_small + >>> model = textnet_small(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A TextNet small model + """ + return _textnet( + "textnet_small", + pretrained, + stages=[ + {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]}, + { + "in_channels": [64, 128, 128, 128, 128, 128, 128, 128], + "out_channels": [128] * 8, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256, 256, 256, 256, 256], + "out_channels": [256] * 8, + "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512, 512], + "out_channels": [512] * 5, + "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)], + "stride": [2, 1, 1, 1, 1], + }, + ], + ignore_keys=["7.2.weight", "7.2.bias"], + **kwargs, + ) + + +def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnet_base + >>> model = textnet_base(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A TextNet base model + """ + return _textnet( + "textnet_base", + pretrained, + stages=[ + { + "in_channels": [64] * 10, + "out_channels": [64] * 10, + "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)], + "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128], + "out_channels": [128] * 10, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256, 256, 256, 256, 256], + "out_channels": [256] * 8, + "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512, 512], + "out_channels": [512] * 5, + "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)], + "stride": [2, 1, 1, 1, 1], + }, + ], + ignore_keys=["7.2.weight", "7.2.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f30d5d823ce84e358e6d3d497d7091bea9db1b04 --- /dev/null +++ b/doctr/models/classification/textnet/tensorflow.py @@ -0,0 +1,267 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +from tensorflow.keras import Sequential, layers + +from doctr.datasets import VOCABS + +from ...modules.layers.tensorflow import FASTConvLayer +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ["textnet_tiny", "textnet_small", "textnet_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "textnet_tiny": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0", + }, + "textnet_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0", + }, + "textnet_base": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0", + }, +} + + +class TextNet(Sequential): + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + Args: + ---- + stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage. + include_top (bool, optional): Whether to include the classifier head. Defaults to True. + num_classes (int, optional): Number of output classes. Defaults to 1000. + cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None. + """ + + def __init__( + self, + stages: List[Dict[str, List[int]]], + input_shape: Tuple[int, int, int] = (32, 32, 3), + num_classes: int = 1000, + include_top: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers = [ + *conv_sequence( + out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape + ), + *[ + Sequential( + [ + FASTConvLayer(**params) # type: ignore[arg-type] + for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))] + ], + name=f"stage_{i}", + ) + for i, stage in enumerate(stages) + ], + ] + + if include_top: + _layers.append( + Sequential( + [ + layers.AveragePooling2D(1), + layers.Flatten(), + layers.Dense(num_classes), + ], + name="classifier", + ) + ) + + super().__init__(_layers) + self.cfg = cfg + + +def _textnet( + arch: str, + pretrained: bool, + **kwargs: Any, +) -> TextNet: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["input_shape"] = kwargs["input_shape"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = TextNet(cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import tensorflow as tf + >>> from doctr.models import textnet_tiny + >>> model = textnet_tiny(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A textnet tiny model + """ + return _textnet( + "textnet_tiny", + pretrained, + stages=[ + {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]}, + { + "in_channels": [64, 128, 128, 128], + "out_channels": [128] * 4, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)], + "stride": [2, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256], + "out_channels": [256] * 4, + "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)], + "stride": [2, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512], + "out_channels": [512] * 4, + "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)], + "stride": [2, 1, 1, 1], + }, + ], + **kwargs, + ) + + +def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import tensorflow as tf + >>> from doctr.models import textnet_small + >>> model = textnet_small(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A TextNet small model + """ + return _textnet( + "textnet_small", + pretrained, + stages=[ + {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]}, + { + "in_channels": [64, 128, 128, 128, 128, 128, 128, 128], + "out_channels": [128] * 8, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256, 256, 256, 256, 256], + "out_channels": [256] * 8, + "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512, 512], + "out_channels": [512] * 5, + "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)], + "stride": [2, 1, 1, 1, 1], + }, + ], + **kwargs, + ) + + +def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import tensorflow as tf + >>> from doctr.models import textnet_base + >>> model = textnet_base(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TextNet architecture + + Returns: + ------- + A TextNet base model + """ + return _textnet( + "textnet_base", + pretrained, + stages=[ + { + "in_channels": [64] * 10, + "out_channels": [64] * 10, + "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)], + "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128], + "out_channels": [128] * 10, + "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [128, 256, 256, 256, 256, 256, 256, 256], + "out_channels": [256] * 8, + "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)], + "stride": [2, 1, 1, 1, 1, 1, 1, 1], + }, + { + "in_channels": [256, 512, 512, 512, 512], + "out_channels": [512] * 5, + "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)], + "stride": [2, 1, 1, 1, 1], + }, + ], + **kwargs, + ) diff --git a/doctr/models/classification/vgg/__init__.py b/doctr/models/classification/vgg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64556e403a5697432f805a5af28dab812fa8b932 --- /dev/null +++ b/doctr/models/classification/vgg/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/models/classification/vgg/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/vgg/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9afb00c6dc1f648e82d797eb8ea1725806fb34e Binary files /dev/null and b/doctr/models/classification/vgg/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/vgg/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/vgg/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3965bcd4508868c1890448e27ad07b34035024 Binary files /dev/null and b/doctr/models/classification/vgg/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/vgg/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/vgg/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537c884c4cedac8026ec8b69d09056fe61191bc5 Binary files /dev/null and b/doctr/models/classification/vgg/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2be84e733637ff9e36f9672e01d7520fbc4cf6d7 Binary files /dev/null and b/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f30f87a361cf6210349ba864c34aed386abfa3e5 Binary files /dev/null and b/doctr/models/classification/vgg/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/vgg/pytorch.py b/doctr/models/classification/vgg/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..2e16b1178841f3096daa9acd7c920fea4a178c7f --- /dev/null +++ b/doctr/models/classification/vgg/pytorch.py @@ -0,0 +1,95 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional + +from torch import nn +from torchvision.models import vgg as tv_vgg + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params + +__all__ = ["vgg16_bn_r"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vgg16_bn_r": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-d108c19c.pt&src=0", + }, +} + + +def _vgg( + arch: str, + pretrained: bool, + tv_arch: str, + num_rect_pools: int = 3, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> tv_vgg.VGG: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None) + # List the MaxPool2d + pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)] + # Replace their kernel with rectangular ones + for idx in pool_idcs[-num_rect_pools:]: + model.features[idx] = nn.MaxPool2d((2, 1)) + # Patch average pool & classification head + model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + model.classifier = nn.Linear(512, kwargs["num_classes"]) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + model.cfg = _cfg + + return model + + +def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG: + """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_, modified by adding batch normalization, rectangular pooling and a simpler + classification head. + + >>> import torch + >>> from doctr.models import vgg16_bn_r + >>> model = vgg16_bn_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on ImageNet + **kwargs: keyword arguments of the VGG architecture + + Returns: + ------- + VGG feature extractor + """ + return _vgg( + "vgg16_bn_r", + pretrained, + "vgg16_bn", + 3, + ignore_keys=["classifier.weight", "classifier.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..259ed9f88875fb00c6c006a314f0210e77857bd8 --- /dev/null +++ b/doctr/models/classification/vgg/tensorflow.py @@ -0,0 +1,113 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ["VGG", "vgg16_bn_r"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vgg16_bn_r": { + "mean": (0.5, 0.5, 0.5), + "std": (1.0, 1.0, 1.0), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0", + }, +} + + +class VGG(Sequential): + """Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_. + + Args: + ---- + num_blocks: number of convolutional block in each stage + planes: number of output channels in each stage + rect_pools: whether pooling square kernels should be replace with rectangular ones + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + input_shape: shapes of the input tensor + """ + + def __init__( + self, + num_blocks: List[int], + planes: List[int], + rect_pools: List[bool], + include_top: bool = False, + num_classes: int = 1000, + input_shape: Optional[Tuple[int, int, int]] = None, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers = [] + # Specify input_shape only for the first layer + kwargs = {"input_shape": input_shape} + for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools): + for _ in range(nb_blocks): + _layers.extend(conv_sequence(out_chan, "relu", True, kernel_size=3, **kwargs)) # type: ignore[arg-type] + kwargs = {} + _layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2))) + + if include_top: + _layers.extend([layers.GlobalAveragePooling2D(), layers.Dense(num_classes)]) + super().__init__(_layers) + self.cfg = cfg + + +def _vgg( + arch: str, pretrained: bool, num_blocks: List[int], planes: List[int], rect_pools: List[bool], **kwargs: Any +) -> VGG: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + _cfg["input_shape"] = kwargs["input_shape"] + kwargs.pop("classes") + + # Build the model + model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG: + """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_, modified by adding batch normalization, rectangular pooling and a simpler + classification head. + + >>> import tensorflow as tf + >>> from doctr.models import vgg16_bn_r + >>> model = vgg16_bn_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on ImageNet + **kwargs: keyword arguments of the VGG architecture + + Returns: + ------- + VGG feature extractor + """ + return _vgg( + "vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs + ) diff --git a/doctr/models/classification/vit/__init__.py b/doctr/models/classification/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/classification/vit/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/vit/__pycache__/__init__.cpython-311.pyc b/doctr/models/classification/vit/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90035e6ee01d1bfc281097bcbe498a20f4fa36cc Binary files /dev/null and b/doctr/models/classification/vit/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/classification/vit/__pycache__/__init__.cpython-38.pyc b/doctr/models/classification/vit/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b80603d9e3e6fa230c29e465ca807f6a254833 Binary files /dev/null and b/doctr/models/classification/vit/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/classification/vit/__pycache__/pytorch.cpython-311.pyc b/doctr/models/classification/vit/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ea387a3ffb58c51706133ccabff91aed4c04d51 Binary files /dev/null and b/doctr/models/classification/vit/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/classification/vit/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/classification/vit/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..493d7f6cb5de935be8579274d32be500a325babd Binary files /dev/null and b/doctr/models/classification/vit/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/classification/vit/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/classification/vit/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b51bfb889ad89ddbc6938838dc0c77fdad77fe Binary files /dev/null and b/doctr/models/classification/vit/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/classification/vit/pytorch.py b/doctr/models/classification/vit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..335e92559fd9fa44de136323a6b50456eb605c7a --- /dev/null +++ b/doctr/models/classification/vit/pytorch.py @@ -0,0 +1,195 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from doctr.datasets import VOCABS +from doctr.models.modules.transformer import EncoderBlock +from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding + +from ...utils.pytorch import load_pretrained_params + +__all__ = ["vit_s", "vit_b"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vit_s": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-5d05442d.pt&src=0", + }, + "vit_b": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-0fbef167.pt&src=0", + }, +} + + +class ClassifierHead(nn.Module): + """Classifier head for Vision Transformer + + Args: + ---- + in_channels: number of input channels + num_classes: number of output classes + """ + + def __init__( + self, + in_channels: int, + num_classes: int, + ) -> None: + super().__init__() + + self.head = nn.Linear(in_channels, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # (batch_size, num_classes) cls token + return self.head(x[:, 0]) + + +class VisionTransformer(nn.Sequential): + """VisionTransformer architecture as described in + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. + + Args: + ---- + d_model: dimension of the transformer layers + num_layers: number of transformer layers + num_heads: number of attention heads + ffd_ratio: multiplier for the hidden dimension of the feedforward layer + patch_size: size of the patches + input_shape: size of the input image + dropout: dropout rate + num_classes: number of output classes + include_top: whether the classifier head should be instantiated + """ + + def __init__( + self, + d_model: int, + num_layers: int, + num_heads: int, + ffd_ratio: int, + patch_size: Tuple[int, int] = (4, 4), + input_shape: Tuple[int, int, int] = (3, 32, 32), + dropout: float = 0.0, + num_classes: int = 1000, + include_top: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers: List[nn.Module] = [ + PatchEmbedding(input_shape, d_model, patch_size), + EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()), + ] + if include_top: + _layers.append(ClassifierHead(d_model, num_classes)) + + super().__init__(*_layers) + self.cfg = cfg + + +def _vit( + arch: str, + pretrained: bool, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> VisionTransformer: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["input_shape"] = kwargs["input_shape"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = VisionTransformer(cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: + """VisionTransformer-S architecture + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. Patches: (H, W) -> (H/8, W/8) + + NOTE: unofficial config used in ViTSTR and ParSeq + + >>> import torch + >>> from doctr.models import vit_s + >>> model = vit_s(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the VisionTransformer architecture + + Returns: + ------- + A feature extractor model + """ + return _vit( + "vit_s", + pretrained, + d_model=384, + num_layers=12, + num_heads=6, + ffd_ratio=4, + ignore_keys=["2.head.weight", "2.head.bias"], + **kwargs, + ) + + +def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: + """VisionTransformer-B architecture as described in + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. Patches: (H, W) -> (H/8, W/8) + + >>> import torch + >>> from doctr.models import vit_b + >>> model = vit_b(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the VisionTransformer architecture + + Returns: + ------- + A feature extractor model + """ + return _vit( + "vit_b", + pretrained, + d_model=768, + num_layers=12, + num_heads=12, + ffd_ratio=4, + ignore_keys=["2.head.weight", "2.head.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..4b73b49ac9c11dd430fe7b8ca1f612b745122f14 --- /dev/null +++ b/doctr/models/classification/vit/tensorflow.py @@ -0,0 +1,192 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Sequential, layers + +from doctr.datasets import VOCABS +from doctr.models.modules.transformer import EncoderBlock +from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding +from doctr.utils.repr import NestedObject + +from ...utils import load_pretrained_params + +__all__ = ["vit_s", "vit_b"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vit_s": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0", + }, + "vit_b": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 32, 3), + "classes": list(VOCABS["french"]), + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0", + }, +} + + +class ClassifierHead(layers.Layer, NestedObject): + """Classifier head for Vision Transformer + + Args: + ---- + num_classes: number of output classes + """ + + def __init__(self, num_classes: int) -> None: + super().__init__() + + self.head = layers.Dense(num_classes, kernel_initializer="he_normal", name="dense") + + def call(self, x: tf.Tensor) -> tf.Tensor: + # (batch_size, num_classes) cls token + return self.head(x[:, 0]) + + +class VisionTransformer(Sequential): + """VisionTransformer architecture as described in + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. + + Args: + ---- + d_model: dimension of the transformer layers + num_layers: number of transformer layers + num_heads: number of attention heads + ffd_ratio: multiplier for the hidden dimension of the feedforward layer + patch_size: size of the patches + input_shape: size of the input image + dropout: dropout rate + num_classes: number of output classes + include_top: whether the classifier head should be instantiated + """ + + def __init__( + self, + d_model: int, + num_layers: int, + num_heads: int, + ffd_ratio: int, + patch_size: Tuple[int, int] = (4, 4), + input_shape: Tuple[int, int, int] = (32, 32, 3), + dropout: float = 0.0, + num_classes: int = 1000, + include_top: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers = [ + PatchEmbedding(input_shape, d_model, patch_size), + EncoderBlock( + num_layers, + num_heads, + d_model, + d_model * ffd_ratio, + dropout, + activation_fct=layers.Activation("gelu"), + ), + ] + if include_top: + _layers.append(ClassifierHead(num_classes)) + + super().__init__(_layers) + self.cfg = cfg + + +def _vit( + arch: str, + pretrained: bool, + **kwargs: Any, +) -> VisionTransformer: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["input_shape"] = kwargs["input_shape"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = VisionTransformer(cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: + """VisionTransformer-S architecture + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. Patches: (H, W) -> (H/8, W/8) + + NOTE: unofficial config used in ViTSTR and ParSeq + + >>> import tensorflow as tf + >>> from doctr.models import vit_s + >>> model = vit_s(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the VisionTransformer architecture + + Returns: + ------- + A feature extractor model + """ + return _vit( + "vit_s", + pretrained, + d_model=384, + num_layers=12, + num_heads=6, + ffd_ratio=4, + **kwargs, + ) + + +def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: + """VisionTransformer-B architecture as described in + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", + `_. Patches: (H, W) -> (H/8, W/8) + + >>> import tensorflow as tf + >>> from doctr.models import vit_b + >>> model = vit_b(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the VisionTransformer architecture + + Returns: + ------- + A feature extractor model + """ + return _vit( + "vit_b", + pretrained, + d_model=768, + num_layers=12, + num_heads=12, + ffd_ratio=4, + **kwargs, + ) diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..9368bb225d85ccc1a04190fb599b439497dc14f2 --- /dev/null +++ b/doctr/models/classification/zoo.py @@ -0,0 +1,98 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +from doctr.file_utils import is_tf_available + +from .. import classification +from ..preprocessor import PreProcessor +from .predictor import OrientationPredictor + +__all__ = ["crop_orientation_predictor", "page_orientation_predictor"] + +ARCHS: List[str] = [ + "magc_resnet31", + "mobilenet_v3_small", + "mobilenet_v3_small_r", + "mobilenet_v3_large", + "mobilenet_v3_large_r", + "resnet18", + "resnet31", + "resnet34", + "resnet50", + "resnet34_wide", + "textnet_tiny", + "textnet_small", + "textnet_base", + "vgg16_bn_r", + "vit_s", + "vit_b", +] +ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] + + +def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor: + if arch not in ORIENTATION_ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + # Load directly classifier from backbone + _model = classification.__dict__[arch](pretrained=pretrained) + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] + predictor = OrientationPredictor( + PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model + ) + return predictor + + +def crop_orientation_predictor( + arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any +) -> OrientationPredictor: + """Crop orientation classification architecture. + + >>> import numpy as np + >>> from doctr.models import crop_orientation_predictor + >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True) + >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8) + >>> out = model([input_crop]) + + Args: + ---- + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') + pretrained: If True, returns a model pre-trained on our recognition crops dataset + **kwargs: keyword arguments to be passed to the OrientationPredictor + + Returns: + ------- + OrientationPredictor + """ + return _orientation_predictor(arch, pretrained, **kwargs) + + +def page_orientation_predictor( + arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any +) -> OrientationPredictor: + """Page orientation classification architecture. + + >>> import numpy as np + >>> from doctr.models import page_orientation_predictor + >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True) + >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') + pretrained: If True, returns a model pre-trained on our recognition crops dataset + **kwargs: keyword arguments to be passed to the OrientationPredictor + + Returns: + ------- + OrientationPredictor + """ + return _orientation_predictor(arch, pretrained, **kwargs) diff --git a/doctr/models/core.py b/doctr/models/core.py new file mode 100644 index 0000000000000000000000000000000000000000..a05aee7aa9f22dbe2a56511699e9977129f1bd99 --- /dev/null +++ b/doctr/models/core.py @@ -0,0 +1,19 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, Dict, Optional + +from doctr.utils.repr import NestedObject + +__all__ = ["BaseModel"] + + +class BaseModel(NestedObject): + """Implements abstract DetectionModel class""" + + def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + self.cfg = cfg diff --git a/doctr/models/detection/__init__.py b/doctr/models/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b09e4395eb6d15a992960894fddfb582dbbd64db --- /dev/null +++ b/doctr/models/detection/__init__.py @@ -0,0 +1,4 @@ +from .differentiable_binarization import * +from .linknet import * +from .fast import * +from .zoo import * diff --git a/doctr/models/detection/__pycache__/__init__.cpython-311.pyc b/doctr/models/detection/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab5e42e4da73decd8a6394888b8cffb4cdc4ea8 Binary files /dev/null and b/doctr/models/detection/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/detection/__pycache__/__init__.cpython-38.pyc b/doctr/models/detection/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab585509e5a3e5add389a900c457fb895640491 Binary files /dev/null and b/doctr/models/detection/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/detection/__pycache__/core.cpython-311.pyc b/doctr/models/detection/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93714a36edcfc5d6911af3c4d8145b7419357d3 Binary files /dev/null and b/doctr/models/detection/__pycache__/core.cpython-311.pyc differ diff --git a/doctr/models/detection/__pycache__/core.cpython-38.pyc b/doctr/models/detection/__pycache__/core.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac91eb0517fd56c698b2ce7c56b654373b73f37e Binary files /dev/null and b/doctr/models/detection/__pycache__/core.cpython-38.pyc differ diff --git a/doctr/models/detection/__pycache__/zoo.cpython-311.pyc b/doctr/models/detection/__pycache__/zoo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6188a41c2dd8c09d2f5e9fb604ba0c54f9524e93 Binary files /dev/null and b/doctr/models/detection/__pycache__/zoo.cpython-311.pyc differ diff --git a/doctr/models/detection/__pycache__/zoo.cpython-38.pyc b/doctr/models/detection/__pycache__/zoo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d77cfc00c11adfa7f2323104063b0481733e5a Binary files /dev/null and b/doctr/models/detection/__pycache__/zoo.cpython-38.pyc differ diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2bfef5e652335569d8da45965b4c64fe56c141 --- /dev/null +++ b/doctr/models/detection/_utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac99f4690993c1d4ec6d3563cecfccb07064c85 --- /dev/null +++ b/doctr/models/detection/_utils/pytorch.py @@ -0,0 +1,43 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from torch import Tensor +from torch.nn.functional import max_pool2d + +__all__ = ["erode", "dilate"] + + +def erode(x: Tensor, kernel_size: int) -> Tensor: + """Performs erosion on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for erosion + + Returns: + ------- + the eroded tensor + """ + _pad = (kernel_size - 1) // 2 + + return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) + + +def dilate(x: Tensor, kernel_size: int) -> Tensor: + """Performs dilation on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for dilation + + Returns: + ------- + the dilated tensor + """ + _pad = (kernel_size - 1) // 2 + + return max_pool2d(x, kernel_size, stride=1, padding=_pad) diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5ec217493d0c1c6d96b834a83e45af65d8481e --- /dev/null +++ b/doctr/models/detection/_utils/tensorflow.py @@ -0,0 +1,38 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +__all__ = ["erode", "dilate"] + + +def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs erosion on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for erosion + + Returns: + ------- + the eroded tensor + """ + return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") + + +def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs dilation on a given tensor + + Args: + ---- + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for dilation + + Returns: + ------- + the dilated tensor + """ + return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py new file mode 100644 index 0000000000000000000000000000000000000000..63fa78615162ceca0dd11d71a64dc2c8edee4af5 --- /dev/null +++ b/doctr/models/detection/core.py @@ -0,0 +1,101 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List + +import cv2 +import numpy as np + +from doctr.utils.repr import NestedObject + +__all__ = ["DetectionPostProcessor"] + + +class DetectionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + box_thresh (float): minimal objectness score to consider a box + bin_thresh (float): threshold to apply to segmentation raw heatmap + assume straight_pages (bool): if True, fit straight boxes only + """ + + def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None: + self.box_thresh = box_thresh + self.bin_thresh = bin_thresh + self.assume_straight_pages = assume_straight_pages + self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8) + + def extra_repr(self) -> str: + return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}" + + @staticmethod + def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float: + """Compute the confidence score for a polygon : mean of the p values on the polygon + + Args: + ---- + pred (np.ndarray): p map returned by the model + points: coordinates of the polygon + assume_straight_pages: if True, fit straight boxes only + + Returns: + ------- + polygon objectness + """ + h, w = pred.shape[:2] + + if assume_straight_pages: + xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1) + return pred[ymin : ymax + 1, xmin : xmax + 1].mean() + + else: + mask: np.ndarray = np.zeros((h, w), np.int32) + cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] + product = pred * mask + return np.sum(product) / np.count_nonzero(product) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + raise NotImplementedError + + def __call__( + self, + proba_map, + ) -> List[List[np.ndarray]]: + """Performs postprocessing for a list of model outputs + + Args: + ---- + proba_map: probability map of shape (N, H, W, C) + + Returns: + ------- + list of N class predictions (for each input sample), where each class predictions is a list of C tensors + of shape (*, 5) or (*, 6) + """ + if proba_map.ndim != 4: + raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") + + # Erosion + dilation on the binary map + bin_map = [ + [ + cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) + for idx in range(proba_map.shape[-1]) + ] + for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) + ] + + return [ + [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])] + for pmaps, bmaps in zip(proba_map, bin_map) + ] diff --git a/doctr/models/detection/differentiable_binarization/__init__.py b/doctr/models/detection/differentiable_binarization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-311.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843c3359b3d91b0ee2fd658a91a0a0514119074c Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-38.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb57913b20a95d1d4787fba01cae0b7177241947 Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-311.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c7cc12c1971abd55901d0cf0fc27745022c5b90 Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-38.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f57dc706d4aea4206701aa93ec314bb6222d246 Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/pytorch.cpython-311.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfdff5d1eba776ddbcd0ec27adb24db784252612 Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5db0606c601f8e3d84311a9c0352327c911f7d43 Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6129bfbc8f1a55f46caca081a0c3a2dee16407b Binary files /dev/null and b/doctr/models/detection/differentiable_binarization/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5f03a2e1bf52861f8e237e9ac63b007f93dd379e --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -0,0 +1,375 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from ..core import DetectionPostProcessor + +__all__ = ["DBPostProcessor"] + + +class DBPostProcessor(DetectionPostProcessor): + """Implements a post processor for DBNet adapted from the implementation of `xuannianz + `_. + + Args: + ---- + unclip ratio: ratio used to unshrink polygons + min_size_box: minimal length (pix) to keep a box + max_candidates: maximum boxes to consider in a single page + box_thresh: minimal objectness score to consider a box + bin_thresh: threshold used to binzarized p_map at inference time + + """ + + def __init__( + self, + box_thresh: float = 0.1, + bin_thresh: float = 0.3, + assume_straight_pages: bool = True, + ) -> None: + super().__init__(box_thresh, bin_thresh, assume_straight_pages) + self.unclip_ratio = 1.5 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + ---- + points: The first parameter. + + Returns: + ------- + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points: np.ndarray = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None # type: ignore[return-value] + return ( + cv2.boundingRect(expanded_points) # type: ignore[return-value] + if self.assume_straight_pages + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + ---- + pred: Pred map from differentiable binarization output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + ------- + np tensor boxes for the bitmap, each box is a 5-element list + containing x, y, w, h, score for the box + """ + height, width = bitmap.shape[:2] + min_size_box = 2 + boxes: List[Union[np.ndarray, List[float]]] = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + # Remove too small boxes + if self.assume_straight_pages: + if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: + continue + elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: + continue + + if self.assume_straight_pages: + x, y, w, h = _box + # compute relative polygon to get rid of img shape + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape, in that case _box is a 4pt polygon + if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): + raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) + + +class _DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + fpn_channels: number of channels each extracted feature maps is mapped to + """ + + shrink_ratio = 0.4 + thresh_min = 0.3 + thresh_max = 0.7 + min_size_box = 3 + assume_straight_pages: bool = True + + @staticmethod + def compute_distance( + xs: np.ndarray, + ys: np.ndarray, + a: np.ndarray, + b: np.ndarray, + eps: float = 1e-6, + ) -> float: + """Compute the distance for each point of the map (xs, ys) to the (a, b) segment + + Args: + ---- + xs : map of x coordinates (height, width) + ys : map of y coordinates (height, width) + a: first point defining the [ab] segment + b: second point defining the [ab] segment + eps: epsilon to avoid division by zero + + Returns: + ------- + The computed distance + + """ + square_dist_1 = np.square(xs - a[0]) + np.square(ys - a[1]) + square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1]) + square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1]) + cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps) + cosin = np.clip(cosin, -1.0, 1.0) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps) + result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0] + return result + + def draw_thresh_map( + self, + polygon: np.ndarray, + canvas: np.ndarray, + mask: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Draw a polygon treshold map on a canvas, as described in the DB paper + + Args: + ---- + polygon : array of coord., to draw the boundary of the polygon + canvas : threshold map to fill with polygons + mask : mask for training on threshold polygons + """ + if polygon.ndim != 2 or polygon.shape[1] != 2: + raise AttributeError("polygon should be a 2 dimensional array of coords") + + # Augment polygon by shrink_ratio + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(coor) for coor in polygon] # Get coord as list of tuples + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0]) + + # Fill the mask with 1 on the new padded polygon + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload] + + # Get min/max to recover polygon after distance computation + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + # Get absolute polygon for distance computation + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + # Get absolute padded polygon + xs: np.ndarray = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) + ys: np.ndarray = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) + + # Compute distance map to fill the padded polygon + distance_map = np.zeros((polygon.shape[0], height, width), dtype=polygon.dtype) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.compute_distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = np.min(distance_map, axis=0) + + # Clip the padded polygon inside the canvas + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + + # Fill the canvas with the distances computed inside the valid padded polygon + canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax( + 1 + - distance_map[ + ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width + ], + canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1], + ) + + return polygon, canvas, mask + + def build_target( + self, + target: List[Dict[str, np.ndarray]], + output_shape: Tuple[int, int, int], + channels_last: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): + raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): + raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32 + + h: int + w: int + if channels_last: + h, w, num_classes = output_shape + else: + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) + + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) + thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) + thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + + for idx, tgt in enumerate(target): + for class_idx, _tgt in enumerate(tgt.values()): + # Draw each polygon on gt + if _tgt.shape[0] == 0: + # Empty image, full masked + seg_mask[idx, class_idx] = False + + # Absolute bounding boxes + abs_boxes = _tgt.copy() + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack( + [ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], + axis=1, + ) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrunken = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrunken) == 0: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + + # Draw on both thresh map and thresh mask + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] + ) + if channels_last: + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) + thresh_target = thresh_target.transpose((0, 2, 3, 1)) + thresh_mask = thresh_mask.transpose((0, 2, 3, 1)) + + thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min + + seg_target = seg_target.astype(input_dtype) + seg_mask = seg_mask.astype(bool) + thresh_target = thresh_target.astype(input_dtype) + thresh_mask = thresh_mask.astype(bool) + + return seg_target, seg_mask, thresh_target, thresh_mask diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7bdd8efe204a22476eb4048108a2d6c6c459ea --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -0,0 +1,435 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models import resnet34, resnet50 +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.deform_conv import DeformConv2d + +from doctr.file_utils import CLASS_NAME + +from ...classification import mobilenet_v3_large +from ...utils import _bf16_to_float32, load_pretrained_params +from .base import DBPostProcessor, _DBNet + +__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "db_resnet50": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0", + }, + "db_resnet34": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet34-cb6aed9e.pt&src=0", + }, + "db_mobilenet_v3_large": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0", + }, +} + + +class FeaturePyramidNetwork(nn.Module): + def __init__( + self, + in_channels: List[int], + out_channels: int, + deform_conv: bool = False, + ) -> None: + super().__init__() + + out_chans = out_channels // len(in_channels) + + conv_layer = DeformConv2d if deform_conv else nn.Conv2d + + self.in_branches = nn.ModuleList([ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + for idx, chans in enumerate(in_channels) + ]) + self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.out_branches = nn.ModuleList([ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), + ) + for idx, chans in enumerate(in_channels) + ]) + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + if len(x) != len(self.out_branches): + raise AssertionError + # Conv1x1 to get the same number of channels + _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)] + out: List[torch.Tensor] = [_x[-1]] + for t in _x[:-1][::-1]: + out.append(self.upsample(out[-1]) + t) + + # Conv and final upsampling + out = [branch(t) for branch, t in zip(self.out_branches, out[::-1])] + + return torch.cat(out, dim=1) + + +class DBNet(_DBNet, nn.Module): + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + head_chans: the number of channels in the head + deform_conv: whether to use deformable convolution + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + head_chans: int = 256, + deform_conv: bool = False, + bin_thresh: float = 0.3, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) + self.cfg = cfg + + conv_layer = DeformConv2d if deform_conv else nn.Conv2d + + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + # Identify the number of channels for the head initialization + _is_training = self.feat_extractor.training + self.feat_extractor = self.feat_extractor.eval() + with torch.no_grad(): + out = self.feat_extractor(torch.zeros((1, 3, 224, 224))) + fpn_channels = [v.shape[1] for _, v in out.items()] + + if _is_training: + self.feat_extractor = self.feat_extractor.train() + + self.fpn = FeaturePyramidNetwork(fpn_channels, head_chans, deform_conv) + # Conv1 map to channels + + self.prob_head = nn.Sequential( + conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), + ) + self.thresh_head = nn.Sequential( + conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), + ) + + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, (nn.Conv2d, DeformConv2d)): + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, torch.Tensor]: + # Extract feature maps at different stages + feats = self.feat_extractor(x) + feats = [feats[str(idx)] for idx in range(len(feats))] + # Pass through the FPN + feat_concat = self.fpn(feats) + logits = self.prob_head(feat_concat) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(torch.sigmoid(logits)) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + if target is not None: + thresh_map = self.thresh_head(feat_concat) + loss = self.compute_loss(logits, thresh_map, target) + out["loss"] = loss + + return out + + def compute_loss( + self, + out_map: torch.Tensor, + thresh_map: torch.Tensor, + target: List[np.ndarray], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> torch.Tensor: + """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes + and a list of masks for each image. From there it computes the loss with the model output + + Args: + ---- + out_map: output feature map of the model of shape (N, C, H, W) + thresh_map: threshold map of shape (N, C, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + + prob_map = torch.sigmoid(out_map) + thresh_map = torch.sigmoid(thresh_map) + + targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] + + seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) + seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) + thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) + thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) + + if torch.any(seg_mask): + # Focal loss + focal_scale = 10.0 + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + + p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target) + alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) + # Unreduced version + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3)) + + # Compute dice loss for each class or for approx binary_map + if len(self.class_names) > 1: + dice_map = torch.softmax(out_map, dim=1) + else: + # compute binary map instead + dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) + # Class reduced + inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3)) + cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3)) + dice_loss = (1 - 2 * inter / (cardinality + eps)).mean() + + # Compute l1 loss for thresh_map + if torch.any(thresh_mask): + l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) + + return l1_loss + focal_scale * focal_loss + dice_loss + + +def _dbnet( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + fpn_layers: List[str], + backbone_submodule: Optional[str] = None, + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> DBNet: + pretrained_backbone = pretrained_backbone and not pretrained + + # Feature extractor + backbone = ( + backbone_fn(pretrained_backbone) + if not arch.split("_")[1].startswith("resnet") + # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50 + else backbone_fn(weights=None) # type: ignore[call-arg] + ) + if isinstance(backbone_submodule, str): + backbone = getattr(backbone, backbone_submodule) + feat_extractor = IntermediateLayerGetter( + backbone, + {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, + ) + + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + # Build the model + model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-34 backbone. + + >>> import torch + >>> from doctr.models import db_resnet34 + >>> model = db_resnet34(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet( + "db_resnet34", + pretrained, + resnet34, + ["layer1", "layer2", "layer3", "layer4"], + None, + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], + **kwargs, + ) + + +def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + + >>> import torch + >>> from doctr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet( + "db_resnet50", + pretrained, + resnet50, + ["layer1", "layer2", "layer3", "layer4"], + None, + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], + **kwargs, + ) + + +def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a MobileNet V3 Large backbone. + + >>> import torch + >>> from doctr.models import db_mobilenet_v3_large + >>> model = db_mobilenet_v3_large(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _dbnet( + "db_mobilenet_v3_large", + pretrained, + mobilenet_v3_large, + ["3", "6", "12", "16"], + "features", + ignore_keys=[ + "prob_head.6.weight", + "prob_head.6.bias", + "thresh_head.6.weight", + "thresh_head.6.bias", + ], + **kwargs, + ) diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..df9935b04259ce7912459cda4e3e046589dbeb45 --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -0,0 +1,402 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras.applications import ResNet50 + +from doctr.file_utils import CLASS_NAME +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.utils.repr import NestedObject + +from ...classification import mobilenet_v3_large +from .base import DBPostProcessor, _DBNet + +__all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "db_resnet50": { + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "input_shape": (1024, 1024, 3), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0", + }, + "db_mobilenet_v3_large": { + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "input_shape": (1024, 1024, 3), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0", + }, +} + + +class FeaturePyramidNetwork(layers.Layer, NestedObject): + """Feature Pyramid Network as described in `"Feature Pyramid Networks for Object Detection" + `_. + + Args: + ---- + channels: number of channel to output + """ + + def __init__( + self, + channels: int, + ) -> None: + super().__init__() + self.channels = channels + self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="nearest") + self.inner_blocks = [layers.Conv2D(channels, 1, strides=1, kernel_initializer="he_normal") for _ in range(4)] + self.layer_blocks = [self.build_upsampling(channels, dilation_factor=2**idx) for idx in range(4)] + + @staticmethod + def build_upsampling( + channels: int, + dilation_factor: int = 1, + ) -> layers.Layer: + """Module which performs a 3x3 convolution followed by up-sampling + + Args: + ---- + channels: number of output channels + dilation_factor (int): dilation factor to scale the convolution output before concatenation + + Returns: + ------- + a keras.layers.Layer object, wrapping these operations in a sequential module + + """ + _layers = conv_sequence(channels, "relu", True, kernel_size=3) + + if dilation_factor > 1: + _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest")) + + module = keras.Sequential(_layers) + + return module + + def extra_repr(self) -> str: + return f"channels={self.channels}" + + def call( + self, + x: List[tf.Tensor], + **kwargs: Any, + ) -> tf.Tensor: + # Channel mapping + results = [block(fmap, **kwargs) for block, fmap in zip(self.inner_blocks, x)] + # Upsample & sum + for idx in range(len(results) - 1, -1): + results[idx] += self.upsample(results[idx + 1]) + # Conv & upsample + results = [block(fmap, **kwargs) for block, fmap in zip(self.layer_blocks, results)] + + return layers.concatenate(results) + + +class DBNet(_DBNet, keras.Model, NestedObject): + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + fpn_channels: number of channels each extracted feature maps is mapped to + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"] + + def __init__( + self, + feature_extractor: IntermediateLayerGetter, + fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea + bin_thresh: float = 0.3, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) + self.cfg = cfg + + self.feat_extractor = feature_extractor + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.fpn = FeaturePyramidNetwork(channels=fpn_channels) + # Initialize kernels + _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] + output_shape = tuple(self.fpn(_inputs).shape) + + self.probability_head = keras.Sequential([ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ]) + self.threshold_head = keras.Sequential([ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ]) + + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + def compute_loss( + self, + out_map: tf.Tensor, + thresh_map: tf.Tensor, + target: List[Dict[str, np.ndarray]], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> tf.Tensor: + """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes + and a list of masks for each image. From there it computes the loss with the model output + + Args: + ---- + out_map: output feature map of the model of shape (N, H, W, C) + thresh_map: threshold map of shape (N, H, W, C) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + + prob_map = tf.math.sigmoid(out_map) + thresh_map = tf.math.sigmoid(thresh_map) + + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) + thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) + thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) + + # Focal loss + focal_scale = 10.0 + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + + # Convert logits to prob, compute gamma factor + p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Compute dice loss for each class or for approx binary_map + if len(self.class_names) > 1: + dice_map = tf.nn.softmax(out_map, axis=-1) + else: + # compute binary map instead + dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map))) + # Class-reduced dice loss + inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2]) + cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2]) + dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps)) + + # Compute l1 loss for thresh_map + if tf.reduce_any(thresh_mask): + thresh_mask = tf.cast(thresh_mask, tf.float32) + l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / ( + tf.reduce_sum(thresh_mask) + eps + ) + else: + l1_loss = tf.constant(0.0) + + return l1_loss + focal_scale * focal_loss + dice_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[Dict[str, np.ndarray]]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + feat_maps = self.feat_extractor(x, **kwargs) + feat_concat = self.fpn(feat_maps, **kwargs) + logits = self.probability_head(feat_concat, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(tf.math.sigmoid(logits)) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())] + + if target is not None: + thresh_map = self.threshold_head(feat_concat, **kwargs) + loss = self.compute_loss(logits, thresh_map, target) + out["loss"] = loss + + return out + + +def _db_resnet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> DBNet: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + weights="imagenet" if pretrained_backbone else None, + include_top=False, + pooling=None, + input_shape=_cfg["input_shape"], + ), + fpn_layers, + ) + + # Build the model + model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg["url"]) + + return model + + +def _db_mobilenet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> DBNet: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + input_shape=_cfg["input_shape"], + include_top=False, + pretrained=pretrained_backbone, + ), + fpn_layers, + ) + + # Build the model + model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg["url"]) + + return model + + +def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + + >>> import tensorflow as tf + >>> from doctr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _db_resnet( + "db_resnet50", + pretrained, + ResNet50, + ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"], + **kwargs, + ) + + +def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a mobilenet v3 large backbone. + + >>> import tensorflow as tf + >>> from doctr.models import db_mobilenet_v3_large + >>> model = db_mobilenet_v3_large(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _db_mobilenet( + "db_mobilenet_v3_large", + pretrained, + mobilenet_v3_large, + ["inverted_2", "inverted_5", "inverted_11", "final_block"], + **kwargs, + ) diff --git a/doctr/models/detection/fast/__init__.py b/doctr/models/detection/fast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/detection/fast/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/fast/__pycache__/__init__.cpython-311.pyc b/doctr/models/detection/fast/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ca62e2f6180ee500248c50775a652e60744bf9 Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/__init__.cpython-38.pyc b/doctr/models/detection/fast/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc8663d45979b6ef25d2fbd3dfe009f18e82537 Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/base.cpython-311.pyc b/doctr/models/detection/fast/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d9266df32708d73db283b03f677a7d91549388e Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/base.cpython-38.pyc b/doctr/models/detection/fast/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21890b9bb8284e996bee5458bd09b1bd047c701a Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/pytorch.cpython-311.pyc b/doctr/models/detection/fast/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6cd855a98e21374a0ebdb0355a84f218a5d54e Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/detection/fast/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff32e91334f8d2cd7c4f510b49da2f71ad25f496 Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/detection/fast/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/detection/fast/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3999bf864cb9135f1260ccfb971871cb173d98d Binary files /dev/null and b/doctr/models/detection/fast/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py new file mode 100644 index 0000000000000000000000000000000000000000..868c3eadec4c1a284c1f348c117fca96a0476291 --- /dev/null +++ b/doctr/models/detection/fast/base.py @@ -0,0 +1,256 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from doctr.models.core import BaseModel + +from ..core import DetectionPostProcessor + +__all__ = ["_FAST", "FASTPostProcessor"] + + +class FASTPostProcessor(DetectionPostProcessor): + """Implements a post processor for FAST model. + + Args: + ---- + bin_thresh: threshold used to binzarized p_map at inference time + box_thresh: minimal objectness score to consider a box + assume_straight_pages: whether the inputs were expected to have horizontal text elements + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + ) -> None: + super().__init__(box_thresh, bin_thresh, assume_straight_pages) + self.unclip_ratio = 1.0 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + ---- + points: The first parameter. + + Returns: + ------- + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points: np.ndarray = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None # type: ignore[return-value] + return ( + cv2.boundingRect(expanded_points) # type: ignore[return-value] + if self.assume_straight_pages + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + ---- + pred: Pred map from differentiable linknet output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + ------- + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box + """ + height, width = bitmap.shape[:2] + boxes: List[Union[np.ndarray, List[float]]] = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + if self.assume_straight_pages: + # compute relative polygon to get rid of img shape + x, y, w, h = _box + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) + + +class _FAST(BaseModel): + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_. + """ + + min_size_box: int = 3 + assume_straight_pages: bool = True + shrink_ratio = 0.4 + + def build_target( + self, + target: List[Dict[str, np.ndarray]], + output_shape: Tuple[int, int, int], + channels_last: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build the target, and it's mask to be used from loss computation. + + Args: + ---- + target: target coming from dataset + output_shape: shape of the output of the model without batch_size + channels_last: whether channels are last or not + + Returns: + ------- + the new formatted target, mask and shrunken text kernel + """ + if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): + raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): + raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + h: int + w: int + if channels_last: + h, w, num_classes = output_shape + else: + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) + + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) + shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + + for idx, tgt in enumerate(target): + for class_idx, _tgt in enumerate(tgt.values()): + # Draw each polygon on gt + if _tgt.shape[0] == 0: + # Empty image, full masked + seg_mask[idx, class_idx] = False + + # Absolute bounding boxes + abs_boxes = _tgt.copy() + + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack( + [ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], + axis=1, + ) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrunken = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrunken) == 0: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + # draw the original polygon on the segmentation target + cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload] + + # Don't forget to switch back to channel last if Tensorflow is used + if channels_last: + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) + shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1)) + + return seg_target, seg_mask, shrunken_kernel diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac44b182567f890b48e090e292fbb2cb48f9944 --- /dev/null +++ b/doctr/models/detection/fast/pytorch.py @@ -0,0 +1,442 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.file_utils import CLASS_NAME + +from ...classification import textnet_base, textnet_small, textnet_tiny +from ...modules.layers import FASTConvLayer +from ...utils import _bf16_to_float32, load_pretrained_params +from .base import _FAST, FASTPostProcessor + +__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "fast_tiny": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-1acac421.pt&src=0", + }, + "fast_small": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-10952cc1.pt&src=0", + }, + "fast_base": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-688a8b34.pt&src=0", + }, +} + + +class FastNeck(nn.Module): + """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers. + + Args: + ---- + in_channels: number of input channels + out_channels: number of output channels + """ + + def __init__( + self, + in_channels: int, + out_channels: int = 128, + ) -> None: + super().__init__() + self.reduction = nn.ModuleList([ + FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8] + ]) + + def _upsample(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return F.interpolate(x, size=y.shape[-2:], mode="bilinear") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + f1, f2, f3, f4 = x + f1, f2, f3, f4 = [reduction(f) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))] + f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)] + f = torch.cat((f1, f2, f3, f4), 1) + return f + + +class FastHead(nn.Sequential): + """Head of the FAST architecture + + Args: + ---- + in_channels: number of input channels + num_classes: number of output classes + out_channels: number of output channels + dropout: dropout probability + """ + + def __init__( + self, + in_channels: int, + num_classes: int, + out_channels: int = 128, + dropout: float = 0.1, + ) -> None: + _layers: List[nn.Module] = [ + FASTConvLayer(in_channels, out_channels, kernel_size=3), + nn.Dropout(dropout), + nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False), + ] + super().__init__(*_layers) + + +class FAST(_FAST, nn.Module): + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_. + + Args: + ---- + feat extractor: the backbone serving as feature extractor + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box + dropout_prob: dropout probability + pooling_size: size of the pooling layer + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + dropout_prob: float = 0.1, + pooling_size: int = 4, # different from paper performs better on close text-rich images + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = {}, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) + self.cfg = cfg + + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + # Identify the number of channels for the neck & head initialization + _is_training = self.feat_extractor.training + self.feat_extractor = self.feat_extractor.eval() + with torch.no_grad(): + out = self.feat_extractor(torch.zeros((1, 3, 32, 32))) + feat_out_channels = [v.shape[1] for _, v in out.items()] + + if _is_training: + self.feat_extractor = self.feat_extractor.train() + + # Initialize neck & head + self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1]) + self.prob_head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob) + + # NOTE: The post processing from the paper works not well for text-rich images + # so we use a modified version from DBNet + self.postprocessor = FASTPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + # Pooling layer as erosion reversal as described in the paper + self.pooling = nn.MaxPool2d(kernel_size=pooling_size // 2 + 1, stride=1, padding=(pooling_size // 2) // 2) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, torch.Tensor]: + # Extract feature maps at different stages + feats = self.feat_extractor(x) + feats = [feats[str(idx)] for idx in range(len(feats))] + # Pass through the Neck & Head & Upsample + feat_concat = self.neck(feats) + logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear") + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(torch.sigmoid(self.pooling(logits))) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + if target is not None: + loss = self.compute_loss(logits, target) + out["loss"] = loss + + return out + + def compute_loss( + self, + out_map: torch.Tensor, + target: List[np.ndarray], + eps: float = 1e-6, + ) -> torch.Tensor: + """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5. + + Args: + ---- + out_map: output feature map of the model of shape (N, num_classes, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] + + seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) + shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device) + seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) + + def ohem_sample(score: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + masks = [] + for class_idx in range(gt.shape[0]): + pos_num = int(torch.sum(gt[class_idx] > 0.5)) - int( + torch.sum((gt[class_idx] > 0.5) & (mask[class_idx] <= 0.5)) + ) + neg_num = int(torch.sum(gt[class_idx] <= 0.5)) + neg_num = int(min(pos_num * 3, neg_num)) + + if neg_num == 0 or pos_num == 0: + masks.append(mask[class_idx]) + continue + + neg_score_sorted, _ = torch.sort(-score[class_idx][gt[class_idx] <= 0.5]) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = ((score[class_idx] >= threshold) | (gt[class_idx] > 0.5)) & (mask[class_idx] > 0.5) + masks.append(selected_mask) + # combine all masks to shape (len(masks), H, W) + return torch.stack(masks).unsqueeze(0).float() + + if len(self.class_names) > 1: + kernels = torch.softmax(out_map, dim=1) + prob_map = torch.softmax(self.pooling(out_map), dim=1) + else: + kernels = torch.sigmoid(out_map) + prob_map = torch.sigmoid(self.pooling(out_map)) + + # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5. + selected_masks = torch.cat( + [ohem_sample(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], 0 + ).float() + inter = (selected_masks * prob_map * seg_target).sum((0, 2, 3)) + cardinality = (selected_masks * (prob_map + seg_target)).sum((0, 2, 3)) + text_loss = (1 - 2 * inter / (cardinality + eps)).mean() * 0.5 + + # As described in the paper, we use the Dice loss for the text kernel map. + selected_masks = seg_target * seg_mask + inter = (selected_masks * kernels * shrunken_kernel).sum((0, 2, 3)) # noqa + cardinality = (selected_masks * (kernels + shrunken_kernel)).sum((0, 2, 3)) # noqa + kernel_loss = (1 - 2 * inter / (cardinality + eps)).mean() + + return text_loss + kernel_loss + + +def reparameterize(model: Union[FAST, nn.Module]) -> FAST: + """Fuse batchnorm and conv layers and reparameterize the model + + args: + ---- + model: the FAST model to reparameterize + + Returns: + ------- + the reparameterized model + """ + last_conv = None + last_conv_name = None + + for module in model.modules(): + if hasattr(module, "reparameterize_layer"): + module.reparameterize_layer() + + for name, child in model.named_children(): + if isinstance(child, nn.BatchNorm2d): + # fuse batchnorm only if it is followed by a conv layer + if last_conv is None: + continue + conv_w = last_conv.weight + conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean) + + factor = child.weight / torch.sqrt(child.running_var + child.eps) + last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1])) + last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) + model._modules[last_conv_name] = last_conv + model._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + reparameterize(child) + + return model # type: ignore[return-value] + + +def _fast( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + feat_layers: List[str], + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> FAST: + pretrained_backbone = pretrained_backbone and not pretrained + + # Build the feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn(pretrained_backbone), + {layer_name: str(idx) for idx, layer_name in enumerate(feat_layers)}, + ) + + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + # Build the model + model = FAST(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a tiny TextNet backbone. + + >>> import torch + >>> from doctr.models import fast_tiny + >>> model = fast_tiny(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_tiny", + pretrained, + textnet_tiny, + ["3", "4", "5", "6"], + ignore_keys=["prob_head.2.weight"], + **kwargs, + ) + + +def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a small TextNet backbone. + + >>> import torch + >>> from doctr.models import fast_small + >>> model = fast_small(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_small", + pretrained, + textnet_small, + ["3", "4", "5", "6"], + ignore_keys=["prob_head.2.weight"], + **kwargs, + ) + + +def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a base TextNet backbone. + + >>> import torch + >>> from doctr.models import fast_base + >>> model = fast_base(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_base", + pretrained, + textnet_base, + ["3", "4", "5", "6"], + ignore_keys=["prob_head.2.weight"], + **kwargs, + ) diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..69998a2303e0c6edf6d321d738afc32cd3f310b1 --- /dev/null +++ b/doctr/models/detection/fast/tensorflow.py @@ -0,0 +1,428 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Sequential, layers + +from doctr.file_utils import CLASS_NAME +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params +from doctr.utils.repr import NestedObject + +from ...classification import textnet_base, textnet_small, textnet_tiny +from ...modules.layers import FASTConvLayer +from .base import _FAST, FASTPostProcessor + +__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "fast_tiny": { + "input_shape": (1024, 1024, 3), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0", + }, + "fast_small": { + "input_shape": (1024, 1024, 3), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0", + }, + "fast_base": { + "input_shape": (1024, 1024, 3), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0", + }, +} + + +class FastNeck(layers.Layer, NestedObject): + """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer. + + Args: + ---- + in_channels: number of input channels + out_channels: number of output channels + """ + + def __init__( + self, + in_channels: int, + out_channels: int = 128, + ) -> None: + super().__init__() + self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]] + + def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + return tf.image.resize(x, size=y.shape[1:3], method="bilinear") + + def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: + f1, f2, f3, f4 = x + f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))] + f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)] + f = tf.concat((f1, f2, f3, f4), axis=-1) + return f + + +class FastHead(Sequential): + """Head of the FAST architecture + + Args: + ---- + in_channels: number of input channels + num_classes: number of output classes + out_channels: number of output channels + dropout: dropout probability + """ + + def __init__( + self, + in_channels: int, + num_classes: int, + out_channels: int = 128, + dropout: float = 0.1, + ) -> None: + _layers = [ + FASTConvLayer(in_channels, out_channels, kernel_size=3), + layers.Dropout(dropout), + layers.Conv2D(num_classes, kernel_size=1, use_bias=False), + ] + super().__init__(_layers) + + +class FAST(_FAST, keras.Model, NestedObject): + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box + dropout_prob: dropout probability + pooling_size: size of the pooling layer + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"] + + def __init__( + self, + feature_extractor: IntermediateLayerGetter, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + dropout_prob: float = 0.1, + pooling_size: int = 4, # different from paper performs better on close text-rich images + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = {}, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) + self.cfg = cfg + + self.feat_extractor = feature_extractor + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + # Identify the number of channels for the neck & head initialization + feat_out_channels = [ + layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape + ] + # Initialize neck & head + self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1]) + self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob) + + # NOTE: The post processing from the paper works not well for text-rich images + # so we use a modified version from DBNet + self.postprocessor = FASTPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + # Pooling layer as erosion reversal as described in the paper + self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same") + + def compute_loss( + self, + out_map: tf.Tensor, + target: List[Dict[str, np.ndarray]], + eps: float = 1e-6, + ) -> tf.Tensor: + """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5. + + Args: + ---- + out_map: output feature map of the model of shape (N, num_classes, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + targets = self.build_target(target, out_map.shape[1:], True) + + seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype) + shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype) + + def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum( + tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32) + ) + neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32)) + neg_num = tf.minimum(pos_num * 3, neg_num) + + if neg_num == 0 or pos_num == 0: + return mask + + neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num) + threshold = -neg_score_sorted[-1] + + selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5)) + return tf.cast(selected_mask, dtype=tf.float32) + + if len(self.class_names) > 1: + kernels = tf.nn.softmax(out_map, axis=-1) + prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1) + else: + kernels = tf.sigmoid(out_map) + prob_map = tf.sigmoid(self.pooling(out_map)) + + # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5. + selected_masks = tf.stack( + [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0 + ) + inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2)) + cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2)) + text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5 + + # As described in the paper, we use the Dice loss for the text kernel map. + selected_masks = seg_target * seg_mask + inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2)) + cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2)) + kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) + + return text_loss + kernel_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[Dict[str, np.ndarray]]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + feat_maps = self.feat_extractor(x, **kwargs) + # Pass through the Neck & Head & Upsample + feat_concat = self.neck(feat_maps, **kwargs) + logits: tf.Tensor = self.head(feat_concat, **kwargs) + logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs))) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())] + + if target is not None: + loss = self.compute_loss(logits, target) + out["loss"] = loss + + return out + + +def reparameterize(model: Union[FAST, layers.Layer]) -> FAST: + """Fuse batchnorm and conv layers and reparameterize the model + + args: + ---- + model: the FAST model to reparameterize + + Returns: + ------- + the reparameterized model + """ + last_conv = None + last_conv_idx = None + + for idx, layer in enumerate(model.layers): + if hasattr(layer, "layers") or isinstance( + layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D) + ): + if isinstance(layer, layers.BatchNormalization): + # fuse batchnorm only if it is followed by a conv layer + if last_conv is None: + continue + conv_w = last_conv.kernel + conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean) + + factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon) + last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1]) + if last_conv.use_bias: + last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta) + model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version + model.layers[idx] = layers.Lambda(lambda x: x) + last_conv = None + elif isinstance(layer, layers.Conv2D): + last_conv = layer + last_conv_idx = idx + elif isinstance(layer, FASTConvLayer): + layer.reparameterize_layer() + elif isinstance(layer, FastNeck): + for reduction in layer.reduction: + reduction.reparameterize_layer() + elif isinstance(layer, FastHead): + reparameterize(layer) + else: + reparameterize(layer) + return model + + +def _fast( + arch: str, + pretrained: bool, + backbone_fn, + feat_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> FAST: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + input_shape=_cfg["input_shape"], + include_top=False, + pretrained=pretrained_backbone, + ), + feat_layers, + ) + + # Build the model + model = FAST(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg["url"]) + + # Build the model for reparameterization to access the layers + _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) + + return model + + +def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a tiny TextNet backbone. + + >>> import tensorflow as tf + >>> from doctr.models import fast_tiny + >>> model = fast_tiny(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_tiny", + pretrained, + textnet_tiny, + ["stage_0", "stage_1", "stage_2", "stage_3"], + **kwargs, + ) + + +def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a small TextNet backbone. + + >>> import tensorflow as tf + >>> from doctr.models import fast_small + >>> model = fast_small(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_small", + pretrained, + textnet_small, + ["stage_0", "stage_1", "stage_2", "stage_3"], + **kwargs, + ) + + +def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: + """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" + `_, using a base TextNet backbone. + + >>> import tensorflow as tf + >>> from doctr.models import fast_base + >>> model = fast_base(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the DBNet architecture + + Returns: + ------- + text detection architecture + """ + return _fast( + "fast_base", + pretrained, + textnet_base, + ["stage_0", "stage_1", "stage_2", "stage_3"], + **kwargs, + ) diff --git a/doctr/models/detection/linknet/__init__.py b/doctr/models/detection/linknet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/detection/linknet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/linknet/__pycache__/__init__.cpython-311.pyc b/doctr/models/detection/linknet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7277c1d57240d6fbc818ef3664c582379ca703b4 Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/__init__.cpython-38.pyc b/doctr/models/detection/linknet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84afee1e4e494894e0f3646e030f9bdf420885bf Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/base.cpython-311.pyc b/doctr/models/detection/linknet/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7942bcc7547daae9bdf225f2b2aec0cc9e99d8 Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/base.cpython-38.pyc b/doctr/models/detection/linknet/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d930ac8740c5d4d6067008da0977eaf1e11376b0 Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/pytorch.cpython-311.pyc b/doctr/models/detection/linknet/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e508ddc5958d5293c0524722847e9db998cc6c59 Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f52c54f89e648626c0793dcc41c2ec36daa65f7 Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b5c23fc3e156047c6bc7cf5d557851df5c0dc6f Binary files /dev/null and b/doctr/models/detection/linknet/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py new file mode 100644 index 0000000000000000000000000000000000000000..986f57d6ad2fbdc6503dff21901875617e04c50d --- /dev/null +++ b/doctr/models/detection/linknet/base.py @@ -0,0 +1,256 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from doctr.models.core import BaseModel + +from ..core import DetectionPostProcessor + +__all__ = ["_LinkNet", "LinkNetPostProcessor"] + + +class LinkNetPostProcessor(DetectionPostProcessor): + """Implements a post processor for LinkNet model. + + Args: + ---- + bin_thresh: threshold used to binzarized p_map at inference time + box_thresh: minimal objectness score to consider a box + assume_straight_pages: whether the inputs were expected to have horizontal text elements + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + ) -> None: + super().__init__(box_thresh, bin_thresh, assume_straight_pages) + self.unclip_ratio = 1.5 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + ---- + points: The first parameter. + + Returns: + ------- + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points: np.ndarray = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None # type: ignore[return-value] + return ( + cv2.boundingRect(expanded_points) # type: ignore[return-value] + if self.assume_straight_pages + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + ---- + pred: Pred map from differentiable linknet output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + ------- + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box + """ + height, width = bitmap.shape[:2] + boxes: List[Union[np.ndarray, List[float]]] = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + if self.assume_straight_pages: + # compute relative polygon to get rid of img shape + x, y, w, h = _box + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) + + +class _LinkNet(BaseModel): + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Args: + ---- + out_chan: number of channels for the output + """ + + min_size_box: int = 3 + assume_straight_pages: bool = True + shrink_ratio = 0.5 + + def build_target( + self, + target: List[Dict[str, np.ndarray]], + output_shape: Tuple[int, int, int], + channels_last: bool = True, + ) -> Tuple[np.ndarray, np.ndarray]: + """Build the target, and it's mask to be used from loss computation. + + Args: + ---- + target: target coming from dataset + output_shape: shape of the output of the model without batch_size + channels_last: whether channels are last or not + + Returns: + ------- + the new formatted target and the mask + """ + if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): + raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): + raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + h: int + w: int + if channels_last: + h, w, num_classes = output_shape + else: + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) + + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) + + for idx, tgt in enumerate(target): + for class_idx, _tgt in enumerate(tgt.values()): + # Draw each polygon on gt + if _tgt.shape[0] == 0: + # Empty image, full masked + seg_mask[idx, class_idx] = False + + # Absolute bounding boxes + abs_boxes = _tgt.copy() + + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack( + [ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], + axis=1, + ) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrunken = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrunken) == 0: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + + # Don't forget to switch back to channel last if Tensorflow is used + if channels_last: + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) + + return seg_target, seg_mask diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..537fd57256526a8964c4feda8c39c05805b697f4 --- /dev/null +++ b/doctr/models/detection/linknet/pytorch.py @@ -0,0 +1,380 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.file_utils import CLASS_NAME +from doctr.models.classification import resnet18, resnet34, resnet50 + +from ...utils import _bf16_to_float32, load_pretrained_params +from .base import LinkNetPostProcessor, _LinkNet + +__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "linknet_resnet18": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0", + }, + "linknet_resnet34": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0", + }, + "linknet_resnet50": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0", + }, +} + + +class LinkNetFPN(nn.Module): + def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None: + super().__init__() + strides = [ + 1 if (in_shape[-1] == out_shape[-1]) else 2 + for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:]) + ] + + chans = [shape[0] for shape in layer_shapes] + + _decoder_layers = [ + self.decoder_block(ochan, ichan, stride) for ichan, ochan, stride in zip(chans[:-1], chans[1:], strides) + ] + + self.decoders = nn.ModuleList(_decoder_layers) + + @staticmethod + def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential: + """Creates a LinkNet decoder block""" + mid_chan = in_chan // 4 + return nn.Sequential( + nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False), + nn.BatchNorm2d(mid_chan), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(mid_chan, mid_chan, 3, padding=1, output_padding=stride - 1, stride=stride, bias=False), + nn.BatchNorm2d(mid_chan), + nn.ReLU(inplace=True), + nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False), + nn.BatchNorm2d(out_chan), + nn.ReLU(inplace=True), + ) + + def forward(self, feats: List[torch.Tensor]) -> torch.Tensor: + out = feats[-1] + for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]): + out = decoder(out) + fmap + + out = self.decoders[0](out) + + return out + + +class LinkNet(nn.Module, _LinkNet): + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box + head_chans: number of channels in the head layers + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + head_chans: int = 32, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) + self.cfg = cfg + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + # Identify the number of channels for the FPN initialization + self.feat_extractor.eval() + with torch.no_grad(): + in_shape = (3, 512, 512) + out = self.feat_extractor(torch.zeros((1, *in_shape))) + # Get the shapes of the extracted feature maps + _shapes = [v.shape[1:] for _, v in out.items()] + # Prepend the expected shapes of the first encoder + _shapes = [(_shapes[0][0], in_shape[1] // 4, in_shape[2] // 4)] + _shapes + self.feat_extractor.train() + + self.fpn = LinkNetFPN(_shapes) + + self.classifier = nn.Sequential( + nn.ConvTranspose2d( + _shapes[0][0], head_chans, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False + ), + nn.BatchNorm2d(head_chans), + nn.ReLU(inplace=True), + nn.Conv2d(head_chans, head_chans, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(head_chans), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2), + ) + + self.postprocessor = LinkNetPostProcessor( + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + feats = self.feat_extractor(x) + logits = self.fpn([feats[str(idx)] for idx in range(len(feats))]) + logits = self.classifier(logits) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(torch.sigmoid(logits)) + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes + out["preds"] = [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + if target is not None: + loss = self.compute_loss(logits, target) + out["loss"] = loss + + return out + + def compute_loss( + self, + out_map: torch.Tensor, + target: List[np.ndarray], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> torch.Tensor: + """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on + `_. + + Args: + ---- + out_map: output feature map of the model of shape (N, num_classes, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] + + seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask) + seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) + seg_mask = seg_mask.to(dtype=torch.float32) + + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + proba_map = torch.sigmoid(out_map) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + p_t = proba_map * seg_target + (1 - proba_map) * (1 - seg_target) + alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) + # Unreduced version + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3)) + + # Compute dice loss for each class + dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map + # Class reduced + inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3)) + cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3)) + dice_loss = (1 - 2 * inter / (cardinality + eps)).mean() + + # Return the full loss (equal sum of focal loss and dice loss) + return focal_loss + dice_loss + + +def _linknet( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + fpn_layers: List[str], + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> LinkNet: + pretrained_backbone = pretrained_backbone and not pretrained + + # Build the feature extractor + backbone = backbone_fn(pretrained_backbone) + feat_extractor = IntermediateLayerGetter( + backbone, + {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, + ) + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + + # Build the model + model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import torch + >>> from doctr.models import linknet_resnet18 + >>> model = linknet_resnet18(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet18", + pretrained, + resnet18, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + ) + + +def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import torch + >>> from doctr.models import linknet_resnet34 + >>> model = linknet_resnet34(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet34", + pretrained, + resnet34, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + ) + + +def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import torch + >>> from doctr.models import linknet_resnet50 + >>> model = linknet_resnet50(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet50", + pretrained, + resnet50, + ["layer1", "layer2", "layer3", "layer4"], + ignore_keys=[ + "classifier.6.weight", + "classifier.6.bias", + ], + **kwargs, + ) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..ff11dbe4778d9fda973a72b3692cc014ccef70d5 --- /dev/null +++ b/doctr/models/detection/linknet/tensorflow.py @@ -0,0 +1,366 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential, layers + +from doctr.file_utils import CLASS_NAME +from doctr.models.classification import resnet18, resnet34, resnet50 +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.utils.repr import NestedObject + +from .base import LinkNetPostProcessor, _LinkNet + +__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "linknet_resnet18": { + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "input_shape": (1024, 1024, 3), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0", + }, + "linknet_resnet34": { + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "input_shape": (1024, 1024, 3), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0", + }, + "linknet_resnet50": { + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "input_shape": (1024, 1024, 3), + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0", + }, +} + + +def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential: + """Creates a LinkNet decoder block""" + return Sequential([ + *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs), + layers.Conv2DTranspose( + filters=in_chan // 4, + kernel_size=3, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(out_chan, "relu", True, kernel_size=1), + ]) + + +class LinkNetFPN(Model, NestedObject): + """LinkNet Decoder module""" + + def __init__( + self, + out_chans: int, + in_shapes: List[Tuple[int, ...]], + ) -> None: + super().__init__() + self.out_chans = out_chans + strides = [2] * (len(in_shapes) - 1) + [1] + i_chans = [s[-1] for s in in_shapes[::-1]] + o_chans = i_chans[1:] + [out_chans] + self.decoders = [ + decoder_block(in_chan, out_chan, s, input_shape=in_shape) + for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1]) + ] + + def call(self, x: List[tf.Tensor]) -> tf.Tensor: + out = 0 + for decoder, fmap in zip(self.decoders, x[::-1]): + out = decoder(out + fmap) + return out + + def extra_repr(self) -> str: + return f"out_chans={self.out_chans}" + + +class LinkNet(_LinkNet, keras.Model): + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Args: + ---- + feature extractor: the backbone serving as feature extractor + fpn_channels: number of channels each extracted feature maps is mapped to + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + class_names: list of class names + """ + + _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"] + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + fpn_channels: int = 64, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__(cfg=cfg) + + self.class_names = class_names + num_classes: int = len(self.class_names) + + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + + self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape]) + self.fpn.build(self.feat_extractor.output_shape) + + self.classifier = Sequential([ + layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=2, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + input_shape=self.fpn.decoders[-1].output_shape[1:], + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(32, "relu", True, kernel_size=3, strides=1), + layers.Conv2DTranspose( + filters=num_classes, + kernel_size=2, + strides=2, + padding="same", + use_bias=True, + kernel_initializer="he_normal", + ), + ]) + + self.postprocessor = LinkNetPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) + + def compute_loss( + self, + out_map: tf.Tensor, + target: List[Dict[str, np.ndarray]], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> tf.Tensor: + """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on + `_. + + Args: + ---- + out_map: output feature map of the model of shape N x H x W x 1 + target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss + + Returns: + ------- + A loss tensor + """ + seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) + + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + proba_map = tf.sigmoid(out_map) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor + p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Compute dice loss for each class + dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map + # Class-reduced dice loss + inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2]) + cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2]) + dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps)) + + return focal_loss + dice_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[Dict[str, np.ndarray]]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + feat_maps = self.feat_extractor(x, **kwargs) + logits = self.fpn(feat_maps, **kwargs) + logits = self.classifier(logits, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = _bf16_to_float32(tf.math.sigmoid(logits)) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes + out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())] + + if target is not None: + loss = self.compute_loss(logits, target) + out["loss"] = loss + + return out + + +def _linknet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> LinkNet: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + pretrained=pretrained_backbone, + include_top=False, + input_shape=_cfg["input_shape"], + ), + fpn_layers, + ) + + # Build the model + model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg["url"]) + + return model + + +def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import tensorflow as tf + >>> from doctr.models import linknet_resnet18 + >>> model = linknet_resnet18(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet18", + pretrained, + resnet18, + ["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"], + **kwargs, + ) + + +def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import tensorflow as tf + >>> from doctr.models import linknet_resnet34 + >>> model = linknet_resnet34(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet34", + pretrained, + resnet34, + ["resnet_block_2", "resnet_block_6", "resnet_block_12", "resnet_block_15"], + **kwargs, + ) + + +def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + >>> import tensorflow as tf + >>> from doctr.models import linknet_resnet50 + >>> model = linknet_resnet50(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + ------- + text detection architecture + """ + return _linknet( + "linknet_resnet50", + pretrained, + resnet50, + ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"], + **kwargs, + ) diff --git a/doctr/models/detection/predictor/__init__.py b/doctr/models/detection/predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd --- /dev/null +++ b/doctr/models/detection/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/predictor/__pycache__/__init__.cpython-311.pyc b/doctr/models/detection/predictor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2b1e12edb10fa70825bd7cb90b93962b5c1b673 Binary files /dev/null and b/doctr/models/detection/predictor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/detection/predictor/__pycache__/__init__.cpython-38.pyc b/doctr/models/detection/predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2ce6b531e860d3ad44407a25f927a6066c55fb2 Binary files /dev/null and b/doctr/models/detection/predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/detection/predictor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/detection/predictor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c09a447c2d5ce8c4a2fd8b3c1921ef7383d7947 Binary files /dev/null and b/doctr/models/detection/predictor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda1405798824f9ba2fea37e2304af816728a453 Binary files /dev/null and b/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1617e75ae2947c51754e4116f0c9aac027234c5c Binary files /dev/null and b/doctr/models/detection/predictor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7b83d61906cd5631ac3638dc18c4bb418bee6556 --- /dev/null +++ b/doctr/models/detection/predictor/pytorch.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import set_device_and_dtype + +__all__ = ["DetectionPredictor"] + + +class DetectionPredictor(nn.Module): + """Implements an object able to localize text elements in a document + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.inference_mode() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + return_maps: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + _params = next(self.model.parameters()) + self.model, processed_batches = set_device_and_dtype( + self.model, processed_batches, _params.device, _params.dtype + ) + predicted_batches = [ + self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches + ] + preds = [pred for batch in predicted_batches for pred in batch["preds"]] + if return_maps: + seg_maps = [ + pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"] + ] + return preds, seg_maps + return preds diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff3f388bd6bcc60824d31b0e8af806f5aaffaf5 --- /dev/null +++ b/doctr/models/detection/predictor/tensorflow.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow import keras + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +__all__ = ["DetectionPredictor"] + + +class DetectionPredictor(NestedObject): + """Implements an object able to localize text elements in a document + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + """ + + _children_names: List[str] = ["pre_processor", "model"] + + def __init__( + self, + pre_processor: PreProcessor, + model: keras.Model, + ) -> None: + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + return_maps: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + predicted_batches = [ + self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs) + for batch in processed_batches + ] + + preds = [pred for batch in predicted_batches for pred in batch["preds"]] + if return_maps: + seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]] + return preds, seg_maps + return preds diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..3cab59e3810bd4d979664e0e3282516de293f9ff --- /dev/null +++ b/doctr/models/detection/zoo.py @@ -0,0 +1,102 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +from doctr.file_utils import is_tf_available, is_torch_available + +from .. import detection +from ..detection.fast import reparameterize +from ..preprocessor import PreProcessor +from .predictor import DetectionPredictor + +__all__ = ["detection_predictor"] + +ARCHS: List[str] + + +if is_tf_available(): + ARCHS = [ + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "linknet_resnet34", + "linknet_resnet50", + "fast_tiny", + "fast_small", + "fast_base", + ] +elif is_torch_available(): + ARCHS = [ + "db_resnet34", + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "linknet_resnet34", + "linknet_resnet50", + "fast_tiny", + "fast_small", + "fast_base", + ] + + +def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = detection.__dict__[arch]( + pretrained=pretrained, + pretrained_backbone=kwargs.get("pretrained_backbone", True), + assume_straight_pages=assume_straight_pages, + ) + # Reparameterize FAST models by default to lower inference latency and memory usage + if isinstance(_model, detection.FAST): + _model = reparameterize(_model) + else: + if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): + raise ValueError(f"unknown architecture: {type(arch)}") + + _model = arch + _model.assume_straight_pages = assume_straight_pages + + kwargs.pop("pretrained_backbone", None) + + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 2) + predictor = DetectionPredictor( + PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), + _model, + ) + return predictor + + +def detection_predictor( + arch: Any = "fast_base", + pretrained: bool = False, + assume_straight_pages: bool = True, + **kwargs: Any, +) -> DetectionPredictor: + """Text detection architecture. + + >>> import numpy as np + >>> from doctr.models import detection_predictor + >>> model = detection_predictor(arch='db_resnet50', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture or model itself to use (e.g. 'db_resnet50') + pretrained: If True, returns a model pre-trained on our text detection dataset + assume_straight_pages: If True, fit straight boxes to the page + **kwargs: optional keyword arguments passed to the architecture + + Returns: + ------- + Detection predictor + """ + return _predictor(arch, pretrained, assume_straight_pages, **kwargs) diff --git a/doctr/models/factory/__init__.py b/doctr/models/factory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5b25a325e8a5ecc11546832996a78de1407ce4 --- /dev/null +++ b/doctr/models/factory/__init__.py @@ -0,0 +1 @@ +from .hub import * diff --git a/doctr/models/factory/__pycache__/__init__.cpython-311.pyc b/doctr/models/factory/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08ae2f36f976d9db0226a4c5557b2270038ce1ff Binary files /dev/null and b/doctr/models/factory/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/factory/__pycache__/__init__.cpython-38.pyc b/doctr/models/factory/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7585d663ad5a11de03992ca1b099747727b648ee Binary files /dev/null and b/doctr/models/factory/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/factory/__pycache__/hub.cpython-311.pyc b/doctr/models/factory/__pycache__/hub.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae97a1b2acaf043cf75143dfdb447c873ec5c996 Binary files /dev/null and b/doctr/models/factory/__pycache__/hub.cpython-311.pyc differ diff --git a/doctr/models/factory/__pycache__/hub.cpython-38.pyc b/doctr/models/factory/__pycache__/hub.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f55f44978bad0dd2f9ed722c49f66ab033239f0 Binary files /dev/null and b/doctr/models/factory/__pycache__/hub.cpython-38.pyc differ diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c3f8932223498599a0cc3127593cd28f62703b --- /dev/null +++ b/doctr/models/factory/hub.py @@ -0,0 +1,231 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py + +import json +import logging +import os +import subprocess +import textwrap +from pathlib import Path +from typing import Any + +from huggingface_hub import ( + HfApi, + Repository, + get_token, + get_token_permission, + hf_hub_download, + login, + snapshot_download, +) + +from doctr import models +from doctr.file_utils import is_tf_available, is_torch_available + +if is_torch_available(): + import torch + +__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] + + +AVAILABLE_ARCHS = { + "classification": models.classification.zoo.ARCHS, + "detection": models.detection.zoo.ARCHS, + "recognition": models.recognition.zoo.ARCHS, +} + + +def login_to_hub() -> None: # pragma: no cover + """Login to huggingface hub""" + access_token = get_token() + if access_token is not None and get_token_permission(access_token): + logging.info("Huggingface Hub token found and valid") + login(token=access_token, write_permission=True) + else: + login() + # check if git lfs is installed + try: + subprocess.call(["git", "lfs", "version"]) + except FileNotFoundError: + raise OSError( + "Looks like you do not have git-lfs installed, please install. \ + You can install from https://git-lfs.github.com/. \ + Then run `git lfs install` (you only have to do this once)." + ) + + +def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None: + """Save model and config to disk for pushing to huggingface hub + + Args: + ---- + model: TF or PyTorch model to be saved + save_dir: directory to save model and config + arch: architecture name + task: task name + """ + save_directory = Path(save_dir) + + if is_torch_available(): + weights_path = save_directory / "pytorch_model.bin" + torch.save(model.state_dict(), weights_path) + elif is_tf_available(): + weights_path = save_directory / "tf_model" / "weights" + model.save_weights(str(weights_path)) + + config_path = save_directory / "config.json" + + # add model configuration + model_config = model.cfg + model_config["arch"] = arch + model_config["task"] = task + + with config_path.open("w") as f: + json.dump(model_config, f, indent=2, ensure_ascii=False) + + +def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # pragma: no cover + """Save model and its configuration on HF hub + + >>> from doctr.models import login_to_hub, push_to_hf_hub + >>> from doctr.models.recognition import crnn_mobilenet_v3_small + >>> login_to_hub() + >>> model = crnn_mobilenet_v3_small(pretrained=True) + >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') + + Args: + ---- + model: TF or PyTorch model to be saved + model_name: name of the model which is also the repository name + task: task name + **kwargs: keyword arguments for push_to_hf_hub + """ + run_config = kwargs.get("run_config", None) + arch = kwargs.get("arch", None) + + if run_config is None and arch is None: + raise ValueError("run_config or arch must be specified") + if task not in ["classification", "detection", "recognition"]: + raise ValueError("task must be one of classification, detection, recognition") + + # default readme + readme = textwrap.dedent( + f""" + --- + language: en + --- + +

+ +

+ + **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** + + ## Task: {task} + + https://github.com/mindee/doctr + + ### Example usage: + + ```python + >>> from doctr.io import DocumentFile + >>> from doctr.models import ocr_predictor, from_hub + + >>> img = DocumentFile.from_images(['']) + >>> # Load your model from the hub + >>> model = from_hub('mindee/my-model') + + >>> # Pass it to the predictor + >>> # If your model is a recognition model: + >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', + >>> reco_arch=model, + >>> pretrained=True) + + >>> # If your model is a detection model: + >>> predictor = ocr_predictor(det_arch=model, + >>> reco_arch='crnn_mobilenet_v3_small', + >>> pretrained=True) + + >>> # Get your predictions + >>> res = predictor(img) + ``` + """ + ) + + # add run configuration to readme if available + if run_config is not None: + arch = run_config.arch + readme += textwrap.dedent( + f"""### Run Configuration + \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" + ) + + if arch not in AVAILABLE_ARCHS[task]: + raise ValueError( + f"Architecture: {arch} for task: {task} not found.\ + \nAvailable architectures: {AVAILABLE_ARCHS}" + ) + + commit_message = f"Add {model_name} model" + + local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) + repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False) + repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) + + with repo.commit(commit_message): + _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) + readme_path = Path(repo.local_dir) / "README.md" + readme_path.write_text(readme) + + repo.git_push() + + +def from_hub(repo_id: str, **kwargs: Any): + """Instantiate & load a pretrained model from HF hub. + + >>> from doctr.models import from_hub + >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") + + Args: + ---- + repo_id: HuggingFace model hub repo + kwargs: kwargs of `hf_hub_download` or `snapshot_download` + + Returns: + ------- + Model loaded with the checkpoint + """ + # Get the config + with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f: + cfg = json.load(f) + + arch = cfg["arch"] + task = cfg["task"] + cfg.pop("arch") + cfg.pop("task") + + if task == "classification": + model = models.classification.__dict__[arch]( + pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"] + ) + elif task == "detection": + model = models.detection.__dict__[arch](pretrained=False) + elif task == "recognition": + model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) + + # update model cfg + model.cfg = cfg + + # Load checkpoint + if is_torch_available(): + state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") + model.load_state_dict(state_dict) + else: # tf + repo_path = snapshot_download(repo_id, **kwargs) + model.load_weights(os.path.join(repo_path, "tf_model", "weights")) + + return model diff --git a/doctr/models/kie_predictor/__init__.py b/doctr/models/kie_predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd --- /dev/null +++ b/doctr/models/kie_predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/kie_predictor/__pycache__/__init__.cpython-311.pyc b/doctr/models/kie_predictor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdec97860fa3f674d46a601f317a8b6132fca5f9 Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/__init__.cpython-38.pyc b/doctr/models/kie_predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e769b0bb6cac33e64405200482010e0d7fdcef6f Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/base.cpython-311.pyc b/doctr/models/kie_predictor/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342332d1bf6f04ac63dae2f44ce3b91524e88e13 Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/base.cpython-38.pyc b/doctr/models/kie_predictor/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a13625ecacfdef1bd6fdf8706dc118348ca2cc0b Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/kie_predictor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97eaa3fc0970fcf3fd15e0b50e9eeeadcdeda903 Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d70cd5b0a4dc16e798c370b2112d03ca53ca4d46 Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aaddd608cc04be78e8db60f9f3be3068f1b03a9 Binary files /dev/null and b/doctr/models/kie_predictor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py new file mode 100644 index 0000000000000000000000000000000000000000..63a87f5900b0dbf2b9e1620f1a87bd2be0564478 --- /dev/null +++ b/doctr/models/kie_predictor/base.py @@ -0,0 +1,43 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Optional + +from doctr.models.builder import KIEDocumentBuilder + +from ..classification.predictor import OrientationPredictor +from ..predictor.base import _OCRPredictor + +__all__ = ["_KIEPredictor"] + + +class _KIEPredictor(_OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) + symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. + kwargs: keyword args of `DocumentBuilder` + """ + + crop_orientation_predictor: Optional[OrientationPredictor] + + def __init__( + self, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs) + + self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa183f5b930ae8cdc217d2af0b5283aa0f69c82 --- /dev/null +++ b/doctr/models/kie_predictor/pytorch.py @@ -0,0 +1,176 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language, invert_data_structure +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_image + +from .base import _KIEPredictor + +__all__ = ["KIEPredictor"] + + +class KIEPredictor(nn.Module, _KIEPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + **kwargs: keyword args of `DocumentBuilder` + """ + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + nn.Module.__init__(self) + self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] + self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + _KIEPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + @torch.inference_mode() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> Document: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages] + + # Localize text elements + loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) + + # Detect document rotation and rotate pages + seg_maps = [ + np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype( + np.uint8 + ) + for out_map in out_maps + ] + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(seq_map) for seq_map in seg_maps] + ) + pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # type: ignore[arg-type] + # Forward again to get predictions on straight pages + loc_preds = self.det_predictor(pages, **kwargs) + + dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment] + # Check whether crop mode should be switched to channels first + channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + + # Rectify crops if aspect ratio + dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} # type: ignore[arg-type] + + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) + + # Crop images + crops = {} + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( + pages, # type: ignore[arg-type] + dict_loc_preds[class_name], + channels_last=channels_last, + assume_straight_pages=self.assume_straight_pages, + ) + # Rectify crop orientation + crop_orientations: Any = {} + if not self.assume_straight_pages: + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( + crops[class_name], dict_loc_preds[class_name] + ) + crop_orientations[class_name] = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations + ] + + # Identify character sequences + word_preds = { + k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs) + for k, crop_value in crops.items() + } + if not crop_orientations: + crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds} + + boxes: Dict = {} + text_preds: Dict = {} + word_crop_orientations: Dict = {} + for class_name in dict_loc_preds.keys(): + boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions( + dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name] + ) + + boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment] + text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment] + crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment] + + if self.detect_language: + languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + + out = self.doc_builder( + pages, # type: ignore[arg-type] + boxes_per_page, + text_preds_per_page, + origin_page_shapes, # type: ignore[arg-type] + crop_orientations_per_page, + orientations, + languages_dict, + ) + return out + + @staticmethod + def get_text(text_pred: Dict) -> str: + text = [] + for value in text_pred.values(): + text += [item[0] for item in value] + + return " ".join(text) diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..52b1211dd070e43dfa5f5385b0d326a38608a576 --- /dev/null +++ b/doctr/models/kie_predictor/tensorflow.py @@ -0,0 +1,171 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Union + +import numpy as np +import tensorflow as tf + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language, invert_data_structure +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_image +from doctr.utils.repr import NestedObject + +from .base import _KIEPredictor + +__all__ = ["KIEPredictor"] + + +class KIEPredictor(NestedObject, _KIEPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + **kwargs: keyword args of `DocumentBuilder` + """ + + _children_names = ["det_predictor", "reco_predictor", "doc_builder"] + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + self.det_predictor = det_predictor + self.reco_predictor = reco_predictor + _KIEPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> Document: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] for page in pages] + + # Localize text elements + loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) + + # Detect document rotation and rotate pages + seg_maps = [ + np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype( + np.uint8 + ) + for out_map in out_maps + ] + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(seq_map) for seq_map in seg_maps] + ) + pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] + # Forward again to get predictions on straight pages + loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment] + + dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore + # Rectify crops if aspect ratio + dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) + + # Crop images + crops = {} + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( + pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages + ) + + # Rectify crop orientation + crop_orientations: Any = {} + if not self.assume_straight_pages: + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( + crops[class_name], dict_loc_preds[class_name] + ) + crop_orientations[class_name] = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations + ] + + # Identify character sequences + word_preds = { + k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs) + for k, crop_value in crops.items() + } + if not crop_orientations: + crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds} + + boxes: Dict = {} + text_preds: Dict = {} + word_crop_orientations: Dict = {} + for class_name in dict_loc_preds.keys(): + boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions( + dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name] + ) + + boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment] + text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment] + crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment] + + if self.detect_language: + languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + + out = self.doc_builder( + pages, + boxes_per_page, + text_preds_per_page, + origin_page_shapes, # type: ignore[arg-type] + crop_orientations_per_page, + orientations, + languages_dict, + ) + return out + + @staticmethod + def get_text(text_pred: Dict) -> str: + text = [] + for value in text_pred.values(): + text += [item[0] for item in value] + + return " ".join(text) diff --git a/doctr/models/modules/__init__.py b/doctr/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d659f1cdb7d675f32c1b504965bd453cc7a9a4d8 --- /dev/null +++ b/doctr/models/modules/__init__.py @@ -0,0 +1,3 @@ +from .layers import * +from .transformer import * +from .vision_transformer import * diff --git a/doctr/models/modules/__pycache__/__init__.cpython-311.pyc b/doctr/models/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df58cd7c135fd01148fd1a08d769d336b33936dc Binary files /dev/null and b/doctr/models/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/modules/__pycache__/__init__.cpython-38.pyc b/doctr/models/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d59190954aa4e56bec89e631cde298491b843db Binary files /dev/null and b/doctr/models/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/modules/layers/__init__.py b/doctr/models/modules/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/modules/layers/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/layers/__pycache__/__init__.cpython-311.pyc b/doctr/models/modules/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..040a5e7aa8b388feec5adceade93d36aa642f9df Binary files /dev/null and b/doctr/models/modules/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/modules/layers/__pycache__/__init__.cpython-38.pyc b/doctr/models/modules/layers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab1c6a9facc94bd54583e50ad891ac780c4ae90 Binary files /dev/null and b/doctr/models/modules/layers/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/modules/layers/__pycache__/pytorch.cpython-311.pyc b/doctr/models/modules/layers/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eafd742bea5d39bebbaa1215505eef02b61c6a0 Binary files /dev/null and b/doctr/models/modules/layers/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/modules/layers/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/modules/layers/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c0a13f400a49e358a3980f670fabb886e431a0 Binary files /dev/null and b/doctr/models/modules/layers/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/modules/layers/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/modules/layers/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d999ffca11cf0823c8b094bf811d50f657dbb82 Binary files /dev/null and b/doctr/models/modules/layers/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ad119ec9ba00caaaabaaf673b8d442014da258 --- /dev/null +++ b/doctr/models/modules/layers/pytorch.py @@ -0,0 +1,165 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["FASTConvLayer"] + + +class FASTConvLayer(nn.Module): + """Convolutional layer used in the TextNet and FAST architectures""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + ) -> None: + super().__init__() + + self.groups = groups + self.in_channels = in_channels + self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.hor_conv, self.hor_bn = None, None + self.ver_conv, self.ver_bn = None, None + + padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2)) + + self.activation = nn.ReLU(inplace=True) + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=self.converted_ks, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.bn = nn.BatchNorm2d(out_channels) + + if self.converted_ks[1] != 1: + self.ver_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(self.converted_ks[0], 1), + padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0), + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.ver_bn = nn.BatchNorm2d(out_channels) + + if self.converted_ks[0] != 1: + self.hor_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(1, self.converted_ks[1]), + padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)), + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.hor_bn = nn.BatchNorm2d(out_channels) + + self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "fused_conv"): + return self.activation(self.fused_conv(x)) + + main_outputs = self.bn(self.conv(x)) + vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0 + horizontal_outputs = ( + self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0 + ) + id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0 + + return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out) + + # The following logic is used to reparametrize the layer + # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py + def _identity_to_conv( + self, identity: Union[nn.BatchNorm2d, None] + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + if identity is None or identity.running_var is None: + return 0, 0 + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 0, 0] = 1 + id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device) + self.id_tensor = self._pad_to_mxn_tensor(id_tensor) + kernel = self.id_tensor + std = (identity.running_var + identity.eps).sqrt() + t = (identity.weight / std).reshape(-1, 1, 1, 1) + return kernel * t, identity.bias - identity.running_mean * identity.weight / std + + def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]: + kernel = conv.weight + kernel = self._pad_to_mxn_tensor(kernel) + std = (bn.running_var + bn.eps).sqrt() # type: ignore + t = (bn.weight / std).reshape(-1, 1, 1, 1) + return kernel * t, bn.bias - bn.running_mean * bn.weight / std + + def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn) + if self.ver_conv is not None: + kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type] + else: + kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment] + if self.hor_conv is not None: + kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type] + else: + kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment] + kernel_id, bias_id = self._identity_to_conv(self.rbr_identity) + kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id + bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id + return kernel_mxn, bias_mxn + + def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor: + kernel_height, kernel_width = self.converted_ks + height, width = kernel.shape[2:] + pad_left_right = (kernel_width - width) // 2 + pad_top_down = (kernel_height - height) // 2 + return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0) + + def reparameterize_layer(self): + if hasattr(self, "fused_conv"): + return + kernel, bias = self._get_equivalent_kernel_bias() + self.fused_conv = nn.Conv2d( + in_channels=self.conv.in_channels, + out_channels=self.conv.out_channels, + kernel_size=self.conv.kernel_size, # type: ignore[arg-type] + stride=self.conv.stride, # type: ignore[arg-type] + padding=self.conv.padding, # type: ignore[arg-type] + dilation=self.conv.dilation, # type: ignore[arg-type] + groups=self.conv.groups, + bias=True, + ) + self.fused_conv.weight.data = kernel + self.fused_conv.bias.data = bias # type: ignore[union-attr] + for para in self.parameters(): + para.detach_() + for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]: + if hasattr(self, attr): + self.__delattr__(attr) + + if hasattr(self, "rbr_identity"): + self.__delattr__("rbr_identity") diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..68849fbf6e5ad5f08a34ff90403dc44612b53834 --- /dev/null +++ b/doctr/models/modules/layers/tensorflow.py @@ -0,0 +1,173 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers + +from doctr.utils.repr import NestedObject + +__all__ = ["FASTConvLayer"] + + +class FASTConvLayer(layers.Layer, NestedObject): + """Convolutional layer used in the TextNet and FAST architectures""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + ) -> None: + super().__init__() + + self.groups = groups + self.in_channels = in_channels + self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.hor_conv, self.hor_bn = None, None + self.ver_conv, self.ver_bn = None, None + + padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2) + + self.activation = layers.ReLU() + self.conv_pad = layers.ZeroPadding2D(padding=padding) + + self.conv = layers.Conv2D( + filters=out_channels, + kernel_size=self.converted_ks, + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=bias, + ) + + self.bn = layers.BatchNormalization() + + if self.converted_ks[1] != 1: + self.ver_pad = layers.ZeroPadding2D( + padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0), + ) + self.ver_conv = layers.Conv2D( + filters=out_channels, + kernel_size=(self.converted_ks[0], 1), + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=bias, + ) + self.ver_bn = layers.BatchNormalization() + + if self.converted_ks[0] != 1: + self.hor_pad = layers.ZeroPadding2D( + padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)), + ) + self.hor_conv = layers.Conv2D( + filters=out_channels, + kernel_size=(1, self.converted_ks[1]), + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=bias, + ) + self.hor_bn = layers.BatchNormalization() + + self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None + + def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: + if hasattr(self, "fused_conv"): + return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs)) + + main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs) + vertical_outputs = ( + self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs) + if self.ver_conv is not None and self.ver_bn is not None + else 0 + ) + horizontal_outputs = ( + self.hor_bn(self.hor_conv(self.hor_pad(x, **kwargs), **kwargs), **kwargs) + if self.hor_bn is not None and self.hor_conv is not None + else 0 + ) + id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0 + + return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out) + + # The following logic is used to reparametrize the layer + # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py + def _identity_to_conv( + self, identity: layers.BatchNormalization + ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]: + if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"): + return 0, 0 + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32) + for i in range(self.in_channels): + kernel_value[0, 0, i % input_dim, i] = 1 + id_tensor = tf.constant(kernel_value, dtype=tf.float32) + self.id_tensor = self._pad_to_mxn_tensor(id_tensor) + kernel = self.id_tensor + std = tf.sqrt(identity.moving_variance + identity.epsilon) + t = tf.reshape(identity.gamma / std, (1, 1, 1, -1)) + return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std + + def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]: + kernel = conv.kernel + kernel = self._pad_to_mxn_tensor(kernel) + std = tf.sqrt(bn.moving_variance + bn.epsilon) + t = tf.reshape(bn.gamma / std, (1, 1, 1, -1)) + return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std + + def _get_equivalent_kernel_bias(self): + kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn) + if self.ver_conv is not None: + kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) + else: + kernel_mx1, bias_mx1 = 0, 0 + if self.hor_conv is not None: + kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) + else: + kernel_1xn, bias_1xn = 0, 0 + kernel_id, bias_id = self._identity_to_conv(self.rbr_identity) + kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id + bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id + return kernel_mxn, bias_mxn + + def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor: + kernel_height, kernel_width = self.converted_ks + height, width = kernel.shape[:2] + pad_left_right = tf.maximum(0, (kernel_width - width) // 2) + pad_top_down = tf.maximum(0, (kernel_height - height) // 2) + return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]]) + + def reparameterize_layer(self): + kernel, bias = self._get_equivalent_kernel_bias() + self.fused_conv = layers.Conv2D( + filters=self.conv.filters, + kernel_size=self.conv.kernel_size, + strides=self.conv.strides, + padding=self.conv.padding, + dilation_rate=self.conv.dilation_rate, + groups=self.conv.groups, + use_bias=True, + ) + # build layer to initialize weights and biases + self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2])) + self.fused_conv.set_weights([kernel.numpy(), bias.numpy()]) + for para in self.trainable_variables: + para._trainable = False + for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]: + if hasattr(self, attr): + delattr(self, attr) + + if hasattr(self, "rbr_identity"): + delattr(self, "rbr_identity") diff --git a/doctr/models/modules/transformer/__init__.py b/doctr/models/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/modules/transformer/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/transformer/__pycache__/__init__.cpython-311.pyc b/doctr/models/modules/transformer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89d09b571e758ff0150784623ac0eb162fce65a4 Binary files /dev/null and b/doctr/models/modules/transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/modules/transformer/__pycache__/__init__.cpython-38.pyc b/doctr/models/modules/transformer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0916dfb541fb1a2f985e4d532ac83f236c176f15 Binary files /dev/null and b/doctr/models/modules/transformer/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/modules/transformer/__pycache__/pytorch.cpython-311.pyc b/doctr/models/modules/transformer/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4574d4f5e2b131511a4dc453942fb50075bfd1d5 Binary files /dev/null and b/doctr/models/modules/transformer/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7abadeb09603c33f6e5be8ba2feb8ecaf020247 Binary files /dev/null and b/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b16a0b8d66cededd8e3e0874557f0c6787764f8 Binary files /dev/null and b/doctr/models/modules/transformer/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1c61297813b1a6b3d7124fd61dd3e37601dcfb --- /dev/null +++ b/doctr/models/modules/transformer/pytorch.py @@ -0,0 +1,202 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed + +import math +from typing import Any, Callable, Optional, Tuple + +import torch +from torch import nn + +__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "MultiHeadAttention", "PositionwiseFeedForward"] + + +class PositionalEncoding(nn.Module): + """Compute positional encoding""" + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe.unsqueeze(0)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass + + Args: + ---- + x: embeddings (batch, max_len, d_model) + + Returns + ------- + positional embeddings (batch, max_len, d_model) + """ + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +def scaled_dot_product_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """Scaled Dot-Product Attention""" + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) + if mask is not None: + # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition + scores = scores.masked_fill(mask == 0, float("-inf")) + p_attn = torch.softmax(scores, dim=-1) + return torch.matmul(p_attn, value), p_attn + + +class PositionwiseFeedForward(nn.Sequential): + """Position-wise Feed-Forward Network""" + + def __init__( + self, d_model: int, ffd: int, dropout: float = 0.1, activation_fct: Callable[[Any], Any] = nn.ReLU() + ) -> None: + super().__init__( # type: ignore[call-overload] + nn.Linear(d_model, ffd), + activation_fct, + nn.Dropout(p=dropout), + nn.Linear(ffd, d_model), + nn.Dropout(p=dropout), + ) + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention""" + + def __init__(self, num_heads: int, d_model: int, dropout: float = 0.1) -> None: + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_k = d_model // num_heads + self.num_heads = num_heads + + self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) + self.output_linear = nn.Linear(d_model, d_model) + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None) -> torch.Tensor: + batch_size = query.size(0) + + # linear projections of Q, K, V + query, key, value = [ + linear(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) + for linear, x in zip(self.linear_layers, (query, key, value)) + ] + + # apply attention on all the projected vectors in batch + x, attn = scaled_dot_product_attention(query, key, value, mask=mask) + + # Concat attention heads + x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) + + return self.output_linear(x) + + +class EncoderBlock(nn.Module): + """Transformer Encoder Block""" + + def __init__( + self, + num_layers: int, + num_heads: int, + d_model: int, + dff: int, # hidden dimension of the feedforward network + dropout: float, + activation_fct: Callable[[Any], Any] = nn.ReLU(), + ) -> None: + super().__init__() + + self.num_layers = num_layers + + self.layer_norm_input = nn.LayerNorm(d_model, eps=1e-5) + self.layer_norm_attention = nn.LayerNorm(d_model, eps=1e-5) + self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5) + self.dropout = nn.Dropout(dropout) + + self.attention = nn.ModuleList([ + MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers) + ]) + self.position_feed_forward = nn.ModuleList([ + PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers) + ]) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + output = x + + for i in range(self.num_layers): + normed_output = self.layer_norm_input(output) + output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output, mask)) + normed_output = self.layer_norm_attention(output) + output = output + self.dropout(self.position_feed_forward[i](normed_output)) + + # (batch_size, seq_len, d_model) + return self.layer_norm_output(output) + + +class Decoder(nn.Module): + """Transformer Decoder""" + + def __init__( + self, + num_layers: int, + num_heads: int, + d_model: int, + vocab_size: int, + dropout: float = 0.2, + dff: int = 2048, # hidden dimension of the feedforward network + maximum_position_encoding: int = 50, + ) -> None: + super(Decoder, self).__init__() + self.num_layers = num_layers + self.d_model = d_model + + self.layer_norm_input = nn.LayerNorm(d_model, eps=1e-5) + self.layer_norm_masked_attention = nn.LayerNorm(d_model, eps=1e-5) + self.layer_norm_attention = nn.LayerNorm(d_model, eps=1e-5) + self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5) + + self.dropout = nn.Dropout(dropout) + self.embed = nn.Embedding(vocab_size, d_model) + self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding) + + self.attention = nn.ModuleList([ + MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers) + ]) + self.source_attention = nn.ModuleList([ + MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers) + ]) + self.position_feed_forward = nn.ModuleList([ + PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers) + ]) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + source_mask: Optional[torch.Tensor] = None, + target_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + tgt = self.embed(tgt) * math.sqrt(self.d_model) + pos_enc_tgt = self.positional_encoding(tgt) + output = pos_enc_tgt + + for i in range(self.num_layers): + normed_output = self.layer_norm_input(output) + output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output, target_mask)) + normed_output = self.layer_norm_masked_attention(output) + output = output + self.dropout(self.source_attention[i](normed_output, memory, memory, source_mask)) + normed_output = self.layer_norm_attention(output) + output = output + self.dropout(self.position_feed_forward[i](normed_output)) + + # (batch_size, seq_len, d_model) + return self.layer_norm_output(output) diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..403f99117dccc90c051f9cddbc3a97baf0c94d7d --- /dev/null +++ b/doctr/models/modules/transformer/tensorflow.py @@ -0,0 +1,238 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Any, Callable, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers + +from doctr.utils.repr import NestedObject + +__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"] + +tf.config.run_functions_eagerly(True) + + +class PositionalEncoding(layers.Layer, NestedObject): + """Compute positional encoding""" + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: + super(PositionalEncoding, self).__init__() + self.dropout = layers.Dropout(rate=dropout) + + # Compute the positional encodings once in log space. + pe = tf.Variable(tf.zeros((max_len, d_model))) + position = tf.cast( + tf.expand_dims(tf.experimental.numpy.arange(start=0, stop=max_len), axis=1), dtype=tf.float32 + ) + div_term = tf.math.exp( + tf.cast(tf.experimental.numpy.arange(start=0, stop=d_model, step=2), dtype=tf.float32) + * -(math.log(10000.0) / d_model) + ) + pe = pe.numpy() + pe[:, 0::2] = tf.math.sin(position * div_term) + pe[:, 1::2] = tf.math.cos(position * div_term) + self.pe = tf.expand_dims(tf.convert_to_tensor(pe), axis=0) + + def call( + self, + x: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + """Forward pass + + Args: + ---- + x: embeddings (batch, max_len, d_model) + **kwargs: additional arguments + + Returns + ------- + positional embeddings (batch, max_len, d_model) + """ + if x.dtype == tf.float16: # amp fix: cast to half + x = x + tf.cast(self.pe[:, : x.shape[1]], dtype=tf.half) + else: + x = x + self.pe[:, : x.shape[1]] + return self.dropout(x, **kwargs) + + +@tf.function +def scaled_dot_product_attention( + query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None +) -> Tuple[tf.Tensor, tf.Tensor]: + """Scaled Dot-Product Attention""" + scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1]) + if mask is not None: + # NOTE: to ensure the ONNX compatibility, tf.where works only with bool type condition + scores = tf.where(mask == False, float("-inf"), scores) # noqa: E712 + p_attn = tf.nn.softmax(scores, axis=-1) + return tf.matmul(p_attn, value), p_attn + + +class PositionwiseFeedForward(layers.Layer, NestedObject): + """Position-wise Feed-Forward Network""" + + def __init__( + self, d_model: int, ffd: int, dropout=0.1, activation_fct: Callable[[Any], Any] = layers.ReLU() + ) -> None: + super(PositionwiseFeedForward, self).__init__() + self.activation_fct = activation_fct + + self.first_linear = layers.Dense(ffd, kernel_initializer=tf.initializers.he_uniform()) + self.sec_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + self.dropout = layers.Dropout(rate=dropout) + + def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: + x = self.first_linear(x, **kwargs) + x = self.activation_fct(x) + x = self.dropout(x, **kwargs) + x = self.sec_linear(x, **kwargs) + x = self.dropout(x, **kwargs) + return x + + +class MultiHeadAttention(layers.Layer, NestedObject): + """Multi-Head Attention""" + + def __init__(self, num_heads: int, d_model: int, dropout: float = 0.1) -> None: + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_k = d_model // num_heads + self.num_heads = num_heads + + self.linear_layers = [layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) for _ in range(3)] + self.output_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + + def call( + self, + query: tf.Tensor, + key: tf.Tensor, + value: tf.Tensor, + mask: tf.Tensor = None, + **kwargs: Any, + ) -> tf.Tensor: + batch_size = query.shape[0] + + # linear projections of Q, K, V + query, key, value = [ + tf.transpose( + tf.reshape(linear(x, **kwargs), shape=[batch_size, -1, self.num_heads, self.d_k]), perm=[0, 2, 1, 3] + ) + for linear, x in zip(self.linear_layers, (query, key, value)) + ] + + # apply attention on all the projected vectors in batch + x, attn = scaled_dot_product_attention(query, key, value, mask=mask) + + # Concat attention heads + x = tf.transpose(x, perm=[0, 2, 1, 3]) + x = tf.reshape(x, shape=[batch_size, -1, self.num_heads * self.d_k]) + + return self.output_linear(x, **kwargs) + + +class EncoderBlock(layers.Layer, NestedObject): + """Transformer Encoder Block""" + + def __init__( + self, + num_layers: int, + num_heads: int, + d_model: int, + dff: int, # hidden dimension of the feedforward network + dropout: float, + activation_fct: Callable[[Any], Any] = layers.ReLU(), + ) -> None: + super().__init__() + + self.num_layers = num_layers + + self.layer_norm_input = layers.LayerNormalization(epsilon=1e-5) + self.layer_norm_attention = layers.LayerNormalization(epsilon=1e-5) + self.layer_norm_output = layers.LayerNormalization(epsilon=1e-5) + self.dropout = layers.Dropout(rate=dropout) + + self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)] + self.position_feed_forward = [ + PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers) + ] + + def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor: + output = x + + for i in range(self.num_layers): + normed_output = self.layer_norm_input(output, **kwargs) + output = output + self.dropout( + self.attention[i](normed_output, normed_output, normed_output, mask, **kwargs), + **kwargs, + ) + normed_output = self.layer_norm_attention(output, **kwargs) + output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs) + + # (batch_size, seq_len, d_model) + return self.layer_norm_output(output, **kwargs) + + +class Decoder(layers.Layer, NestedObject): + """Transformer Decoder""" + + def __init__( + self, + num_layers: int, + num_heads: int, + d_model: int, + vocab_size: int, + dropout: float = 0.2, + dff: int = 2048, # hidden dimension of the feedforward network + maximum_position_encoding: int = 50, + ) -> None: + super(Decoder, self).__init__() + self.num_layers = num_layers + self.d_model = d_model + + self.layer_norm_input = layers.LayerNormalization(epsilon=1e-5) + self.layer_norm_masked_attention = layers.LayerNormalization(epsilon=1e-5) + self.layer_norm_attention = layers.LayerNormalization(epsilon=1e-5) + self.layer_norm_output = layers.LayerNormalization(epsilon=1e-5) + + self.dropout = layers.Dropout(rate=dropout) + self.embed = layers.Embedding(vocab_size, d_model) + self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding) + + self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)] + self.source_attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)] + self.position_feed_forward = [PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)] + + def call( + self, + tgt: tf.Tensor, + memory: tf.Tensor, + source_mask: Optional[tf.Tensor] = None, + target_mask: Optional[tf.Tensor] = None, + **kwargs: Any, + ) -> tf.Tensor: + tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model) + pos_enc_tgt = self.positional_encoding(tgt, **kwargs) + output = pos_enc_tgt + + for i in range(self.num_layers): + normed_output = self.layer_norm_input(output, **kwargs) + output = output + self.dropout( + self.attention[i](normed_output, normed_output, normed_output, target_mask, **kwargs), + **kwargs, + ) + normed_output = self.layer_norm_masked_attention(output, **kwargs) + output = output + self.dropout( + self.source_attention[i](normed_output, memory, memory, source_mask, **kwargs), + **kwargs, + ) + normed_output = self.layer_norm_attention(output, **kwargs) + output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs) + + # (batch_size, seq_len, d_model) + return self.layer_norm_output(output, **kwargs) diff --git a/doctr/models/modules/vision_transformer/__init__.py b/doctr/models/modules/vision_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/modules/vision_transformer/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-311.pyc b/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0664674505df773add817242a3061ad7fda759e Binary files /dev/null and b/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-38.pyc b/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d9777a6a3df7055bf84fffe24db6d3081338258 Binary files /dev/null and b/doctr/models/modules/vision_transformer/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/modules/vision_transformer/__pycache__/pytorch.cpython-311.pyc b/doctr/models/modules/vision_transformer/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a3de28fae96fa67f5be4b0155c86264fc91d29 Binary files /dev/null and b/doctr/models/modules/vision_transformer/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354d4b9ffb93ec8d9683da1179690a25e63360f3 Binary files /dev/null and b/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..499f2ea32f384bc51d0e2832740a0b112b200868 Binary files /dev/null and b/doctr/models/modules/vision_transformer/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/modules/vision_transformer/pytorch.py b/doctr/models/modules/vision_transformer/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff07ed4ffe4ea6fd769d87fae35c085d927e971 --- /dev/null +++ b/doctr/models/modules/vision_transformer/pytorch.py @@ -0,0 +1,84 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Tuple + +import torch +from torch import nn + +__all__ = ["PatchEmbedding"] + + +class PatchEmbedding(nn.Module): + """Compute 2D patch embeddings with cls token and positional encoding""" + + def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None: + super().__init__() + channels, height, width = input_shape + self.patch_size = patch_size + self.interpolate = True if patch_size[0] == patch_size[1] else False + self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) + self.projection = nn.Conv2d(channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """100 % borrowed from: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py + + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.positions.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.positions + class_pos_embed = self.positions[:, 0] + patch_pos_embed = self.positions[:, 1:] + dim = embeddings.shape[-1] + h0 = float(height // self.patch_size[0]) + w0 = float(width // self.patch_size[1]) + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bilinear", + align_corners=False, + recompute_scale_factor=True, + ) + assert int(h0) == patch_pos_embed.shape[-2], "height of interpolated patch embedding doesn't match" + assert int(w0) == patch_pos_embed.shape[-1], "width of interpolated patch embedding doesn't match" + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" + assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" + + # patchify image + patches = self.projection(x).flatten(2).transpose(1, 2) + + cls_tokens = self.cls_token.expand(B, -1, -1) # (batch_size, 1, d_model) + # concate cls_tokens to patches + embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model) + # add positions to embeddings + if self.interpolate: + embeddings += self.interpolate_pos_encoding(embeddings, H, W) + else: + embeddings += self.positions + + return embeddings # (batch_size, num_patches + 1, d_model) diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..a78f0da3fbdb035859bc4db4cb27b055e1330e4a --- /dev/null +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -0,0 +1,100 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Any, Tuple + +import tensorflow as tf +from tensorflow.keras import layers + +from doctr.utils.repr import NestedObject + +__all__ = ["PatchEmbedding"] + + +class PatchEmbedding(layers.Layer, NestedObject): + """Compute 2D patch embeddings with cls token and positional encoding""" + + def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None: + super().__init__() + height, width, _ = input_shape + self.patch_size = patch_size + self.interpolate = True if patch_size[0] == patch_size[1] else False + self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token") + self.positions = self.add_weight( + shape=(1, self.num_patches + 1, embed_dim), + initializer="zeros", + trainable=True, + name="positions", + ) + self.projection = layers.Conv2D( + filters=embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + data_format="channels_last", + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + name="projection", + ) + + def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: + """100 % borrowed from: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py + + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py + """ + seq_len, dim = embeddings.shape[1:] + num_patches = seq_len - 1 + + num_positions = self.positions.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.positions + class_pos_embed = self.positions[:, :1] + patch_pos_embed = self.positions[:, 1:] + h0 = height // self.patch_size[0] + w0 = width // self.patch_size[1] + patch_pos_embed = tf.image.resize( + images=tf.reshape( + patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ), + size=(h0, w0), + method="bilinear", + ) + + shape = patch_pos_embed.shape + assert h0 == shape[-3], "height of interpolated patch embedding doesn't match" + assert w0 == shape[-2], "width of interpolated patch embedding doesn't match" + + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + + def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: + B, H, W, C = x.shape + assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" + assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" + # patchify image + patches = self.projection(x, **kwargs) # (batch_size, num_patches, d_model) + patches = tf.reshape(patches, (B, self.num_patches, -1)) # (batch_size, num_patches, d_model) + + cls_tokens = tf.repeat(self.cls_token, B, axis=0) # (batch_size, 1, d_model) + # concate cls_tokens to patches + embeddings = tf.concat([cls_tokens, patches], axis=1) # (batch_size, num_patches + 1, d_model) + # add positions to embeddings + if self.interpolate: + embeddings += self.interpolate_pos_encoding(embeddings, H, W) + else: + embeddings += self.positions + + return embeddings # (batch_size, num_patches + 1, d_model) diff --git a/doctr/models/predictor/__init__.py b/doctr/models/predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd --- /dev/null +++ b/doctr/models/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/predictor/__pycache__/__init__.cpython-311.pyc b/doctr/models/predictor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6856b4e80e2cf490e6df12c2d4868502a04982d9 Binary files /dev/null and b/doctr/models/predictor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/predictor/__pycache__/__init__.cpython-38.pyc b/doctr/models/predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..377c103650a6207bef1fcb1f5c72a0c1a70ee81a Binary files /dev/null and b/doctr/models/predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/predictor/__pycache__/base.cpython-311.pyc b/doctr/models/predictor/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dcaf5bf8577c71d67234bc132a61766fe5d7cf1 Binary files /dev/null and b/doctr/models/predictor/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/predictor/__pycache__/base.cpython-38.pyc b/doctr/models/predictor/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e7398892c25fe6b47da9887fee5c79d6fbaf77 Binary files /dev/null and b/doctr/models/predictor/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/predictor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/predictor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d6ebb7f8a8d9227e85f1f41a21a46fef4684c9f Binary files /dev/null and b/doctr/models/predictor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/predictor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/predictor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11d96cd0397a277760f7917b683743501d640767 Binary files /dev/null and b/doctr/models/predictor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/predictor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/predictor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25350ee6e3a0b70b4390900b57f00d5565927235 Binary files /dev/null and b/doctr/models/predictor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0033b2cbf4d99261aaa048f63322135665fc0283 --- /dev/null +++ b/doctr/models/predictor/base.py @@ -0,0 +1,170 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +from doctr.models.builder import DocumentBuilder +from doctr.utils.geometry import extract_crops, extract_rcrops + +from .._utils import rectify_crops, rectify_loc_preds +from ..classification import crop_orientation_predictor +from ..classification.predictor import OrientationPredictor + +__all__ = ["_OCRPredictor"] + + +class _OCRPredictor: + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) + symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. + **kwargs: keyword args of `DocumentBuilder` + """ + + crop_orientation_predictor: Optional[OrientationPredictor] + + def __init__( + self, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + **kwargs: Any, + ) -> None: + self.assume_straight_pages = assume_straight_pages + self.straighten_pages = straighten_pages + self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) + self.doc_builder = DocumentBuilder(**kwargs) + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.hooks: List[Callable] = [] + + @staticmethod + def _generate_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> List[List[np.ndarray]]: + extraction_fn = extract_crops if assume_straight_pages else extract_rcrops + + crops = [ + extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] + for page, _boxes in zip(pages, loc_preds) + ] + return crops + + @staticmethod + def _prepare_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: + crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) + + # Avoid sending zero-sized crops + is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] + crops = [ + [crop for crop, _kept in zip(page_crops, page_kept) if _kept] + for page_crops, page_kept in zip(crops, is_kept) + ] + loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] + + return crops, loc_preds + + def _rectify_crops( + self, + crops: List[List[np.ndarray]], + loc_preds: List[np.ndarray], + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]: + # Work at a page level + orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc] + rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] + rect_loc_preds = [ + rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds + for page_loc_preds, orientation in zip(loc_preds, orientations) + ] + # Flatten to list of tuples with (value, confidence) + crop_orientations = [ + (orientation, prob) + for page_classes, page_probs in zip(classes, probs) + for orientation, prob in zip(page_classes, page_probs) + ] + return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value] + + def _remove_padding( + self, + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + ) -> List[np.ndarray]: + if self.preserve_aspect_ratio: + # Rectify loc_preds to remove padding + rectified_preds = [] + for page, loc_pred in zip(pages, loc_preds): + h, w = page.shape[0], page.shape[1] + if h > w: + # y unchanged, dilate x coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1) + else: + loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] *= h / w + else: + loc_pred[:, :, 0] *= h / w + elif w > h: + # x unchanged, dilate y coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1) + else: + loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] *= w / h + else: + loc_pred[:, :, 1] *= w / h + rectified_preds.append(loc_pred) + return rectified_preds + return loc_preds + + @staticmethod + def _process_predictions( + loc_preds: List[np.ndarray], + word_preds: List[Tuple[str, float]], + crop_orientations: List[Dict[str, Any]], + ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]: + text_preds = [] + crop_orientation_preds = [] + if len(loc_preds) > 0: + # Text & crop orientation predictions at page level + _idx = 0 + for page_boxes in loc_preds: + text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]]) + crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]]) + _idx += page_boxes.shape[0] + + return loc_preds, text_preds, crop_orientation_preds + + def add_hook(self, hook: Callable) -> None: + """Add a hook to the predictor + + Args: + ---- + hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` + """ + self.hooks.append(hook) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..91640371e49874eb0ab88c87675709e27b3641be --- /dev/null +++ b/doctr/models/predictor/pytorch.py @@ -0,0 +1,152 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_image + +from .base import _OCRPredictor + +__all__ = ["OCRPredictor"] + + +class OCRPredictor(nn.Module, _OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + **kwargs: keyword args of `DocumentBuilder` + """ + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + nn.Module.__init__(self) + self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] + self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + _OCRPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + @torch.inference_mode() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> Document: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages] + + # Localize text elements + loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) + + # Detect document rotation and rotate pages + seg_maps = [ + np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8) + for out_map in out_maps + ] + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(seq_map) for seq_map in seg_maps] + ) + pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # type: ignore[arg-type] + # Forward again to get predictions on straight pages + loc_preds = self.det_predictor(pages, **kwargs) + + assert all( + len(loc_pred) == 1 for loc_pred in loc_preds + ), "Detection Model in ocr_predictor should output only one class" + + loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds] + # Check whether crop mode should be switched to channels first + channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + + # Rectify crops if aspect ratio + loc_preds = self._remove_padding(pages, loc_preds) # type: ignore[arg-type] + + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) + + # Crop images + crops, loc_preds = self._prepare_crops( + pages, # type: ignore[arg-type] + loc_preds, + channels_last=channels_last, + assume_straight_pages=self.assume_straight_pages, + ) + # Rectify crop orientation and get crop orientation predictions + crop_orientations: Any = [] + if not self.assume_straight_pages: + crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds) + crop_orientations = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations + ] + + # Identify character sequences + word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) + if not crop_orientations: + crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds] + + boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations) + + if self.detect_language: + languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + + out = self.doc_builder( + pages, # type: ignore[arg-type] + boxes, + text_preds, + origin_page_shapes, # type: ignore[arg-type] + crop_orientations, + orientations, + languages_dict, + ) + return out diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f424b7b50b4527235d9a243a5981db4b3858b103 --- /dev/null +++ b/doctr/models/predictor/tensorflow.py @@ -0,0 +1,146 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import tensorflow as tf + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_image +from doctr.utils.repr import NestedObject + +from .base import _OCRPredictor + +__all__ = ["OCRPredictor"] + + +class OCRPredictor(NestedObject, _OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + ---- + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + **kwargs: keyword args of `DocumentBuilder` + """ + + _children_names = ["det_predictor", "reco_predictor", "doc_builder"] + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + self.det_predictor = det_predictor + self.reco_predictor = reco_predictor + _OCRPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> Document: + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] for page in pages] + + # Localize text elements + loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) + + # Detect document rotation and rotate pages + seg_maps = [ + np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8) + for out_map in out_maps + ] + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(seq_map) for seq_map in seg_maps] + ) + pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] + # forward again to get predictions on straight pages + loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment] + + assert all( + len(loc_pred) == 1 for loc_pred in loc_preds_dict + ), "Detection Model in ocr_predictor should output only one class" + loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr] + + # Rectify crops if aspect ratio + loc_preds = self._remove_padding(pages, loc_preds) + + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) + + # Crop images + crops, loc_preds = self._prepare_crops( + pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages + ) + # Rectify crop orientation and get crop orientation predictions + crop_orientations: Any = [] + if not self.assume_straight_pages: + crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds) + crop_orientations = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations + ] + + # Identify character sequences + word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) + if not crop_orientations: + crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds] + + boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations) + + if self.detect_language: + languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + + out = self.doc_builder( + pages, + boxes, + text_preds, + origin_page_shapes, # type: ignore[arg-type] + crop_orientations, + orientations, + languages_dict, + ) + return out diff --git a/doctr/models/preprocessor/__init__.py b/doctr/models/preprocessor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/preprocessor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/preprocessor/__pycache__/__init__.cpython-311.pyc b/doctr/models/preprocessor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e87c91940a23ba635b61e4618e34cea5797bed60 Binary files /dev/null and b/doctr/models/preprocessor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/preprocessor/__pycache__/__init__.cpython-38.pyc b/doctr/models/preprocessor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde84db40b648089e8e2894e95103a0c124441df Binary files /dev/null and b/doctr/models/preprocessor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/preprocessor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/preprocessor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a216d8c1a0fa75e6f59c73cc4599f366453ce035 Binary files /dev/null and b/doctr/models/preprocessor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/preprocessor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/preprocessor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43196239ade4cc343a562363f8315afb5311e841 Binary files /dev/null and b/doctr/models/preprocessor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/preprocessor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/preprocessor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a029ee0c6b48f8939f999fa9d026fb2d44d818c Binary files /dev/null and b/doctr/models/preprocessor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..58a236bd08a7314885b26b816209879d2457bc37 --- /dev/null +++ b/doctr/models/preprocessor/pytorch.py @@ -0,0 +1,128 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Any, List, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import functional as F +from torchvision.transforms import transforms as T + +from doctr.transforms import Resize +from doctr.utils.multithreading import multithread_exec + +__all__ = ["PreProcessor"] + + +class PreProcessor(nn.Module): + """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. + + Args: + ---- + output_size: expected size of each page in format (H, W) + batch_size: the size of page batches + mean: mean value of the training distribution by channel + std: standard deviation of the training distribution by channel + """ + + def __init__( + self, + output_size: Tuple[int, int], + batch_size: int, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (1.0, 1.0, 1.0), + **kwargs: Any, + ) -> None: + super().__init__() + self.batch_size = batch_size + self.resize: T.Resize = Resize(output_size, **kwargs) + # Perform the division by 255 at the same time + self.normalize = T.Normalize(mean, std) + + def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]: + """Gather samples into batches for inference purposes + + Args: + ---- + samples: list of samples of shape (C, H, W) + + Returns: + ------- + list of batched samples (*, C, H, W) + """ + num_batches = int(math.ceil(len(samples) / self.batch_size)) + batches = [ + torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0) + for idx in range(int(num_batches)) + ] + + return batches + + def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = torch.from_numpy(x.copy()).permute(2, 0, 1) + elif x.dtype not in (torch.uint8, torch.float16, torch.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Resizing + x = self.resize(x) + # Data type + if x.dtype == torch.uint8: + x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] + else: + x = x.to(dtype=torch.float32) # type: ignore[union-attr] + + return x + + def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]: + """Prepare document data for model forwarding + + Args: + ---- + x: list of images (np.array) or tensors (already resized and batched) + + Returns: + ------- + list of page batches + """ + # Input type check + if isinstance(x, (np.ndarray, torch.Tensor)): + if x.ndim != 4: + raise AssertionError("expected 4D Tensor") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) + elif x.dtype not in (torch.uint8, torch.float16, torch.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Resizing + if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: + x = F.resize( + x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias + ) + # Data type + if x.dtype == torch.uint8: # type: ignore[union-attr] + x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] + else: + x = x.to(dtype=torch.float32) # type: ignore[union-attr] + batches = [x] + + elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x): + # Sample transform (to tensor, resize) + samples = list(multithread_exec(self.sample_transforms, x)) + # Batching + batches = self.batch_inputs(samples) + else: + raise TypeError(f"invalid input type: {type(x)}") + + # Batch transforms (normalize) + batches = list(multithread_exec(self.normalize, batches)) + + return batches diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..431f95b11fff641bb99d2e8e13e2b74ce36b57e5 --- /dev/null +++ b/doctr/models/preprocessor/tensorflow.py @@ -0,0 +1,125 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Any, List, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.transforms import Normalize, Resize +from doctr.utils.multithreading import multithread_exec +from doctr.utils.repr import NestedObject + +__all__ = ["PreProcessor"] + + +class PreProcessor(NestedObject): + """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. + + Args: + ---- + output_size: expected size of each page in format (H, W) + batch_size: the size of page batches + mean: mean value of the training distribution by channel + std: standard deviation of the training distribution by channel + """ + + _children_names: List[str] = ["resize", "normalize"] + + def __init__( + self, + output_size: Tuple[int, int], + batch_size: int, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (1.0, 1.0, 1.0), + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.resize = Resize(output_size, **kwargs) + # Perform the division by 255 at the same time + self.normalize = Normalize(mean, std) + + def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: + """Gather samples into batches for inference purposes + + Args: + ---- + samples: list of samples (tf.Tensor) + + Returns: + ------- + list of batched samples + """ + num_batches = int(math.ceil(len(samples) / self.batch_size)) + batches = [ + tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0) + for idx in range(int(num_batches)) + ] + + return batches + + def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = tf.convert_to_tensor(x) + elif x.dtype not in (tf.uint8, tf.float16, tf.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Data type & 255 division + if x.dtype == tf.uint8: + x = tf.image.convert_image_dtype(x, dtype=tf.float32) + # Resizing + x = self.resize(x) + + return x + + def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]: + """Prepare document data for model forwarding + + Args: + ---- + x: list of images (np.array) or tensors (already resized and batched) + + Returns: + ------- + list of page batches + """ + # Input type check + if isinstance(x, (np.ndarray, tf.Tensor)): + if x.ndim != 4: + raise AssertionError("expected 4D Tensor") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = tf.convert_to_tensor(x) + elif x.dtype not in (tf.uint8, tf.float16, tf.float32): + raise TypeError("unsupported data type for torch.Tensor") + + # Data type & 255 division + if x.dtype == tf.uint8: + x = tf.image.convert_image_dtype(x, dtype=tf.float32) + # Resizing + if (x.shape[1], x.shape[2]) != self.resize.output_size: + x = tf.image.resize( + x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias + ) + + batches = [x] + + elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x): + # Sample transform (to tensor, resize) + samples = list(multithread_exec(self.sample_transforms, x)) + # Batching + batches = self.batch_inputs(samples) + else: + raise TypeError(f"invalid input type: {type(x)}") + + # Batch transforms (normalize) + batches = list(multithread_exec(self.normalize, batches)) + + return batches diff --git a/doctr/models/recognition/__init__.py b/doctr/models/recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2f723a722dc7c8fab2477fa7aaf2e83afdc8e6 --- /dev/null +++ b/doctr/models/recognition/__init__.py @@ -0,0 +1,6 @@ +from .crnn import * +from .master import * +from .sar import * +from .vitstr import * +from .parseq import * +from .zoo import * diff --git a/doctr/models/recognition/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..253b88573670fbb9c41a4ef9aaad640557e6c00c Binary files /dev/null and b/doctr/models/recognition/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46f3f43e8636d521483619a26ee6826d5feaed22 Binary files /dev/null and b/doctr/models/recognition/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/__pycache__/core.cpython-311.pyc b/doctr/models/recognition/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c80eeec71dc40544ae94b451087bdcbde51c8fff Binary files /dev/null and b/doctr/models/recognition/__pycache__/core.cpython-311.pyc differ diff --git a/doctr/models/recognition/__pycache__/core.cpython-38.pyc b/doctr/models/recognition/__pycache__/core.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed32d4d3dee5a9e38bf4948b1b4f1d0e680d45d Binary files /dev/null and b/doctr/models/recognition/__pycache__/core.cpython-38.pyc differ diff --git a/doctr/models/recognition/__pycache__/utils.cpython-311.pyc b/doctr/models/recognition/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecc39de4a3d74d0c8eb597087dcc78b69bc0cc17 Binary files /dev/null and b/doctr/models/recognition/__pycache__/utils.cpython-311.pyc differ diff --git a/doctr/models/recognition/__pycache__/utils.cpython-38.pyc b/doctr/models/recognition/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..660aef05c5a2a87cdb159e8a13433b295373fa7d Binary files /dev/null and b/doctr/models/recognition/__pycache__/utils.cpython-38.pyc differ diff --git a/doctr/models/recognition/__pycache__/zoo.cpython-311.pyc b/doctr/models/recognition/__pycache__/zoo.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e83d588ebc41f1f2d0f32c3fa133d39bac95c1a8 Binary files /dev/null and b/doctr/models/recognition/__pycache__/zoo.cpython-311.pyc differ diff --git a/doctr/models/recognition/__pycache__/zoo.cpython-38.pyc b/doctr/models/recognition/__pycache__/zoo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a666c63df5aad6a1d94bb3305e6bb9f02eaa6e1c Binary files /dev/null and b/doctr/models/recognition/__pycache__/zoo.cpython-38.pyc differ diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py new file mode 100644 index 0000000000000000000000000000000000000000..ab82218cce1fa62325013606c02587e5617650c8 --- /dev/null +++ b/doctr/models/recognition/core.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from doctr.datasets import encode_sequences +from doctr.utils.repr import NestedObject + +__all__ = ["RecognitionPostProcessor", "RecognitionModel"] + + +class RecognitionModel(NestedObject): + """Implements abstract RecognitionModel class""" + + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + ---- + gts: list of ground-truth labels + + Returns: + ------- + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab)) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class RecognitionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + self.vocab = vocab + self._embedding = list(self.vocab) + [""] + + def extra_repr(self) -> str: + return f"vocab_size={len(self.vocab)}" diff --git a/doctr/models/recognition/crnn/__init__.py b/doctr/models/recognition/crnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/recognition/crnn/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/crnn/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/crnn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1fc9489a32e4e6c01a47531d96a7a15b403e92 Binary files /dev/null and b/doctr/models/recognition/crnn/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/crnn/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/crnn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8357dade06d33d78fb13d3fe6f9de41d806fc8 Binary files /dev/null and b/doctr/models/recognition/crnn/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/crnn/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/crnn/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82664e23a044d05881f6ddf70ff8be8841a86873 Binary files /dev/null and b/doctr/models/recognition/crnn/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8148e472c8b0272dbadd581125c6fc96ed23e5 Binary files /dev/null and b/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f95b1c15f0448a886cd591dcd1d89a640c6476 Binary files /dev/null and b/doctr/models/recognition/crnn/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4b891f9a236b2e138dbae6ae77f5e0a6799c47 --- /dev/null +++ b/doctr/models/recognition/crnn/pytorch.py @@ -0,0 +1,339 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from itertools import groupby +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from doctr.datasets import VOCABS, decode_sequence + +from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r +from ...utils.pytorch import load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "crnn_vgg16_bn": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["legacy_french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0", + }, + "crnn_mobilenet_v3_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small_pt-3b919a02.pt&src=0", + }, + "crnn_mobilenet_v3_large": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_large_pt-f5259ec2.pt&src=0", + }, +} + + +class CTCPostProcessor(RecognitionPostProcessor): + """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + @staticmethod + def ctc_best_path( + logits: torch.Tensor, + vocab: str = VOCABS["french"], + blank: int = 0, + ) -> List[Tuple[str, float]]: + """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from + `_. + + Args: + ---- + logits: model output, shape: N x T x C + vocab: vocabulary to use + blank: index of blank label + + Returns: + ------- + A list of tuples: (word, confidence) + """ + # Gather the most confident characters, and assign the smallest conf among those to the sequence prob + probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values + + # collapse best path (using itertools.groupby), map to chars, join char list to string + words = [ + decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) + for seq in torch.argmax(logits, dim=-1) + ] + + return list(zip(words, probs.tolist())) + + def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]: + """Performs decoding of raw output with CTC and decoding of CTC predictions + with label_to_idx mapping dictionnary + + Args: + ---- + logits: raw output of the model, shape (N, C + 1, seq_len) + + Returns: + ------- + A tuple of 2 lists: a list of str (words) and a list of float (probs) + + """ + # Decode CTC + return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab)) + + +class CRNN(RecognitionModel, nn.Module): + """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of units in the LSTM layers + exportable: onnx exportable returns only logits + cfg: configuration dictionary + """ + + _children_names: List[str] = ["feat_extractor", "decoder", "linear", "postprocessor"] + + def __init__( + self, + feature_extractor: nn.Module, + vocab: str, + rnn_units: int = 128, + input_shape: Tuple[int, int, int] = (3, 32, 128), + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.cfg = cfg + self.max_length = 32 + self.exportable = exportable + self.feat_extractor = feature_extractor + + # Resolve the input_size of the LSTM + with torch.inference_mode(): + out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape + lstm_in = out_shape[1] * out_shape[2] + + self.decoder = nn.LSTM( + input_size=lstm_in, + hidden_size=rnn_units, + batch_first=True, + num_layers=2, + bidirectional=True, + ) + + # features units = 2 * rnn_units because bidirectional layers + self.linear = nn.Linear(in_features=2 * rnn_units, out_features=len(vocab) + 1) + + self.postprocessor = CTCPostProcessor(vocab=vocab) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def compute_loss( + self, + model_output: torch.Tensor, + target: List[str], + ) -> torch.Tensor: + """Compute CTC loss for the model. + + Args: + ---- + model_output: predicted logits of the model + target: list of target strings + + Returns: + ------- + The loss of the model on the batch + """ + gt, seq_len = self.build_target(target) + batch_len = model_output.shape[0] + input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32) + # N x T x C -> T x N x C + logits = model_output.permute(1, 0, 2) + probs = F.log_softmax(logits, dim=-1) + ctc_loss = F.ctc_loss( + probs, + torch.from_numpy(gt), + input_length, + torch.tensor(seq_len, dtype=torch.int), + len(self.vocab), + zero_infinity=True, + ) + + return ctc_loss + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + if self.training and target is None: + raise ValueError("Need to provide labels during training") + + features = self.feat_extractor(x) + # B x C x H x W --> B x C*H x W --> B x W x C*H + c, h, w = features.shape[1], features.shape[2], features.shape[3] + features_seq = torch.reshape(features, shape=(-1, h * c, w)) + features_seq = torch.transpose(features_seq, 1, 2) + logits, _ = self.decoder(features_seq) + logits = self.linear(logits) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits) + + if target is not None: + out["loss"] = self.compute_loss(logits, target) + + return out + + +def _crnn( + arch: str, + pretrained: bool, + backbone_fn: Callable[[Any], nn.Module], + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> CRNN: + pretrained_backbone = pretrained_backbone and not pretrained + + # Feature extractor + feat_extractor = backbone_fn(pretrained=pretrained_backbone).features # type: ignore[call-arg] + + kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"]) + kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs["vocab"] + _cfg["input_shape"] = kwargs["input_shape"] + + # Build the model + model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys) + + return model + + +def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import torch + >>> from doctr.models import crnn_vgg16_bn + >>> model = crnn_vgg16_bn(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs) + + +def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import torch + >>> from doctr.models import crnn_mobilenet_v3_small + >>> model = crnn_mobilenet_v3_small(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn( + "crnn_mobilenet_v3_small", + pretrained, + mobilenet_v3_small_r, + ignore_keys=["linear.weight", "linear.bias"], + **kwargs, + ) + + +def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import torch + >>> from doctr.models import crnn_mobilenet_v3_large + >>> model = crnn_mobilenet_v3_large(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn( + "crnn_mobilenet_v3_large", + pretrained, + mobilenet_v3_large_r, + ignore_keys=["linear.weight", "linear.bias"], + **kwargs, + ) diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec48c4f0e90be4100b5dc700eb661df957a8a25 --- /dev/null +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -0,0 +1,318 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Model, Sequential + +from doctr.datasets import VOCABS + +from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "crnn_vgg16_bn": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["legacy_french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0", + }, + "crnn_mobilenet_v3_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0", + }, + "crnn_mobilenet_v3_large": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0", + }, +} + + +class CTCPostProcessor(RecognitionPostProcessor): + """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + ignore_case: if True, ignore case of letters + ignore_accents: if True, ignore accents of letters + """ + + def __call__( + self, + logits: tf.Tensor, + beam_width: int = 1, + top_paths: int = 1, + ) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]: + """Performs decoding of raw output with CTC and decoding of CTC predictions + with label_to_idx mapping dictionnary + + Args: + ---- + logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1 + beam_width: An int scalar >= 0 (beam search beam width). + top_paths: An int scalar >= 0, <= beam_width (controls output size). + + Returns: + ------- + A list of decoded words of length BATCH_SIZE + + + """ + # Decode CTC + _decoded, _log_prob = tf.nn.ctc_beam_search_decoder( + tf.transpose(logits, perm=[1, 0, 2]), + tf.fill(tf.shape(logits)[:1], tf.shape(logits)[1]), + beam_width=beam_width, + top_paths=top_paths, + ) + + _decoded = tf.sparse.concat( + 1, + [tf.sparse.expand_dims(dec, axis=1) for dec in _decoded], + expand_nonconcat_dims=True, + ) # dim : batchsize x beamwidth x actual_max_len_predictions + out_idxs = tf.sparse.to_dense(_decoded, default_value=len(self.vocab)) + + # Map it to characters + _decoded_strings_pred = tf.strings.reduce_join( + inputs=tf.nn.embedding_lookup(tf.constant(self._embedding, dtype=tf.string), out_idxs), + axis=-1, + ) + _decoded_strings_pred = tf.strings.split(_decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(_decoded_strings_pred.to_sparse(), default_value="not valid")[ + :, :, 0 + ] # dim : batch_size x beam_width + + if top_paths == 1: + probs = tf.math.exp(tf.squeeze(_log_prob, axis=1)) # dim : batchsize + decoded_strings_pred = tf.squeeze(decoded_strings_pred, axis=1) + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + else: + probs = tf.math.exp(_log_prob) # dim : batchsize x beamwidth + word_values = [[word.decode() for word in words] for words in decoded_strings_pred.numpy().tolist()] + return list(zip(word_values, probs.numpy().tolist())) + + +class CRNN(RecognitionModel, Model): + """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of units in the LSTM layers + exportable: onnx exportable returns only logits + beam_width: beam width for beam search decoding + top_paths: number of top paths for beam search decoding + cfg: configuration dictionary + """ + + _children_names: List[str] = ["feat_extractor", "decoder", "postprocessor"] + + def __init__( + self, + feature_extractor: tf.keras.Model, + vocab: str, + rnn_units: int = 128, + exportable: bool = False, + beam_width: int = 1, + top_paths: int = 1, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + # Initialize kernels + h, w, c = feature_extractor.output_shape[1:] + + super().__init__() + self.vocab = vocab + self.max_length = w + self.cfg = cfg + self.exportable = exportable + self.feat_extractor = feature_extractor + + self.decoder = Sequential([ + layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)), + layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)), + layers.Dense(units=len(vocab) + 1), + ]) + self.decoder.build(input_shape=(None, w, h * c)) + + self.postprocessor = CTCPostProcessor(vocab=vocab) + + self.beam_width = beam_width + self.top_paths = top_paths + + def compute_loss( + self, + model_output: tf.Tensor, + target: List[str], + ) -> tf.Tensor: + """Compute CTC loss for the model. + + Args: + ---- + model_output: predicted logits of the model + target: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + gt, seq_len = self.build_target(target) + batch_len = model_output.shape[0] + input_length = tf.fill((batch_len,), model_output.shape[1]) + ctc_loss = tf.nn.ctc_loss( + gt, model_output, seq_len, input_length, logits_time_major=False, blank_index=len(self.vocab) + ) + return ctc_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + beam_width: int = 1, + top_paths: int = 1, + **kwargs: Any, + ) -> Dict[str, Any]: + if kwargs.get("training", False) and target is None: + raise ValueError("Need to provide labels during training") + + features = self.feat_extractor(x, **kwargs) + # B x H x W x C --> B x W x H x C + transposed_feat = tf.transpose(features, perm=[0, 2, 1, 3]) + w, h, c = transposed_feat.get_shape().as_list()[1:] + # B x W x H x C --> B x W x H * C + features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c)) + logits = _bf16_to_float32(self.decoder(features_seq, **kwargs)) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits, beam_width=beam_width, top_paths=top_paths) + + if target is not None: + out["loss"] = self.compute_loss(logits, target) + + return out + + +def _crnn( + arch: str, + pretrained: bool, + backbone_fn, + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> CRNN: + pretrained_backbone = pretrained_backbone and not pretrained + + kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs["vocab"] + _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"] + + feat_extractor = backbone_fn( + input_shape=_cfg["input_shape"], + include_top=False, + pretrained=pretrained_backbone, + ) + + # Build the model + model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg["url"]) + + return model + + +def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import tensorflow as tf + >>> from doctr.models import crnn_vgg16_bn + >>> model = crnn_vgg16_bn(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs) + + +def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import tensorflow as tf + >>> from doctr.models import crnn_mobilenet_v3_small + >>> model = crnn_mobilenet_v3_small(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs) + + +def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + >>> import tensorflow as tf + >>> from doctr.models import crnn_mobilenet_v3_large + >>> model = crnn_mobilenet_v3_large(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the CRNN architecture + + Returns: + ------- + text recognition architecture + """ + return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs) diff --git a/doctr/models/recognition/master/__init__.py b/doctr/models/recognition/master/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/recognition/master/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/master/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/master/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f54d8d5c0e24d8a72cf0d67cceb2230b1e966c Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/master/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f5826edd3b9934417a14e6073793fa698e63676 Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/base.cpython-311.pyc b/doctr/models/recognition/master/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c571b817622960619abca5b8c8f495ea2e9a27 Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/base.cpython-38.pyc b/doctr/models/recognition/master/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45adfdbbe4a54a76bdb3d65f2efc86be652830f7 Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/master/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cb56cbecb84e1f10be1297bf08256673122be8a Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/master/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60fda3ff45c32b08a3e9180bbfef9860c8e208d8 Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/master/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/master/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90f2254214e74edea3d04f02728f247ea298600e Binary files /dev/null and b/doctr/models/recognition/master/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3002893ee26b4aa41838e41e63cd3f17c35779 --- /dev/null +++ b/doctr/models/recognition/master/base.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from ....datasets import encode_sequences +from ..core import RecognitionPostProcessor + + +class _MASTER: + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + ---- + gts: list of ground-truth labels + + Returns: + ------- + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences( + sequences=gts, + vocab=self.vocab, + target_size=self.max_length, + eos=len(self.vocab), + sos=len(self.vocab) + 1, + pad=len(self.vocab) + 2, + ) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class _MASTERPostProcessor(RecognitionPostProcessor): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + [""] + [""] + [""] diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..875fcbd687d6e644623632eba39da6f57f26edf0 --- /dev/null +++ b/doctr/models/recognition/master/pytorch.py @@ -0,0 +1,338 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS +from doctr.models.classification import magc_resnet31 +from doctr.models.modules.transformer import Decoder, PositionalEncoding + +from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from .base import _MASTER, _MASTERPostProcessor + +__all__ = ["MASTER", "master"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "master": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0", + }, +} + + +class MASTER(_MASTER, nn.Module): + """Implements MASTER as described in paper: `_. + Implementation based on the official Pytorch implementation: `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary, (without EOS, SOS, PAD) + d_model: d parameter for the transformer decoder + dff: depth of the pointwise feed-forward layer + num_heads: number of heads for the mutli-head attention module + num_layers: number of decoder layers to stack + max_length: maximum length of character sequence handled by the model + dropout: dropout probability of the decoder + input_shape: size of the image inputs + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + def __init__( + self, + feature_extractor: nn.Module, + vocab: str, + d_model: int = 512, + dff: int = 2048, + num_heads: int = 8, # number of heads in the transformer decoder + num_layers: int = 3, + max_length: int = 50, + dropout: float = 0.2, + input_shape: Tuple[int, int, int] = (3, 32, 128), # different from the paper + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.exportable = exportable + self.max_length = max_length + self.d_model = d_model + self.vocab = vocab + self.cfg = cfg + self.vocab_size = len(vocab) + + self.feat_extractor = feature_extractor + self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[1] * input_shape[2]) + + self.decoder = Decoder( + num_layers=num_layers, + d_model=self.d_model, + num_heads=num_heads, + vocab_size=self.vocab_size + 3, # EOS, SOS, PAD + dff=dff, + dropout=dropout, + maximum_position_encoding=self.max_length, + ) + + self.linear = nn.Linear(self.d_model, self.vocab_size + 3) + self.postprocessor = MASTERPostProcessor(vocab=self.vocab) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def make_source_and_target_mask( + self, source: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch + # NOTE: nn.TransformerDecoder takes the inverse from this implementation + # [True, True, True, ..., False, False, False] -> False is masked + # (N, 1, 1, max_length) + target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) + target_length = target.size(1) + # sub mask filled diagonal with True = see and False = masked (max_length, max_length) + # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14) + target_sub_mask = torch.tril(torch.ones((target_length, target_length), device=source.device), diagonal=0).to( + dtype=torch.bool + ) + # source mask filled with ones (max_length, positional_encoded_seq_len) + source_mask = torch.ones((target_length, source.size(1)), dtype=torch.uint8, device=source.device) + # combine the two masks into one (N, 1, max_length, max_length) + target_mask = target_pad_mask & target_sub_mask + return source_mask, target_mask.int() + + @staticmethod + def compute_loss( + model_output: torch.Tensor, + gt: torch.Tensor, + seq_len: torch.Tensor, + ) -> torch.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = model_output.shape[1] + # Add one for additional token (sos disappear in shift!) + seq_len = seq_len + 1 + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . Delete last logit of the model output. + cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none") + # Compute mask, remove 1 timestep here as well + mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None] + cce[mask_2d] = 0 + + ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + return ce_loss.mean() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + """Call function for training + + Args: + ---- + x: images + target: list of str labels + return_model_output: if True, return logits + return_preds: if True, decode logits + + Returns: + ------- + A dictionnary containing eventually loss, logits and predictions. + """ + # Encode + features = self.feat_extractor(x)["features"] + b, c, h, w = features.shape + # (N, C, H, W) --> (N, H * W, C) + features = features.view(b, c, h * w).permute((0, 2, 1)) + # add positional encoding to features + encoded = self.positional_encoding(features) + + out: Dict[str, Any] = {} + + if self.training and target is None: + raise ValueError("Need to provide labels during training") + + if target is not None: + # Compute target: tensor of gts and sequence lengths + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) + gt, seq_len = gt.to(x.device), seq_len.to(x.device) + + # Compute source mask and target mask + source_mask, target_mask = self.make_source_and_target_mask(encoded, gt) + output = self.decoder(gt, encoded, source_mask, target_mask) + # Compute logits + logits = self.linear(output) + else: + logits = self.decode(encoded) + + logits = _bf16_to_float32(logits) + + if self.exportable: + out["logits"] = logits + return out + + if target is not None: + out["loss"] = self.compute_loss(logits, gt, seq_len) + + if return_model_output: + out["out_map"] = logits + + if return_preds: + out["preds"] = self.postprocessor(logits) + + return out + + def decode(self, encoded: torch.Tensor) -> torch.Tensor: + """Decode function for prediction + + Args: + ---- + encoded: input tensor + + Returns: + ------- + A Tuple of torch.Tensor: predictions, logits + """ + b = encoded.size(0) + + # Padding symbol + SOS at the beginning + ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device) # pad + ys[:, 0] = self.vocab_size + 1 # sos + + # Final dimension include EOS/SOS/PAD + for i in range(self.max_length - 1): + source_mask, target_mask = self.make_source_and_target_mask(encoded, ys) + output = self.decoder(ys, encoded, source_mask, target_mask) + logits = self.linear(output) + prob = torch.softmax(logits, dim=-1) + next_token = torch.max(prob, dim=-1).indices + # update ys with the next token and ignore the first token (SOS) + ys[:, i + 1] = next_token[:, i] + + # Shape (N, max_length, vocab_size + 1) + return logits + + +class MASTERPostProcessor(_MASTERPostProcessor): + """Post processor for MASTER architectures""" + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + # N x L + probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) + # Take the minimum confidence of the sequence + probs = probs.min(dim=1).values.detach().cpu() + + # Manual decoding + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.cpu().numpy() + ] + + return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + + +def _master( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> MASTER: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + feat_extractor = IntermediateLayerGetter( + backbone_fn(pretrained_backbone), + {layer: "features"}, + ) + model = MASTER(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def master(pretrained: bool = False, **kwargs: Any) -> MASTER: + """MASTER as described in paper: `_. + + >>> import torch + >>> from doctr.models import master + >>> model = master(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keywoard arguments passed to the MASTER architecture + + Returns: + ------- + text recognition architecture + """ + return _master( + "master", + pretrained, + magc_resnet31, + "10", + ignore_keys=[ + "decoder.embed.weight", + "linear.weight", + "linear.bias", + ], + **kwargs, + ) diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ecadcc15b071c781d38280b0904901ba365619 --- /dev/null +++ b/doctr/models/recognition/master/tensorflow.py @@ -0,0 +1,318 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Model, layers + +from doctr.datasets import VOCABS +from doctr.models.classification import magc_resnet31 +from doctr.models.modules.transformer import Decoder, PositionalEncoding + +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from .base import _MASTER, _MASTERPostProcessor + +__all__ = ["MASTER", "master"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "master": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0", + }, +} + + +class MASTER(_MASTER, Model): + """Implements MASTER as described in paper: `_. + Implementation based on the official TF implementation: `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary, (without EOS, SOS, PAD) + d_model: d parameter for the transformer decoder + dff: depth of the pointwise feed-forward layer + num_heads: number of heads for the mutli-head attention module + num_layers: number of decoder layers to stack + max_length: maximum length of character sequence handled by the model + dropout: dropout probability of the decoder + input_shape: size of the image inputs + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + def __init__( + self, + feature_extractor: tf.keras.Model, + vocab: str, + d_model: int = 512, + dff: int = 2048, + num_heads: int = 8, # number of heads in the transformer decoder + num_layers: int = 3, + max_length: int = 50, + dropout: float = 0.2, + input_shape: Tuple[int, int, int] = (32, 128, 3), # different from the paper + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.exportable = exportable + self.max_length = max_length + self.d_model = d_model + self.vocab = vocab + self.cfg = cfg + self.vocab_size = len(vocab) + + self.feat_extractor = feature_extractor + self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[0] * input_shape[1]) + + self.decoder = Decoder( + num_layers=num_layers, + d_model=self.d_model, + num_heads=num_heads, + vocab_size=self.vocab_size + 3, # EOS, SOS, PAD + dff=dff, + dropout=dropout, + maximum_position_encoding=self.max_length, + ) + + self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform()) + self.postprocessor = MASTERPostProcessor(vocab=self.vocab) + + @tf.function + def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked + # (N, 1, 1, max_length) + target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8) + target_pad_mask = target_pad_mask[:, tf.newaxis, tf.newaxis, :] + target_length = target.shape[1] + # sub mask filled diagonal with 1 = see 0 = masked (max_length, max_length) + target_sub_mask = tf.linalg.band_part(tf.ones((target_length, target_length)), -1, 0) + # source mask filled with ones (max_length, positional_encoded_seq_len) + source_mask = tf.ones((target_length, source.shape[1])) + # combine the two masks into one boolean mask where False is masked (N, 1, max_length, max_length) + target_mask = tf.math.logical_and( + tf.cast(target_sub_mask, dtype=tf.bool), tf.cast(target_pad_mask, dtype=tf.bool) + ) + return source_mask, target_mask + + @staticmethod + def compute_loss( + model_output: tf.Tensor, + gt: tf.Tensor, + seq_len: List[int], + ) -> tf.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = tf.shape(model_output)[1] + # Add one for additional token (sos disappear in shift!) + seq_len = tf.cast(seq_len, tf.int32) + 1 + # One-hot gt labels + oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . Delete last logit of the model output. + cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :]) + # Compute mask + mask_values = tf.zeros_like(cce) + mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well + masked_loss = tf.where(mask_2d, cce, mask_values) + ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) + + return tf.expand_dims(ce_loss, axis=1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + """Call function for training + + Args: + ---- + x: images + target: list of str labels + return_model_output: if True, return logits + return_preds: if True, decode logits + **kwargs: keyword arguments passed to the decoder + + Returns: + ------- + A dictionnary containing eventually loss, logits and predictions. + """ + # Encode + feature = self.feat_extractor(x, **kwargs) + b, h, w, c = feature.get_shape() + # (N, H, W, C) --> (N, H * W, C) + feature = tf.reshape(feature, shape=(b, h * w, c)) + # add positional encoding to features + encoded = self.positional_encoding(feature, **kwargs) + + out: Dict[str, tf.Tensor] = {} + + if kwargs.get("training", False) and target is None: + raise ValueError("Need to provide labels during training") + + if target is not None: + # Compute target: tensor of gts and sequence lengths + gt, seq_len = self.build_target(target) + # Compute decoder masks + source_mask, target_mask = self.make_source_and_target_mask(encoded, gt) + # Compute logits + output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs) + logits = self.linear(output, **kwargs) + else: + logits = self.decode(encoded, **kwargs) + + logits = _bf16_to_float32(logits) + + if self.exportable: + out["logits"] = logits + return out + + if target is not None: + out["loss"] = self.compute_loss(logits, gt, seq_len) + + if return_model_output: + out["out_map"] = logits + + if return_preds: + out["preds"] = self.postprocessor(logits) + + return out + + @tf.function + def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor: + """Decode function for prediction + + Args: + ---- + encoded: encoded features + **kwargs: keyword arguments passed to the decoder + + Returns: + ------- + A Tuple of tf.Tensor: predictions, logits + """ + b = encoded.shape[0] + + start_symbol = tf.constant(self.vocab_size + 1, dtype=tf.int32) # SOS + padding_symbol = tf.constant(self.vocab_size + 2, dtype=tf.int32) # PAD + + ys = tf.fill(dims=(b, self.max_length - 1), value=padding_symbol) + start_vector = tf.fill(dims=(b, 1), value=start_symbol) + ys = tf.concat([start_vector, ys], axis=-1) + + # Final dimension include EOS/SOS/PAD + for i in range(self.max_length - 1): + source_mask, target_mask = self.make_source_and_target_mask(encoded, ys) + output = self.decoder(ys, encoded, source_mask, target_mask, **kwargs) + logits = self.linear(output, **kwargs) + prob = tf.nn.softmax(logits, axis=-1) + next_token = tf.argmax(prob, axis=-1, output_type=ys.dtype) + # update ys with the next token and ignore the first token (SOS) + i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(self.max_length), indexing="ij") + indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1) + + ys = tf.tensor_scatter_nd_update(ys, indices, next_token[:, i]) + + # Shape (N, max_length, vocab_size + 1) + return logits + + +class MASTERPostProcessor(_MASTERPostProcessor): + """Post processor for MASTER architectures + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + # N x L + probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) + # Take the minimum confidence of the sequence + probs = tf.math.reduce_min(probs, axis=1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype="int32") + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + + +def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + model = MASTER( + backbone_fn(pretrained=pretrained_backbone, input_shape=_cfg["input_shape"], include_top=False), + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def master(pretrained: bool = False, **kwargs: Any) -> MASTER: + """MASTER as described in paper: `_. + + >>> import tensorflow as tf + >>> from doctr.models import master + >>> model = master(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keywoard arguments passed to the MASTER architecture + + Returns: + ------- + text recognition architecture + """ + return _master("master", pretrained, magc_resnet31, **kwargs) diff --git a/doctr/models/recognition/parseq/__init__.py b/doctr/models/recognition/parseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/recognition/parseq/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/parseq/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/parseq/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aac9f819905f199451bd9eef963d62dd01cf16ed Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/parseq/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2b0a9ef943ac1faa231fe04b6e76d18b01d453 Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/base.cpython-311.pyc b/doctr/models/recognition/parseq/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4158a3e8867e5e86f7010c700f4aa2c90a72f01b Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/base.cpython-38.pyc b/doctr/models/recognition/parseq/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0598e7289f960a8b46422afcaa0e373b4c80229 Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/parseq/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40380039c823161200f6c05f997de74602feb58a Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..170faa175ca7519a4744d3f60205f4c64834fe8a Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d82d172bf4c62c53d123a7f31f960bba40c377da Binary files /dev/null and b/doctr/models/recognition/parseq/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/parseq/base.py b/doctr/models/recognition/parseq/base.py new file mode 100644 index 0000000000000000000000000000000000000000..60aa1fcfcf073751bac75350ae25dfdfa29bc491 --- /dev/null +++ b/doctr/models/recognition/parseq/base.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from ....datasets import encode_sequences +from ..core import RecognitionPostProcessor + + +class _PARSeq: + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + ---- + gts: list of ground-truth labels + + Returns: + ------- + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences( + sequences=gts, + vocab=self.vocab, + target_size=self.max_length, + eos=len(self.vocab), + sos=len(self.vocab) + 1, + pad=len(self.vocab) + 2, + ) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class _PARSeqPostProcessor(RecognitionPostProcessor): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + ["", "", ""] diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..8fff062da982c80b6b64bd7c57a57582251d38a1 --- /dev/null +++ b/doctr/models/recognition/parseq/pytorch.py @@ -0,0 +1,482 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from copy import deepcopy +from itertools import permutations +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS +from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward + +from ...classification import vit_s +from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from .base import _PARSeq, _PARSeqPostProcessor + +__all__ = ["PARSeq", "parseq"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "parseq": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/parseq-56125471.pt&src=0", + }, +} + + +class CharEmbedding(nn.Module): + """Implements the character embedding module + + Args: + ---- + vocab_size: size of the vocabulary + d_model: dimension of the model + """ + + def __init__(self, vocab_size: int, d_model: int): + super().__init__() + self.embedding = nn.Embedding(vocab_size, d_model) + self.d_model = d_model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return math.sqrt(self.d_model) * self.embedding(x) + + +class PARSeqDecoder(nn.Module): + """Implements decoder module of the PARSeq model + + Args: + ---- + d_model: dimension of the model + num_heads: number of attention heads + ffd: dimension of the feed forward layer + ffd_ratio: depth multiplier for the feed forward layer + dropout: dropout rate + """ + + def __init__( + self, + d_model: int, + num_heads: int = 12, + ffd: int = 2048, + ffd_ratio: int = 4, + dropout: float = 0.1, + ): + super().__init__() + self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout) + self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout) + self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU()) + + self.attention_norm = nn.LayerNorm(d_model, eps=1e-5) + self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5) + self.query_norm = nn.LayerNorm(d_model, eps=1e-5) + self.content_norm = nn.LayerNorm(d_model, eps=1e-5) + self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5) + self.output_norm = nn.LayerNorm(d_model, eps=1e-5) + self.attention_dropout = nn.Dropout(dropout) + self.cross_attention_dropout = nn.Dropout(dropout) + self.feed_forward_dropout = nn.Dropout(dropout) + + def forward( + self, + target, + content, + memory, + target_mask: Optional[torch.Tensor] = None, + ): + query_norm = self.query_norm(target) + content_norm = self.content_norm(content) + target = target.clone() + self.attention_dropout( + self.attention(query_norm, content_norm, content_norm, mask=target_mask) + ) + target = target.clone() + self.cross_attention_dropout( + self.cross_attention(self.query_norm(target), memory, memory) + ) + target = target.clone() + self.feed_forward_dropout(self.position_feed_forward(self.feed_forward_norm(target))) + return self.output_norm(target) + + +class PARSeq(_PARSeq, nn.Module): + """Implements a PARSeq architecture as described in `"Scene Text Recognition + with Permuted Autoregressive Sequence Models" `_. + Slightly modified implementation based on the official Pytorch implementation: None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + self.max_length = max_length + self.vocab_size = len(vocab) + self.rng = np.random.default_rng() + + self.feat_extractor = feature_extractor + self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob) + self.head = nn.Linear(embedding_units, self.vocab_size + 1) # +1 for EOS + self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD + + self.pos_queries = nn.Parameter(torch.Tensor(1, self.max_length + 1, embedding_units)) # +1 for EOS + self.dropout = nn.Dropout(p=dropout_prob) + + self.postprocessor = PARSeqPostProcessor(vocab=self.vocab) + + nn.init.trunc_normal_(self.pos_queries, std=0.02) + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.padding_idx is not None: + m.weight.data[m.padding_idx].zero_() + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor: + # Generates permutations of the target sequence. + # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py + # with small modifications + + max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch + perms = [torch.arange(max_num_chars, device=seqlen.device)] + + max_perms = math.factorial(max_num_chars) // 2 + num_gen_perms = min(3, max_perms) + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=seqlen.device)[ + selector + ] + # If the forward permutation is always selected, no need to add it to the pool for sampling + perm_pool = perm_pool[1:] + final_perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False) + final_perms = torch.cat([final_perms, perm_pool[i]]) + else: + perms.extend([ + torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms)) + ]) + final_perms = torch.stack(perms) + + comp = final_perms.flip(-1) + final_perms = torch.stack([final_perms, comp]).transpose(0, 1).reshape(-1, max_num_chars) + + sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device) + eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device) + combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() + if len(combined) > 1: + combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device) + return combined + + def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Generate source and target mask for the decoder attention. + sz = permutation.shape[0] + mask = torch.ones((sz, sz), device=permutation.device) + + for i in range(sz): + query_idx = permutation[i] + masked_keys = permutation[i + 1 :] + mask[query_idx, masked_keys] = 0.0 + source_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, device=permutation.device)] = 0.0 + target_mask = mask[1:, :-1] + + return source_mask.int(), target_mask.int() + + def decode( + self, + target: torch.Tensor, + memory: torch.Tensor, + target_mask: Optional[torch.Tensor] = None, + target_query: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Add positional information to the target sequence and pass it through the decoder.""" + batch_size, sequence_length = target.shape + # apply positional information to the target sequence excluding the SOS token + null_ctx = self.embed(target[:, :1]) + content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:]) + content = self.dropout(torch.cat([null_ctx, content], dim=1)) + if target_query is None: + target_query = self.pos_queries[:, :sequence_length].expand(batch_size, -1, -1) + target_query = self.dropout(target_query) + return self.decoder(target_query, content, memory, target_mask) + + def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor: + """Generate predictions for the given features.""" + max_length = max_len if max_len is not None else self.max_length + max_length = min(max_length, self.max_length) + 1 + # Padding symbol + SOS at the beginning + ys = torch.full( + (features.size(0), max_length), self.vocab_size + 2, dtype=torch.long, device=features.device + ) # pad + ys[:, 0] = self.vocab_size + 1 # SOS token + pos_queries = self.pos_queries[:, :max_length].expand(features.size(0), -1, -1) + # Create query mask for the decoder attention + query_mask = ( + torch.tril(torch.ones((max_length, max_length), device=features.device), diagonal=0).to(dtype=torch.bool) + ).int() + + pos_logits = [] + for i in range(max_length): + # Decode one token at a time without providing information about the future tokens + tgt_out = self.decode( + ys[:, : i + 1], + features, + query_mask[i : i + 1, : i + 1], + target_query=pos_queries[:, i : i + 1], + ) + pos_prob = self.head(tgt_out) + pos_logits.append(pos_prob) + + if i + 1 < max_length: + # Update with the next token + ys[:, i + 1] = pos_prob.squeeze().argmax(-1) + + # Stop decoding if all sequences have reached the EOS token + # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export + if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): + break + + logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1) + + # One refine iteration + # Update query mask + query_mask[torch.triu(torch.ones(max_length, max_length, dtype=torch.bool, device=features.device), 2)] = 1 + + # Prepare target input for 1 refine iteration + sos = torch.full((features.size(0), 1), self.vocab_size + 1, dtype=torch.long, device=features.device) + ys = torch.cat([sos, logits[:, :-1].argmax(-1)], dim=1) + + # Create padding mask for refined target input maskes all behind EOS token as False + # (N, 1, 1, max_length) + target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) + mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int() + logits = self.head(self.decode(ys, features, mask, target_query=pos_queries)) + + return logits # (N, max_length, vocab_size + 1) + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model) + # remove cls token + features = features[:, 1:, :] + + if self.training and target is None: + raise ValueError("Need to provide labels during training") + + if target is not None: + # Build target tensor + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device) + gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) + + if self.training: + # Generate permutations for the target sequences + tgt_perms = self.generate_permutations(seq_len) + + gt_in = gt[:, :-1] # remove EOS token from longest target sequence + gt_out = gt[:, 1:] # remove SOS token + # Create padding mask for target input + # [True, True, True, ..., False, False, False] -> False is masked + padding_mask = ~( + ((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0 + ).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len) + + loss = torch.tensor(0.0, device=features.device) + loss_numel: Union[int, float] = 0 + n = (gt_out != self.vocab_size + 2).sum().item() + for i, perm in enumerate(tgt_perms): + _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len) + # combine both masks + mask = (target_mask.bool() & padding_mask.bool()).int() # (N, 1, seq_len, seq_len) + + logits = self.head(self.decode(gt_in, features, mask)).flatten(end_dim=1) + loss += n * F.cross_entropy(logits, gt_out.flatten(), ignore_index=self.vocab_size + 2) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out) + n = (gt_out != self.vocab_size + 2).sum().item() + + loss /= loss_numel + + else: + gt = gt[:, 1:] # remove SOS token + max_len = gt.shape[1] - 1 # exclude EOS token + logits = self.decode_autoregressive(features, max_len) + loss = F.cross_entropy(logits.flatten(end_dim=1), gt.flatten(), ignore_index=self.vocab_size + 2) + else: + logits = self.decode_autoregressive(features) + + logits = _bf16_to_float32(logits) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits) + + if target is not None: + out["loss"] = loss + + return out + + +class PARSeqPostProcessor(_PARSeqPostProcessor): + """Post processor for PARSeq architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + preds_prob = torch.softmax(logits, -1).max(dim=-1)[0] + + # Manual decoding + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.cpu().numpy() + ] + # compute probabilties for each word up to the EOS token + probs = [ + preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) + + +def _parseq( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> PARSeq: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + patch_size = kwargs.get("patch_size", (4, 8)) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch + backbone_fn(False, input_shape=_cfg["input_shape"], patch_size=patch_size), # type: ignore[call-arg] + {layer: "features"}, + ) + + kwargs.pop("patch_size", None) + kwargs.pop("pretrained_backbone", None) + + # Build the model + model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq: + """PARSeq architecture from + `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_. + + >>> import torch + >>> from doctr.models import parseq + >>> model = parseq(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the PARSeq architecture + + Returns: + ------- + text recognition architecture + """ + return _parseq( + "parseq", + pretrained, + vit_s, + "1", + embedding_units=384, + patch_size=(4, 8), + ignore_keys=["embed.embedding.weight", "head.weight", "head.bias"], + **kwargs, + ) diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..1365a6ac12c76de8bc28551290e7aface641f1a5 --- /dev/null +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -0,0 +1,512 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from copy import deepcopy +from itertools import permutations +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow.keras import Model, layers + +from doctr.datasets import VOCABS +from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward + +from ...classification import vit_s +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from .base import _PARSeq, _PARSeqPostProcessor + +__all__ = ["PARSeq", "parseq"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "parseq": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0", + }, +} + + +class CharEmbedding(layers.Layer): + """Implements the character embedding module + + Args: + ---- + vocab_size: size of the vocabulary + d_model: dimension of the model + """ + + def __init__(self, vocab_size: int, d_model: int): + super(CharEmbedding, self).__init__() + self.embedding = tf.keras.layers.Embedding(vocab_size, d_model) + self.d_model = d_model + + def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: + return math.sqrt(self.d_model) * self.embedding(x, **kwargs) + + +class PARSeqDecoder(layers.Layer): + """Implements decoder module of the PARSeq model + + Args: + ---- + d_model: dimension of the model + num_heads: number of attention heads + ffd: dimension of the feed forward layer + ffd_ratio: depth multiplier for the feed forward layer + dropout: dropout rate + """ + + def __init__( + self, + d_model: int, + num_heads: int = 12, + ffd: int = 2048, + ffd_ratio: int = 4, + dropout: float = 0.1, + ): + super(PARSeqDecoder, self).__init__() + self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout) + self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout) + self.position_feed_forward = PositionwiseFeedForward( + d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu) + ) + + self.attention_norm = layers.LayerNormalization(epsilon=1e-5) + self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5) + self.query_norm = layers.LayerNormalization(epsilon=1e-5) + self.content_norm = layers.LayerNormalization(epsilon=1e-5) + self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5) + self.output_norm = layers.LayerNormalization(epsilon=1e-5) + self.attention_dropout = layers.Dropout(dropout) + self.cross_attention_dropout = layers.Dropout(dropout) + self.feed_forward_dropout = layers.Dropout(dropout) + + def call( + self, + target, + content, + memory, + target_mask=None, + **kwargs: Any, + ): + query_norm = self.query_norm(target, **kwargs) + content_norm = self.content_norm(content, **kwargs) + target = target + self.attention_dropout( + self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs + ) + target = target + self.cross_attention_dropout( + self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs + ) + target = target + self.feed_forward_dropout( + self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs + ) + return self.output_norm(target, **kwargs) + + +class PARSeq(_PARSeq, Model): + """Implements a PARSeq architecture as described in `"Scene Text Recognition + with Permuted Autoregressive Sequence Models" `_. + Modified implementation based on the official Pytorch implementation: None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + self.max_length = max_length + self.vocab_size = len(vocab) + self.rng = np.random.default_rng() + + self.feat_extractor = feature_extractor + self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob) + self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD + self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS + self.pos_queries = self.add_weight( + shape=(1, self.max_length + 1, embedding_units), + initializer="zeros", + trainable=True, + name="positions", + ) + self.dropout = layers.Dropout(dropout_prob) + + self.postprocessor = PARSeqPostProcessor(vocab=self.vocab) + + @tf.function + def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor: + # Generates permutations of the target sequence. + # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py + # with small modifications + + max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch + perms = [tf.range(max_num_chars, dtype=tf.int32)] + + max_perms = math.factorial(max_num_chars) // 2 + num_gen_perms = min(3, max_perms) + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars)) + perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector]) + # If the forward permutation is always selected, no need to add it to the pool for sampling + perm_pool = perm_pool[1:] + final_perms = tf.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False) + final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0) + else: + perms.extend([ + tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms)) + ]) + final_perms = tf.stack(perms) + + comp = tf.reverse(final_perms, axis=[-1]) + final_perms = tf.stack([final_perms, comp]) + final_perms = tf.transpose(final_perms, perm=[1, 0, 2]) + final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars)) + + sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32) + eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1) + combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1) + combined = tf.cast(combined, dtype=tf.int32) + if tf.shape(combined)[0] > 1: + combined = tf.tensor_scatter_nd_update( + combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1) + ) + return combined + + @tf.function + def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + # Generate source and target mask for the decoder attention. + sz = permutation.shape[0] + mask = tf.ones((sz, sz), dtype=tf.float32) + + for i in range(sz - 1): + query_idx = int(permutation[i]) + masked_keys = permutation[i + 1 :].numpy().tolist() + indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32) + mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32)) + + source_mask = tf.identity(mask[:-1, :-1]) + eye_indices = tf.eye(sz, dtype=tf.bool) + mask = tf.tensor_scatter_nd_update( + mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices)) + ) + target_mask = mask[1:, :-1] + return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool) + + @tf.function + def decode( + self, + target: tf.Tensor, + memory: tf, + target_mask: Optional[tf.Tensor] = None, + target_query: Optional[tf.Tensor] = None, + **kwargs: Any, + ) -> tf.Tensor: + batch_size, sequence_length = target.shape + # apply positional information to the target sequence excluding the SOS token + null_ctx = self.embed(target[:, :1], **kwargs) + content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs) + content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs) + if target_query is None: + target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1]) + target_query = self.dropout(target_query, **kwargs) + return self.decoder(target_query, content, memory, target_mask, **kwargs) + + @tf.function + def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor: + """Generate predictions for the given features.""" + max_length = max_len if max_len is not None else self.max_length + max_length = min(max_length, self.max_length) + 1 + b = tf.shape(features)[0] + # Padding symbol + SOS at the beginning + ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2) + start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1) + ys = tf.concat([start_vector, ys], axis=-1) + pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1]) + query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool) + + pos_logits = [] + for i in range(max_length): + # Decode one token at a time without providing information about the future tokens + tgt_out = self.decode( + ys[:, : i + 1], + features, + query_mask[i : i + 1, : i + 1], + target_query=pos_queries[:, i : i + 1], + **kwargs, + ) + pos_prob = self.head(tgt_out) + pos_logits.append(pos_prob) + + if i + 1 < max_length: + # update ys with the next token + i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij") + indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1) + ys = tf.tensor_scatter_nd_update( + ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32) + ) + + # Stop decoding if all sequences have reached the EOS token + # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export + if ( + not self.exportable + and max_len is None + and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) + ): + break + + logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1) + + # One refine iteration + # Update query mask + diag_matrix = tf.eye(max_length) + diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32) + query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool) + + sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1) + ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1) + # Create padding mask for refined target input maskes all behind EOS token as False + # (N, 1, 1, max_length) + mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32) + first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32) + mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32) + target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool) + + mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]]) + logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs) + + return logits # (N, max_length, vocab_size + 1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model) + # remove cls token + features = features[:, 1:, :] + + if kwargs.get("training", False) and target is None: + raise ValueError("Need to provide labels during training") + + if target is not None: + gt, seq_len = self.build_target(target) + seq_len = tf.cast(seq_len, tf.int32) + gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) + + if kwargs.get("training", False): + # Generate permutations of the target sequences + tgt_perms = self.generate_permutations(seq_len) + + gt_in = gt[:, :-1] # remove EOS token from longest target sequence + gt_out = gt[:, 1:] # remove SOS token + + # Create padding mask for target input + # [True, True, True, ..., False, False, False] -> False is masked + padding_mask = tf.math.logical_and( + tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size) + ) + padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len) + + loss = tf.constant(0.0) + loss_numel = tf.constant(0.0) + n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32)) + for i, perm in enumerate(tgt_perms): + _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len) + # combine both masks to (N, 1, seq_len, seq_len) + mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0)) + + logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs) + logits_flat = tf.reshape(logits, (-1, logits.shape[-1])) + targets_flat = tf.reshape(gt_out, (-1,)) + mask = tf.not_equal(targets_flat, self.vocab_size + 2) + loss += n * tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask) + ) + ) + loss_numel += n + + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out) + n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32)) + + loss /= loss_numel + + else: + gt = gt[:, 1:] # remove SOS token + max_len = gt.shape[1] - 1 # exclude EOS token + logits = self.decode_autoregressive(features, max_len, **kwargs) + logits_flat = tf.reshape(logits, (-1, logits.shape[-1])) + targets_flat = tf.reshape(gt, (-1,)) + mask = tf.not_equal(targets_flat, self.vocab_size + 2) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask) + ) + ) + else: + logits = self.decode_autoregressive(features, **kwargs) + + logits = _bf16_to_float32(logits) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits) + + if target is not None: + out["loss"] = loss + + return out + + +class PARSeqPostProcessor(_PARSeqPostProcessor): + """Post processor for PARSeq architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype="int32") + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + # compute probabilties for each word up to the EOS token + probs = [ + preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) + + +def _parseq( + arch: str, + pretrained: bool, + backbone_fn, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> PARSeq: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + patch_size = kwargs.get("patch_size", (4, 8)) + + kwargs["vocab"] = _cfg["vocab"] + + # Feature extractor + feat_extractor = backbone_fn( + # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch + pretrained=False, + input_shape=_cfg["input_shape"], + patch_size=patch_size, + include_top=False, + ) + + kwargs.pop("patch_size", None) + kwargs.pop("pretrained_backbone", None) + + # Build the model + model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq: + """PARSeq architecture from + `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_. + + >>> import tensorflow as tf + >>> from doctr.models import parseq + >>> model = parseq(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the PARSeq architecture + + Returns: + ------- + text recognition architecture + """ + return _parseq( + "parseq", + pretrained, + vit_s, + embedding_units=384, + patch_size=(4, 8), + **kwargs, + ) diff --git a/doctr/models/recognition/predictor/__init__.py b/doctr/models/recognition/predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd --- /dev/null +++ b/doctr/models/recognition/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/predictor/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/predictor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9303adc195e4c70306dc08672646960bc8148ea3 Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..947b33363e142f41e331e7a5e74e7b6cc31c490a Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/_utils.cpython-311.pyc b/doctr/models/recognition/predictor/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44c523d638893292a458b1c64b2d251850b86193 Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/_utils.cpython-311.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/_utils.cpython-38.pyc b/doctr/models/recognition/predictor/__pycache__/_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8f6b5ab8b2f83e7883a0a247b695fe383268e5e Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/_utils.cpython-38.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/predictor/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de29c26bb951c58072fa2578a4d9ccf9c913046 Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b6be27647ad888a14ca8ba93e76b7bc86eaa45a Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9998fc8543777361d2d011d0df14977e1acc565 Binary files /dev/null and b/doctr/models/recognition/predictor/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/predictor/_utils.py b/doctr/models/recognition/predictor/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac98d41862f04efec4ec105bbc50000b35b55555 --- /dev/null +++ b/doctr/models/recognition/predictor/_utils.py @@ -0,0 +1,86 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple, Union + +import numpy as np + +from ..utils import merge_multi_strings + +__all__ = ["split_crops", "remap_preds"] + + +def split_crops( + crops: List[np.ndarray], + max_ratio: float, + target_ratio: int, + dilation: float, + channels_last: bool = True, +) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: + """Chunk crops horizontally to match a given aspect ratio + + Args: + ---- + crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise + max_ratio: the maximum aspect ratio that won't trigger the chunk + target_ratio: when crops are chunked, they will be chunked to match this aspect ratio + dilation: the width dilation of final chunks (to provide some overlaps) + channels_last: whether the numpy array has dimensions in channels last order + + Returns: + ------- + a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required + """ + _remap_required = False + crop_map: List[Union[int, Tuple[int, int]]] = [] + new_crops: List[np.ndarray] = [] + for crop in crops: + h, w = crop.shape[:2] if channels_last else crop.shape[-2:] + aspect_ratio = w / h + if aspect_ratio > max_ratio: + # Determine the number of crops, reference aspect ratio = 4 = 128 / 32 + num_subcrops = int(aspect_ratio // target_ratio) + # Find the new widths, additional dilation factor to overlap crops + width = dilation * w / num_subcrops + centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] + # Get the crops + if channels_last: + _crops = [ + crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :] + for center in centers + ] + else: + _crops = [ + crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))] + for center in centers + ] + # Avoid sending zero-sized crops + _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] + # Record the slice of crops + crop_map.append((len(new_crops), len(new_crops) + len(_crops))) + new_crops.extend(_crops) + # At least one crop will require merging + _remap_required = True + else: + crop_map.append(len(new_crops)) + new_crops.append(crop) + + return new_crops, crop_map, _remap_required + + +def remap_preds( + preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float +) -> List[Tuple[str, float]]: + remapped_out = [] + for _idx in crop_map: + # Crop hasn't been split + if isinstance(_idx, int): + remapped_out.append(preds[_idx]) + else: + # unzip + vals, probs = zip(*preds[_idx[0] : _idx[1]]) + # Merge the string values + remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type] + return remapped_out diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b71202f7c28c5dacdc6ab4e605695ddb24df3ff8 --- /dev/null +++ b/doctr/models/recognition/predictor/pytorch.py @@ -0,0 +1,86 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import set_device_and_dtype + +from ._utils import remap_preds, split_crops + +__all__ = ["RecognitionPredictor"] + + +class RecognitionPredictor(nn.Module): + """Implements an object able to identify character sequences in images + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + split_wide_crops: wether to use crop splitting for high aspect ratio crops + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + split_wide_crops: bool = True, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + self.split_wide_crops = split_wide_crops + self.critical_ar = 8 # Critical aspect ratio + self.dil_factor = 1.4 # Dilation factor to overlap the crops + self.target_ar = 6 # Target aspect ratio + + @torch.inference_mode() + def forward( + self, + crops: Sequence[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> List[Tuple[str, float]]: + if len(crops) == 0: + return [] + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + # Split crops that are too wide + remapped = False + if self.split_wide_crops: + new_crops, crop_map, remapped = split_crops( + crops, # type: ignore[arg-type] + self.critical_ar, + self.target_ar, + self.dil_factor, + isinstance(crops[0], np.ndarray), + ) + if remapped: + crops = new_crops + + # Resize & batch them + processed_batches = self.pre_processor(crops) + + # Forward it + _params = next(self.model.parameters()) + self.model, processed_batches = set_device_and_dtype( + self.model, processed_batches, _params.device, _params.dtype + ) + raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches] + + # Process outputs + out = [charseq for batch in raw for charseq in batch] + + # Remap crops + if self.split_wide_crops and remapped: + out = remap_preds(out, crop_map, self.dil_factor) + + return out diff --git a/doctr/models/recognition/predictor/tensorflow.py b/doctr/models/recognition/predictor/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..409f39323af4159b342f1a60b0682aca6142c51b --- /dev/null +++ b/doctr/models/recognition/predictor/tensorflow.py @@ -0,0 +1,80 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +from ..core import RecognitionModel +from ._utils import remap_preds, split_crops + +__all__ = ["RecognitionPredictor"] + + +class RecognitionPredictor(NestedObject): + """Implements an object able to identify character sequences in images + + Args: + ---- + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + split_wide_crops: wether to use crop splitting for high aspect ratio crops + """ + + _children_names: List[str] = ["pre_processor", "model"] + + def __init__( + self, + pre_processor: PreProcessor, + model: RecognitionModel, + split_wide_crops: bool = True, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model + self.split_wide_crops = split_wide_crops + self.critical_ar = 8 # Critical aspect ratio + self.dil_factor = 1.4 # Dilation factor to overlap the crops + self.target_ar = 6 # Target aspect ratio + + def __call__( + self, + crops: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> List[Tuple[str, float]]: + if len(crops) == 0: + return [] + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + # Split crops that are too wide + remapped = False + if self.split_wide_crops: + new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor) + if remapped: + crops = new_crops + + # Resize & batch them + processed_batches = self.pre_processor(crops) + + # Forward it + raw = [ + self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator] + for batch in processed_batches + ] + + # Process outputs + out = [charseq for batch in raw for charseq in batch] + + # Remap crops + if self.split_wide_crops and remapped: + out = remap_preds(out, crop_map, self.dil_factor) + + return out diff --git a/doctr/models/recognition/sar/__init__.py b/doctr/models/recognition/sar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/recognition/sar/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/sar/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/sar/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..181dfbd6c1ec8fae4bda4424214ba827b3b7d332 Binary files /dev/null and b/doctr/models/recognition/sar/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/sar/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/sar/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92d1922c5d6395a55dc3a0f41ec616da4a15c8a0 Binary files /dev/null and b/doctr/models/recognition/sar/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/sar/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/sar/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6428cb1acdac5744955c62b85624b8e24da44322 Binary files /dev/null and b/doctr/models/recognition/sar/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174626194240184a65c01e913a437bf98dfe21d4 Binary files /dev/null and b/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8e5ebf27c6c283fa174481de3128239318bfc74 Binary files /dev/null and b/doctr/models/recognition/sar/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..a66bd32036c11ea66ebb18fcaf4aec7742869e1d --- /dev/null +++ b/doctr/models/recognition/sar/pytorch.py @@ -0,0 +1,402 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS + +from ...classification import resnet31 +from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ["SAR", "sar_resnet31"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "sar_resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0", + }, +} + + +class SAREncoder(nn.Module): + def __init__(self, in_feats: int, rnn_units: int, dropout_prob: float = 0.0) -> None: + super().__init__() + self.rnn = nn.LSTM(in_feats, rnn_units, 2, batch_first=True, dropout=dropout_prob) + self.linear = nn.Linear(rnn_units, rnn_units) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # (N, L, C) --> (N, T, C) + encoded = self.rnn(x)[0] + # (N, C) + return self.linear(encoded[:, -1, :]) + + +class AttentionModule(nn.Module): + def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None: + super().__init__() + self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1) + # No need to add another bias since both tensors are summed together + self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False) + self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False) + + def forward( + self, + features: torch.Tensor, # (N, C, H, W) + hidden_state: torch.Tensor, # (N, C) + ) -> torch.Tensor: + H_f, W_f = features.shape[2:] + + # (N, feat_chans, H, W) --> (N, attention_units, H, W) + feat_projection = self.feat_conv(features) + # (N, state_chans, 1, 1) --> (N, attention_units, 1, 1) + hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1) + state_projection = self.state_conv(hidden_state) + state_projection = state_projection.expand(-1, -1, H_f, W_f) + # (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f) + attention_weights = torch.tanh(feat_projection + state_projection) + # (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f) + attention_weights = self.attention_projector(attention_weights) + B, C, H, W = attention_weights.size() + + # (N, H, W) --> (N, 1, H, W) + attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W) + # fuse features and attention weights (N, C) + return (features * attention_weights).sum(dim=(2, 3)) + + +class SARDecoder(nn.Module): + """Implements decoder module of the SAR model + + Args: + ---- + rnn_units: number of hidden units in recurrent cells + max_length: maximum length of a sequence + vocab_size: number of classes in the model alphabet + embedding_units: number of hidden embedding units + attention_units: number of hidden attention units + + """ + + def __init__( + self, + rnn_units: int, + max_length: int, + vocab_size: int, + embedding_units: int, + attention_units: int, + feat_chans: int = 512, + dropout_prob: float = 0.0, + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.max_length = max_length + + self.embed = nn.Linear(self.vocab_size + 1, embedding_units) + self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1) + self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units) + self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units) + self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1) + self.dropout = nn.Dropout(dropout_prob) + + def forward( + self, + features: torch.Tensor, # (N, C, H, W) + holistic: torch.Tensor, # (N, C) + gt: Optional[torch.Tensor] = None, # (N, L) + ) -> torch.Tensor: + if gt is not None: + gt_embedding = self.embed_tgt(gt) + + logits_list: List[torch.Tensor] = [] + + for t in range(self.max_length + 1): # 32 + if t == 0: + # step to init the first states of the LSTMCell + hidden_state_init = cell_state_init = torch.zeros( + features.size(0), features.size(1), device=features.device, dtype=features.dtype + ) + hidden_state, cell_state = hidden_state_init, cell_state_init + prev_symbol = holistic + elif t == 1: + # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros + # (N, vocab_size + 1) --> (N, embedding_units) + prev_symbol = torch.zeros( + features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype + ) + prev_symbol = self.embed(prev_symbol) + else: + if gt is not None and self.training: + # (N, embedding_units) -2 because of and (same) + prev_symbol = self.embed(gt_embedding[:, t - 2]) + else: + # -1 to start at timestep where prev_symbol was initialized + index = logits_list[t - 1].argmax(-1) + # update prev_symbol with ones at the index of the previous logit vector + prev_symbol = self.embed(self.embed_tgt(index)) + + # (N, C), (N, C) take the last hidden state and cell state from current timestep + hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init)) + hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state)) + # (N, C, H, W), (N, C) --> (N, C) + glimpse = self.attention_module(features, hidden_state) + # (N, C), (N, C) --> (N, 2 * C) + logits = torch.cat([hidden_state, glimpse], dim=1) + logits = self.dropout(logits) + # (N, vocab_size + 1) + logits_list.append(self.output_dense(logits)) + + # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1) + return torch.stack(logits_list[1:]).permute(1, 0, 2) + + +class SAR(nn.Module, RecognitionModel): + """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for + Irregular Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of hidden units in both encoder and decoder LSTM + embedding_units: number of embedding units + attention_units: number of hidden units in attention module + max_length: maximum word length handled by the model + dropout_prob: dropout probability of the encoder LSTM + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + def __init__( + self, + feature_extractor, + vocab: str, + rnn_units: int = 512, + embedding_units: int = 512, + attention_units: int = 512, + max_length: int = 30, + dropout_prob: float = 0.0, + input_shape: Tuple[int, int, int] = (3, 32, 128), + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + + self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word + + self.feat_extractor = feature_extractor + + # Size the LSTM + self.feat_extractor.eval() + with torch.no_grad(): + out_shape = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape + # Switch back to original mode + self.feat_extractor.train() + + self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob) + self.decoder = SARDecoder( + rnn_units, + self.max_length, + len(self.vocab), + embedding_units, + attention_units, + dropout_prob=dropout_prob, + ) + + self.postprocessor = SARPostProcessor(vocab=vocab) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + features = self.feat_extractor(x)["features"] + # NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size) + # Vertical max pooling (N, C, H, W) --> (N, C, W) + pooled_features = features.max(dim=-2).values + # (N, W, C) + pooled_features = pooled_features.permute(0, 2, 1).contiguous() + # (N, C) + encoded = self.encoder(pooled_features) + if target is not None: + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) + gt, seq_len = gt.to(x.device), seq_len.to(x.device) + + if self.training and target is None: + raise ValueError("Need to provide labels during training for teacher forcing") + + decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt)) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = decoded_features + return out + + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out["loss"] = self.compute_loss(decoded_features, gt, seq_len) + + return out + + @staticmethod + def compute_loss( + model_output: torch.Tensor, + gt: torch.Tensor, + seq_len: torch.Tensor, + ) -> torch.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + model_output: predicted logits of the model + gt: the encoded tensor with gt labels + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = model_output.shape[1] + # Add one for additional token + seq_len = seq_len + 1 + # Compute loss + # (N, L, vocab_size + 1) + cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none") + mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None] + cce[mask_2d] = 0 + + ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + return ce_loss.mean() + + +class SARPostProcessor(RecognitionPostProcessor): + """Post processor for SAR architectures + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + # N x L + probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) + # Take the minimum confidence of the sequence + probs = probs.min(dim=1).values.detach().cpu() + + # Manual decoding + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.detach().cpu().numpy() + ] + + return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + + +def _sar( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> SAR: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn(pretrained_backbone), + {layer: "features"}, + ) + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Build the model + model = SAR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: + """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong + Baseline for Irregular Text Recognition" `_. + + >>> import torch + >>> from doctr.models import sar_resnet31 + >>> model = sar_resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the SAR architecture + + Returns: + ------- + text recognition architecture + """ + return _sar( + "sar_resnet31", + pretrained, + resnet31, + "10", + ignore_keys=[ + "decoder.embed.weight", + "decoder.embed_tgt.weight", + "decoder.output_dense.weight", + "decoder.output_dense.bias", + ], + **kwargs, + ) diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e557c2329e0a6c07efe1f6c10f2d8694109100 --- /dev/null +++ b/doctr/models/recognition/sar/tensorflow.py @@ -0,0 +1,421 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Model, Sequential, layers + +from doctr.datasets import VOCABS +from doctr.utils.repr import NestedObject + +from ...classification import resnet31 +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ["SAR", "sar_resnet31"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "sar_resnet31": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0", + }, +} + + +class SAREncoder(layers.Layer, NestedObject): + """Implements encoder module of the SAR model + + Args: + ---- + rnn_units: number of hidden rnn units + dropout_prob: dropout probability + """ + + def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None: + super().__init__() + self.rnn = Sequential([ + layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob), + layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob), + ]) + + def call( + self, + x: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + # (N, C) + return self.rnn(x, **kwargs) + + +class AttentionModule(layers.Layer, NestedObject): + """Implements attention module of the SAR model + + Args: + ---- + attention_units: number of hidden attention units + + """ + + def __init__(self, attention_units: int) -> None: + super().__init__() + self.hidden_state_projector = layers.Conv2D( + attention_units, + 1, + strides=1, + use_bias=False, + padding="same", + kernel_initializer="he_normal", + ) + self.features_projector = layers.Conv2D( + attention_units, + 3, + strides=1, + use_bias=True, + padding="same", + kernel_initializer="he_normal", + ) + self.attention_projector = layers.Conv2D( + 1, + 1, + strides=1, + use_bias=False, + padding="same", + kernel_initializer="he_normal", + ) + self.flatten = layers.Flatten() + + def call( + self, + features: tf.Tensor, + hidden_state: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + [H, W] = features.get_shape().as_list()[1:3] + # shape (N, H, W, vgg_units) -> (N, H, W, attention_units) + features_projection = self.features_projector(features, **kwargs) + # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units) + hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1) + hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs) + projection = tf.math.tanh(hidden_state_projection + features_projection) + # shape (N, H, W, attention_units) -> (N, H, W, 1) + attention = self.attention_projector(projection, **kwargs) + # shape (N, H, W, 1) -> (N, H * W) + attention = self.flatten(attention) + attention = tf.nn.softmax(attention) + # shape (N, H * W) -> (N, H, W, 1) + attention_map = tf.reshape(attention, [-1, H, W, 1]) + glimpse = tf.math.multiply(features, attention_map) + # shape (N, H * W) -> (N, C) + return tf.reduce_sum(glimpse, axis=[1, 2]) + + +class SARDecoder(layers.Layer, NestedObject): + """Implements decoder module of the SAR model + + Args: + ---- + rnn_units: number of hidden units in recurrent cells + max_length: maximum length of a sequence + vocab_size: number of classes in the model alphabet + embedding_units: number of hidden embedding units + attention_units: number of hidden attention units + num_decoder_cells: number of LSTMCell layers to stack + dropout_prob: dropout probability + + """ + + def __init__( + self, + rnn_units: int, + max_length: int, + vocab_size: int, + embedding_units: int, + attention_units: int, + num_decoder_cells: int = 2, + dropout_prob: float = 0.0, + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.max_length = max_length + + self.embed = layers.Dense(embedding_units, use_bias=False) + self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1) + + self.lstm_cells = layers.StackedRNNCells([ + layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells) + ]) + self.attention_module = AttentionModule(attention_units) + self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True) + self.dropout = layers.Dropout(dropout_prob) + + def call( + self, + features: tf.Tensor, + holistic: tf.Tensor, + gt: Optional[tf.Tensor] = None, + **kwargs: Any, + ) -> tf.Tensor: + if gt is not None: + gt_embedding = self.embed_tgt(gt, **kwargs) + + logits_list: List[tf.Tensor] = [] + + for t in range(self.max_length + 1): # 32 + if t == 0: + # step to init the first states of the LSTMCell + states = self.lstm_cells.get_initial_state( + inputs=None, batch_size=features.shape[0], dtype=features.dtype + ) + prev_symbol = holistic + elif t == 1: + # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros + # (N, vocab_size + 1) --> (N, embedding_units) + prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype) + prev_symbol = self.embed(prev_symbol, **kwargs) + else: + if gt is not None and kwargs.get("training", False): + # (N, embedding_units) -2 because of and (same) + prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs) + else: + # -1 to start at timestep where prev_symbol was initialized + index = tf.argmax(logits_list[t - 1], axis=-1) + # update prev_symbol with ones at the index of the previous logit vector + prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs) + + # (N, C), (N, C) take the last hidden state and cell state from current timestep + _, states = self.lstm_cells(prev_symbol, states, **kwargs) + # states = (hidden_state, cell_state) + hidden_state = states[0][0] + # (N, H, W, C), (N, C) --> (N, C) + glimpse = self.attention_module(features, hidden_state, **kwargs) + # (N, C), (N, C) --> (N, 2 * C) + logits = tf.concat([hidden_state, glimpse], axis=1) + logits = self.dropout(logits, **kwargs) + # (N, vocab_size + 1) + logits_list.append(self.output_dense(logits, **kwargs)) + + # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1) + return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2)) + + +class SAR(Model, RecognitionModel): + """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for + Irregular Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of hidden units in both encoder and decoder LSTM + embedding_units: number of embedding units + attention_units: number of hidden units in attention module + max_length: maximum word length handled by the model + num_decoder_cells: number of LSTMCell layers to stack + dropout_prob: dropout probability for the encoder and decoder + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"] + + def __init__( + self, + feature_extractor, + vocab: str, + rnn_units: int = 512, + embedding_units: int = 512, + attention_units: int = 512, + max_length: int = 30, + num_decoder_cells: int = 2, + dropout_prob: float = 0.0, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word + + self.feat_extractor = feature_extractor + + self.encoder = SAREncoder(rnn_units, dropout_prob) + self.decoder = SARDecoder( + rnn_units, + self.max_length, + len(vocab), + embedding_units, + attention_units, + num_decoder_cells, + dropout_prob, + ) + + self.postprocessor = SARPostProcessor(vocab=vocab) + + @staticmethod + def compute_loss( + model_output: tf.Tensor, + gt: tf.Tensor, + seq_len: tf.Tensor, + ) -> tf.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = tf.shape(model_output)[1] + # Add one for additional token + seq_len = seq_len + 1 + # One-hot gt labels + oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) + # Compute loss + cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output) + # Compute mask + mask_values = tf.zeros_like(cce) + mask_2d = tf.sequence_mask(seq_len, input_len) + masked_loss = tf.where(mask_2d, cce, mask_values) + ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) + return tf.expand_dims(ce_loss, axis=1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + features = self.feat_extractor(x, **kwargs) + # vertical max pooling --> (N, C, W) + pooled_features = tf.reduce_max(features, axis=1) + # holistic (N, C) + encoded = self.encoder(pooled_features, **kwargs) + + if target is not None: + gt, seq_len = self.build_target(target) + seq_len = tf.cast(seq_len, tf.int32) + + if kwargs.get("training", False) and target is None: + raise ValueError("Need to provide labels during training for teacher forcing") + + decoded_features = _bf16_to_float32( + self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + ) + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = decoded_features + return out + + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out["loss"] = self.compute_loss(decoded_features, gt, seq_len) + + return out + + +class SARPostProcessor(RecognitionPostProcessor): + """Post processor for SAR architectures + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + # N x L + probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) + # Take the minimum confidence of the sequence + probs = tf.math.reduce_min(probs, axis=1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype="int32") + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + return list(zip(word_values, probs.numpy().clip(0, 1).tolist())) + + +def _sar( + arch: str, + pretrained: bool, + backbone_fn, + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> SAR: + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + + # Feature extractor + feat_extractor = backbone_fn( + pretrained=pretrained_backbone, + input_shape=_cfg["input_shape"], + include_top=False, + ) + + kwargs["vocab"] = _cfg["vocab"] + + # Build the model + model = SAR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: + """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong + Baseline for Irregular Text Recognition" `_. + + >>> import tensorflow as tf + >>> from doctr.models import sar_resnet31 + >>> model = sar_resnet31(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the SAR architecture + + Returns: + ------- + text recognition architecture + """ + return _sar("sar_resnet31", pretrained, resnet31, **kwargs) diff --git a/doctr/models/recognition/utils.py b/doctr/models/recognition/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..09de8b9165bc80664cb4db403b5cca6af9d8d9a7 --- /dev/null +++ b/doctr/models/recognition/utils.py @@ -0,0 +1,89 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List + +from rapidfuzz.distance import Levenshtein + +__all__ = ["merge_strings", "merge_multi_strings"] + + +def merge_strings(a: str, b: str, dil_factor: float) -> str: + """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. + + Args: + ---- + a: first char seq, suffix should be similar to b's prefix. + b: second char seq, prefix should be similar to a's suffix. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + ------- + A merged character sequence. + + Example:: + >>> from doctr.model.recognition.utils import merge_sequences + >>> merge_sequences('abcd', 'cdefgh', 1.4) + 'abcdefgh' + >>> merge_sequences('abcdi', 'cdefgh', 1.4) + 'abcdefgh' + """ + seq_len = min(len(a), len(b)) + if seq_len == 0: # One sequence is empty, return the other + return b if len(a) == 0 else a + + # Initialize merging index and corresponding score (mean Levenstein) + min_score, index = 1.0, 0 # No overlap, just concatenate + + scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)] + + # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0 + if len(scores) > 1 and (scores[0], scores[1]) == (0, 0): + # Compute n_overlap (number of overlapping chars, geometrically determined) + n_overlap = round(len(b) * (dil_factor - 1) / dil_factor) + # Find the number of consecutive zeros in the scores list + # Impossible to have a zero after a non-zero score in that case + n_zeros = sum(val == 0 for val in scores) + # Index is bounded by the geometrical overlap to avoid collapsing repetitions + min_score, index = 0, min(n_zeros, n_overlap) + + else: # Common case: choose the min score index + for i, score in enumerate(scores): + if score < min_score: + min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char + + # Merge with correct overlap + if index == 0: + return a + b + return a[:-1] + b[index - 1 :] + + +def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: + """Recursively merges consecutive string sequences with overlapping characters. + + Args: + ---- + seq_list: list of sequences to merge. Sequences need to be ordered from left to right. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + ------- + A merged character sequence + + Example:: + >>> from doctr.model.recognition.utils import merge_multi_sequences + >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4) + 'abcdefghijkl' + """ + + def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str: + # Recursive version of compute_overlap + if len(seq_list) == 1: + return merge_strings(a, seq_list[0], dil_factor) + return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor) + + return _recursive_merge("", seq_list, dil_factor) diff --git a/doctr/models/recognition/vitstr/__init__.py b/doctr/models/recognition/vitstr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/recognition/vitstr/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-311.pyc b/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1043d26ddae39714f33419d44f2767a66fb96752 Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-38.pyc b/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4436551edff1503661d74c414450c0a7c9e634f Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/base.cpython-311.pyc b/doctr/models/recognition/vitstr/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7ecff77d43b317314ea138409eb8b3426ded9f Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/base.cpython-38.pyc b/doctr/models/recognition/vitstr/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3da3b4c8b9848122ae37f14ade48a0cb070c1ea7 Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/pytorch.cpython-311.pyc b/doctr/models/recognition/vitstr/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc0bcd8a003e11afaf338da17e9a477863efbe13 Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f14b3b2ad121589fa095ed529ac60a5ef8a66eee Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0902fb814d3c83a44ca3cd8ed601660efe6fd224 Binary files /dev/null and b/doctr/models/recognition/vitstr/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/recognition/vitstr/base.py b/doctr/models/recognition/vitstr/base.py new file mode 100644 index 0000000000000000000000000000000000000000..af01dce60083efa84647614acce0dbdd79573269 --- /dev/null +++ b/doctr/models/recognition/vitstr/base.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from ....datasets import encode_sequences +from ..core import RecognitionPostProcessor + + +class _ViTSTR: + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + ---- + gts: list of ground-truth labels + + Returns: + ------- + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences( + sequences=gts, + vocab=self.vocab, + target_size=self.max_length, + eos=len(self.vocab), + sos=len(self.vocab) + 1, + ) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class _ViTSTRPostProcessor(RecognitionPostProcessor): + """Abstract class to postprocess the raw output of the model + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + super().__init__(vocab) + self._embedding = list(vocab) + ["", ""] diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6644220d86627cb10d8ad0aa9e762f12606a17 --- /dev/null +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -0,0 +1,279 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS + +from ...classification import vit_b, vit_s +from ...utils.pytorch import _bf16_to_float32, load_pretrained_params +from .base import _ViTSTR, _ViTSTRPostProcessor + +__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vitstr_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_small-fcd12655.pt&src=0", + }, + "vitstr_base": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 32, 128), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/vitstr_base-50b21df2.pt&src=0", + }, +} + + +class ViTSTR(_ViTSTR, nn.Module): + """Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and + Efficient Scene Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + embedding_units: number of embedding units + max_length: maximum word length handled by the model + dropout_prob: dropout probability of the encoder LSTM + input_shape: input shape of the image + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + def __init__( + self, + feature_extractor, + vocab: str, + embedding_units: int, + max_length: int = 32, # different from paper + input_shape: Tuple[int, int, int] = (3, 32, 128), # different from paper + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + self.max_length = max_length + 2 # +2 for SOS and EOS + + self.feat_extractor = feature_extractor + self.head = nn.Linear(embedding_units, len(self.vocab) + 1) # +1 for EOS + + self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab) + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model) + + if target is not None: + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) + gt, seq_len = gt.to(x.device), seq_len.to(x.device) + + if self.training and target is None: + raise ValueError("Need to provide labels during training") + + # borrowed from : https://github.com/baudm/parseq/blob/main/strhub/models/vitstr/model.py + features = features[:, : self.max_length] # (batch_size, max_length, d_model) + B, N, E = features.size() + features = features.reshape(B * N, E) + logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1) + decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = decoded_features + return out + + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out["loss"] = self.compute_loss(decoded_features, gt, seq_len) + + return out + + @staticmethod + def compute_loss( + model_output: torch.Tensor, + gt: torch.Tensor, + seq_len: torch.Tensor, + ) -> torch.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + model_output: predicted logits of the model + gt: the encoded tensor with gt labels + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of steps + input_len = model_output.shape[1] + # Add one for additional token (sos disappear in shift!) + seq_len = seq_len + 1 + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . + cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none") + # Compute mask + mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None] + cce[mask_2d] = 0 + + ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + return ce_loss.mean() + + +class ViTSTRPostProcessor(_ViTSTRPostProcessor): + """Post processor for ViTSTR architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + preds_prob = torch.softmax(logits, -1).max(dim=-1)[0] + + # Manual decoding + word_values = [ + "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.cpu().numpy() + ] + # compute probabilties for each word up to the EOS token + probs = [ + preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) + + +def _vitstr( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> ViTSTR: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"]) + patch_size = kwargs.get("patch_size", (4, 8)) + + kwargs["vocab"] = _cfg["vocab"] + kwargs["input_shape"] = _cfg["input_shape"] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch + backbone_fn(False, input_shape=_cfg["input_shape"], patch_size=patch_size), # type: ignore[call-arg] + {layer: "features"}, + ) + + kwargs.pop("patch_size", None) + kwargs.pop("pretrained_backbone", None) + + # Build the model + model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR: + """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import torch + >>> from doctr.models import vitstr_small + >>> model = vitstr_small(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr( + "vitstr_small", + pretrained, + vit_s, + "1", + embedding_units=384, + patch_size=(4, 8), + ignore_keys=["head.weight", "head.bias"], + **kwargs, + ) + + +def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR: + """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import torch + >>> from doctr.models import vitstr_base + >>> model = vitstr_base(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr( + "vitstr_base", + pretrained, + vit_b, + "1", + embedding_units=768, + patch_size=(4, 8), + ignore_keys=["head.weight", "head.bias"], + **kwargs, + ) diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5359dde20391fe2ee0f32a27a03c23ff0240e3 --- /dev/null +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -0,0 +1,281 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Model, layers + +from doctr.datasets import VOCABS + +from ...classification import vit_b, vit_s +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from .base import _ViTSTR, _ViTSTRPostProcessor + +__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "vitstr_small": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0", + }, + "vitstr_base": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (32, 128, 3), + "vocab": VOCABS["french"], + "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0", + }, +} + + +class ViTSTR(_ViTSTR, Model): + """Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and + Efficient Scene Text Recognition" `_. + + Args: + ---- + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + embedding_units: number of embedding units + max_length: maximum word length handled by the model + dropout_prob: dropout probability for the encoder and decoder + input_shape: input shape of the image + exportable: onnx exportable returns only logits + cfg: dictionary containing information about the model + """ + + _children_names: List[str] = ["feat_extractor", "postprocessor"] + + def __init__( + self, + feature_extractor, + vocab: str, + embedding_units: int, + max_length: int = 32, + dropout_prob: float = 0.0, + input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.exportable = exportable + self.cfg = cfg + self.max_length = max_length + 2 # +2 for SOS and EOS + + self.feat_extractor = feature_extractor + self.head = layers.Dense(len(self.vocab) + 1, name="head") # +1 for EOS + + self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab) + + @staticmethod + def compute_loss( + model_output: tf.Tensor, + gt: tf.Tensor, + seq_len: List[int], + ) -> tf.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + ---- + model_output: predicted logits of the model + gt: the encoded tensor with gt labels + seq_len: lengths of each gt word inside the batch + + Returns: + ------- + The loss of the model on the batch + """ + # Input length : number of steps + input_len = tf.shape(model_output)[1] + # Add one for additional token (sos disappear in shift!) + seq_len = tf.cast(seq_len, tf.int32) + 1 + # One-hot gt labels + oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . + cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output) + # Compute mask + mask_values = tf.zeros_like(cce) + mask_2d = tf.sequence_mask(seq_len, input_len) + masked_loss = tf.where(mask_2d, cce, mask_values) + ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) + + return tf.expand_dims(ce_loss, axis=1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model) + + if target is not None: + gt, seq_len = self.build_target(target) + seq_len = tf.cast(seq_len, tf.int32) + + if kwargs.get("training", False) and target is None: + raise ValueError("Need to provide labels during training") + + features = features[:, : self.max_length] # (batch_size, max_length, d_model) + B, N, E = features.shape + features = tf.reshape(features, (B * N, E)) + logits = tf.reshape( + self.head(features, **kwargs), (B, N, len(self.vocab) + 1) + ) # (batch_size, max_length, vocab + 1) + decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token + + out: Dict[str, tf.Tensor] = {} + if self.exportable: + out["logits"] = decoded_features + return out + + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out["loss"] = self.compute_loss(decoded_features, gt, seq_len) + + return out + + +class ViTSTRPostProcessor(_ViTSTRPostProcessor): + """Post processor for ViTSTR architecture + + Args: + ---- + vocab: string containing the ordered sequence of supported characters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype="int32") + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + # compute probabilties for each word up to the EOS token + probs = [ + preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + for i, word in enumerate(word_values) + ] + + return list(zip(word_values, probs)) + + +def _vitstr( + arch: str, + pretrained: bool, + backbone_fn, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> ViTSTR: + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg["input_shape"] = input_shape or _cfg["input_shape"] + _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) + patch_size = kwargs.get("patch_size", (4, 8)) + + kwargs["vocab"] = _cfg["vocab"] + + # Feature extractor + feat_extractor = backbone_fn( + # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch + pretrained=False, + input_shape=_cfg["input_shape"], + patch_size=patch_size, + include_top=False, + ) + + kwargs.pop("patch_size", None) + kwargs.pop("pretrained_backbone", None) + + # Build the model + model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]["url"]) + + return model + + +def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR: + """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import tensorflow as tf + >>> from doctr.models import vitstr_small + >>> model = vitstr_small(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr( + "vitstr_small", + pretrained, + vit_s, + embedding_units=384, + patch_size=(4, 8), + **kwargs, + ) + + +def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR: + """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" + `_. + + >>> import tensorflow as tf + >>> from doctr.models import vitstr_base + >>> model = vitstr_base(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + ---- + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + **kwargs: keyword arguments of the ViTSTR architecture + + Returns: + ------- + text recognition architecture + """ + return _vitstr( + "vitstr_base", + pretrained, + vit_b, + embedding_units=768, + patch_size=(4, 8), + **kwargs, + ) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..0393240431556cf2591c79752ea74962107d0a42 --- /dev/null +++ b/doctr/models/recognition/zoo.py @@ -0,0 +1,75 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, List + +from doctr.file_utils import is_tf_available +from doctr.models.preprocessor import PreProcessor + +from .. import recognition +from .predictor import RecognitionPredictor + +__all__ = ["recognition_predictor"] + + +ARCHS: List[str] = [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master", + "vitstr_small", + "vitstr_base", + "parseq", +] + + +def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = recognition.__dict__[arch]( + pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True) + ) + else: + if not isinstance( + arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) + ): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch + + kwargs.pop("pretrained_backbone", None) + + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 128) + input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] + predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) + + return predictor + + +def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor: + """Text recognition architecture. + + Example:: + >>> import numpy as np + >>> from doctr.models import recognition_predictor + >>> model = recognition_predictor(pretrained=True) + >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') + pretrained: If True, returns a model pre-trained on our text recognition dataset + **kwargs: optional parameters to be passed to the architecture + + Returns: + ------- + Recognition predictor + """ + return _predictor(arch, pretrained, **kwargs) diff --git a/doctr/models/utils/__init__.py b/doctr/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a --- /dev/null +++ b/doctr/models/utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/utils/__pycache__/__init__.cpython-311.pyc b/doctr/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fcdcb30f024d761c45edc92f862fb69a34fd836 Binary files /dev/null and b/doctr/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/models/utils/__pycache__/__init__.cpython-38.pyc b/doctr/models/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627b83e410d2a2544dbed4e9c7f6040f975400d9 Binary files /dev/null and b/doctr/models/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/models/utils/__pycache__/pytorch.cpython-311.pyc b/doctr/models/utils/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e142c0e6e2a64bccc8cd2132ede5864ef9ded583 Binary files /dev/null and b/doctr/models/utils/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/models/utils/__pycache__/tensorflow.cpython-311.pyc b/doctr/models/utils/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c5228742b39ea5b531ee4fe4d2608791117e841 Binary files /dev/null and b/doctr/models/utils/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/models/utils/__pycache__/tensorflow.cpython-38.pyc b/doctr/models/utils/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7b5670e291e525b94fb4f5d35b36e3c81517fb Binary files /dev/null and b/doctr/models/utils/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0401cdef6cc14d256d171fe71376f3304c0fce64 --- /dev/null +++ b/doctr/models/utils/pytorch.py @@ -0,0 +1,170 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import nn + +from doctr.utils.data import download_from_url + +__all__ = [ + "load_pretrained_params", + "conv_sequence_pt", + "set_device_and_dtype", + "export_model_to_onnx", + "_copy_tensor", + "_bf16_to_float32", +] + + +def _copy_tensor(x: torch.Tensor) -> torch.Tensor: + return x.clone().detach() + + +def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor: + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return x.float() if x.dtype == torch.bfloat16 else x + + +def load_pretrained_params( + model: nn.Module, + url: Optional[str] = None, + hash_prefix: Optional[str] = None, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> None: + """Load a set of parameters onto a model + + >>> from doctr.models import load_pretrained_params + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + ---- + model: the PyTorch model to be loaded + url: URL of the zipped set of parameters + hash_prefix: first characters of SHA256 expected hash + ignore_keys: list of weights to be ignored from the state_dict + **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` + """ + if url is None: + logging.warning("Invalid model URL, using default initialization.") + else: + archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) + + # Read state_dict + state_dict = torch.load(archive_path, map_location="cpu") + + # Remove weights from the state_dict + if ignore_keys is not None and len(ignore_keys) > 0: + for key in ignore_keys: + state_dict.pop(key) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0: + raise ValueError("unable to load state_dict, due to non-matching keys.") + else: + # Load weights + model.load_state_dict(state_dict) + + +def conv_sequence_pt( + in_channels: int, + out_channels: int, + relu: bool = False, + bn: bool = False, + **kwargs: Any, +) -> List[nn.Module]: + """Builds a convolutional-based layer sequence + + >>> from torch.nn import Sequential + >>> from doctr.models import conv_sequence + >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) + + Args: + ---- + in_channels: number of input channels + out_channels: number of output channels + relu: whether ReLU should be used + bn: should a batch normalization layer be added + **kwargs: additional arguments to be passed to the convolutional layer + + Returns: + ------- + list of layers + """ + # No bias before Batch norm + kwargs["bias"] = kwargs.get("bias", not bn) + # Add activation directly to the conv if there is no BN + conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)] + + if bn: + conv_seq.append(nn.BatchNorm2d(out_channels)) + + if relu: + conv_seq.append(nn.ReLU(inplace=True)) + + return conv_seq + + +def set_device_and_dtype( + model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype +) -> Tuple[Any, List[torch.Tensor]]: + """Set the device and dtype of a model and its batches + + >>> import torch + >>> from torch import nn + >>> from doctr.models.utils import set_device_and_dtype + >>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) + >>> batches = [torch.rand(8) for _ in range(2)] + >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16) + + Args: + ---- + model: the model to be set + batches: the batches to be set + device: the device to be used + dtype: the dtype to be used + + Returns: + ------- + the model and batches set + """ + return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches] + + +def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str: + """Export model to ONNX format. + + >>> import torch + >>> from doctr.models.classification import resnet18 + >>> from doctr.models.utils import export_model_to_onnx + >>> model = resnet18(pretrained=True) + >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) + + Args: + ---- + model: the PyTorch model to be exported + model_name: the name for the exported model + dummy_input: the dummy input to the model + kwargs: additional arguments to be passed to torch.onnx.export + + Returns: + ------- + the path to the exported model + """ + torch.onnx.export( + model, + dummy_input, + f"{model_name}.onnx", + input_names=["input"], + output_names=["logits"], + dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, + export_params=True, + verbose=False, + **kwargs, + ) + logging.info(f"Model exported to {model_name}.onnx") + return f"{model_name}.onnx" diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6f02c2a3271418b5f159559b856cb09cc28150 --- /dev/null +++ b/doctr/models/utils/tensorflow.py @@ -0,0 +1,181 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +import os +from typing import Any, Callable, List, Optional, Tuple, Union +from zipfile import ZipFile + +import tensorflow as tf +import tf2onnx +from tensorflow.keras import Model, layers + +from doctr.utils.data import download_from_url + +logging.getLogger("tensorflow").setLevel(logging.DEBUG) + + +__all__ = [ + "load_pretrained_params", + "conv_sequence", + "IntermediateLayerGetter", + "export_model_to_onnx", + "_copy_tensor", + "_bf16_to_float32", +] + + +def _copy_tensor(x: tf.Tensor) -> tf.Tensor: + return tf.identity(x) + + +def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor: + # Convert bfloat16 to float32 for numpy compatibility + return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x + + +def load_pretrained_params( + model: Model, + url: Optional[str] = None, + hash_prefix: Optional[str] = None, + overwrite: bool = False, + internal_name: str = "weights", + **kwargs: Any, +) -> None: + """Load a set of parameters onto a model + + >>> from doctr.models import load_pretrained_params + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + ---- + model: the keras model to be loaded + url: URL of the zipped set of parameters + hash_prefix: first characters of SHA256 expected hash + overwrite: should the zip extraction be enforced if the archive has already been extracted + internal_name: name of the ckpt files + **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` + """ + if url is None: + logging.warning("Invalid model URL, using default initialization.") + else: + archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) + + # Unzip the archive + params_path = archive_path.parent.joinpath(archive_path.stem) + if not params_path.is_dir() or overwrite: + with ZipFile(archive_path, "r") as f: + f.extractall(path=params_path) + + # Load weights + model.load_weights(f"{params_path}{os.sep}{internal_name}") + + +def conv_sequence( + out_channels: int, + activation: Optional[Union[str, Callable]] = None, + bn: bool = False, + padding: str = "same", + kernel_initializer: str = "he_normal", + **kwargs: Any, +) -> List[layers.Layer]: + """Builds a convolutional-based layer sequence + + >>> from tensorflow.keras import Sequential + >>> from doctr.models import conv_sequence + >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) + + Args: + ---- + out_channels: number of output channels + activation: activation to be used (default: no activation) + bn: should a batch normalization layer be added + padding: padding scheme + kernel_initializer: kernel initializer + **kwargs: additional arguments to be passed to the convolutional layer + + Returns: + ------- + list of layers + """ + # No bias before Batch norm + kwargs["use_bias"] = kwargs.get("use_bias", not bn) + # Add activation directly to the conv if there is no BN + kwargs["activation"] = activation if not bn else None + conv_seq = [layers.Conv2D(out_channels, padding=padding, kernel_initializer=kernel_initializer, **kwargs)] + + if bn: + conv_seq.append(layers.BatchNormalization()) + + if (isinstance(activation, str) or callable(activation)) and bn: + # Activation function can either be a string or a function ('relu' or tf.nn.relu) + conv_seq.append(layers.Activation(activation)) + + return conv_seq + + +class IntermediateLayerGetter(Model): + """Implements an intermediate layer getter + + >>> from tensorflow.keras.applications import ResNet50 + >>> from doctr.models import IntermediateLayerGetter + >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] + >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) + + Args: + ---- + model: the model to extract feature maps from + layer_names: the list of layers to retrieve the feature map from + """ + + def __init__(self, model: Model, layer_names: List[str]) -> None: + intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names] + super().__init__(model.input, outputs=intermediate_fmaps) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +def export_model_to_onnx( + model: Model, model_name: str, dummy_input: List[tf.TensorSpec], **kwargs: Any +) -> Tuple[str, List[str]]: + """Export model to ONNX format. + + >>> import tensorflow as tf + >>> from doctr.models.classification import resnet18 + >>> from doctr.models.utils import export_classification_model_to_onnx + >>> model = resnet18(pretrained=True, include_top=True) + >>> export_model_to_onnx(model, "my_model", + >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")]) + + Args: + ---- + model: the keras model to be exported + model_name: the name for the exported model + dummy_input: the dummy input to the model + kwargs: additional arguments to be passed to tf2onnx + + Returns: + ------- + the path to the exported model and a list with the output layer names + """ + large_model = kwargs.get("large_model", False) + model_proto, _ = tf2onnx.convert.from_keras( + model, + input_signature=dummy_input, + output_path=f"{model_name}.zip" if large_model else f"{model_name}.onnx", + **kwargs, + ) + # Get the output layer names + output = [n.name for n in model_proto.graph.output] + + # models which are too large (weights > 2GB while converting to ONNX) needs to be handled + # about an external tensor storage where the graph and weights are seperatly stored in a archive + if large_model: + logging.info(f"Model exported to {model_name}.zip") + return f"{model_name}.zip", output + + logging.info(f"Model exported to {model_name}.zip") + return f"{model_name}.onnx", output diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..eff5fe14c418841f3ac9afaaa19a218f3d7f6f28 --- /dev/null +++ b/doctr/models/zoo.py @@ -0,0 +1,241 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +from .detection.zoo import detection_predictor +from .kie_predictor import KIEPredictor +from .predictor import OCRPredictor +from .recognition.zoo import recognition_predictor + +__all__ = ["ocr_predictor", "kie_predictor"] + + +def _predictor( + det_arch: Any, + reco_arch: Any, + pretrained: bool, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + det_bs: int = 2, + reco_bs: int = 128, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs, +) -> OCRPredictor: + # Detection + det_predictor = detection_predictor( + det_arch, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + batch_size=det_bs, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + ) + + # Recognition + reco_predictor = recognition_predictor( + reco_arch, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + batch_size=reco_bs, + ) + + return OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) + + +def ocr_predictor( + det_arch: Any = "fast_base", + reco_arch: Any = "crnn_vgg16_bn", + pretrained: bool = False, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + export_as_straight_boxes: bool = False, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs: Any, +) -> OCRPredictor: + """End-to-end OCR architecture using one model for localization, and another for text recognition. + + >>> import numpy as np + >>> from doctr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + det_arch: name of the detection architecture or the model itself to use + (e.g. 'db_resnet50', 'db_mobilenet_v3_large') + reco_arch: name of the recognition architecture or the model itself to use + (e.g. 'crnn_vgg16_bn', 'sar_resnet31') + pretrained: If True, returns a model pre-trained on our OCR dataset + pretrained_backbone: If True, returns a model with a pretrained backbone + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it. + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + straighten_pages: if True, estimates the page general orientation + based on the segmentation map median line orientation. + Then, rotates page before passing it again to the deep learning detection module. + Doing so will improve performances for documents with page-uniform rotations. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `OCRPredictor` + + Returns: + ------- + OCR predictor + """ + return _predictor( + det_arch, + reco_arch, + pretrained, + pretrained_backbone=pretrained_backbone, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + export_as_straight_boxes=export_as_straight_boxes, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) + + +def _kie_predictor( + det_arch: Any, + reco_arch: Any, + pretrained: bool, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + det_bs: int = 2, + reco_bs: int = 128, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs, +) -> KIEPredictor: + # Detection + det_predictor = detection_predictor( + det_arch, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + batch_size=det_bs, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + ) + + # Recognition + reco_predictor = recognition_predictor( + reco_arch, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + batch_size=reco_bs, + ) + + return KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) + + +def kie_predictor( + det_arch: Any = "fast_base", + reco_arch: Any = "crnn_vgg16_bn", + pretrained: bool = False, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + export_as_straight_boxes: bool = False, + detect_orientation: bool = False, + straighten_pages: bool = False, + detect_language: bool = False, + **kwargs: Any, +) -> KIEPredictor: + """End-to-end KIE architecture using one model for localization, and another for text recognition. + + >>> import numpy as np + >>> from doctr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + det_arch: name of the detection architecture or the model itself to use + (e.g. 'db_resnet50', 'db_mobilenet_v3_large') + reco_arch: name of the recognition architecture or the model itself to use + (e.g. 'crnn_vgg16_bn', 'sar_resnet31') + pretrained: If True, returns a model pre-trained on our OCR dataset + pretrained_backbone: If True, returns a model with a pretrained backbone + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it. + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + straighten_pages: if True, estimates the page general orientation + based on the segmentation map median line orientation. + Then, rotates page before passing it again to the deep learning detection module. + Doing so will improve performances for documents with page-uniform rotations. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `OCRPredictor` + + Returns: + ------- + KIE predictor + """ + return _kie_predictor( + det_arch, + reco_arch, + pretrained, + pretrained_backbone=pretrained_backbone, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + export_as_straight_boxes=export_as_straight_boxes, + detect_orientation=detect_orientation, + straighten_pages=straighten_pages, + detect_language=detect_language, + **kwargs, + ) diff --git a/doctr/transforms/__init__.py b/doctr/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..270dcebaa5f4e79f101903087c3dfbd8dcfdddb3 --- /dev/null +++ b/doctr/transforms/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/doctr/transforms/__pycache__/__init__.cpython-311.pyc b/doctr/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40f1b4386e68cd36f6f5baa8ad474206f67cbade Binary files /dev/null and b/doctr/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/transforms/__pycache__/__init__.cpython-38.pyc b/doctr/transforms/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d8bc0c203c441197b36c1dec59e8cc1c83ef01e Binary files /dev/null and b/doctr/transforms/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/transforms/functional/__init__.py b/doctr/transforms/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64556e403a5697432f805a5af28dab812fa8b932 --- /dev/null +++ b/doctr/transforms/functional/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/transforms/functional/__pycache__/__init__.cpython-311.pyc b/doctr/transforms/functional/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8f80d9444496ef1ed77f1c83fc4fa64d2b994f Binary files /dev/null and b/doctr/transforms/functional/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/transforms/functional/__pycache__/__init__.cpython-38.pyc b/doctr/transforms/functional/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4419ddb42738f23b3442bee887703a8c1002d8c3 Binary files /dev/null and b/doctr/transforms/functional/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/transforms/functional/__pycache__/base.cpython-311.pyc b/doctr/transforms/functional/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c277492e51d437e90e139dbd940741557298495e Binary files /dev/null and b/doctr/transforms/functional/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/transforms/functional/__pycache__/base.cpython-38.pyc b/doctr/transforms/functional/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78907d86da035fd772525900f2430020cfc9b3f6 Binary files /dev/null and b/doctr/transforms/functional/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/transforms/functional/__pycache__/pytorch.cpython-311.pyc b/doctr/transforms/functional/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..218340b6c7de67a20ddc4ad20ce028c3817a6fc9 Binary files /dev/null and b/doctr/transforms/functional/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/transforms/functional/__pycache__/tensorflow.cpython-311.pyc b/doctr/transforms/functional/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af47e02712837256010e63665989c1a69e896805 Binary files /dev/null and b/doctr/transforms/functional/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/transforms/functional/__pycache__/tensorflow.cpython-38.pyc b/doctr/transforms/functional/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7164330b3680ac33c9f405ad7e9deff2bdb7342 Binary files /dev/null and b/doctr/transforms/functional/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/transforms/functional/base.py b/doctr/transforms/functional/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9084989b34b510dcb0e12c6d0444faa557254ac2 --- /dev/null +++ b/doctr/transforms/functional/base.py @@ -0,0 +1,203 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Tuple, Union + +import cv2 +import numpy as np + +from doctr.utils.geometry import rotate_abs_geoms + +__all__ = ["crop_boxes", "create_shadow_mask"] + + +def crop_boxes( + boxes: np.ndarray, + crop_box: Union[Tuple[int, int, int, int], Tuple[float, float, float, float]], +) -> np.ndarray: + """Crop localization boxes + + Args: + ---- + boxes: ndarray of shape (N, 4) in relative or abs coordinates + crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes + + Returns: + ------- + the cropped boxes + """ + is_box_rel = boxes.max() <= 1 + is_crop_rel = max(crop_box) <= 1 + + if is_box_rel ^ is_crop_rel: + raise AssertionError("both the boxes and the crop need to have the same coordinate convention") + + xmin, ymin, xmax, ymax = crop_box + # Clip boxes & correct offset + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(xmin, xmax) - xmin + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(ymin, ymax) - ymin + # Rescale relative coords + if is_box_rel: + boxes[:, [0, 2]] /= xmax - xmin + boxes[:, [1, 3]] /= ymax - ymin + + # Remove 0-sized boxes + is_valid = np.logical_and(boxes[:, 1] < boxes[:, 3], boxes[:, 0] < boxes[:, 2]) + + return boxes[is_valid] + + +def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float, float]: + """Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in + the same direction until we meet one of the edges. + + Args: + ---- + line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip. + target_shape: the desired mask shape + + Returns: + ------- + 2D coordinates of the first point once we extended the line (on one of the edges) + """ + if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])): + return line[0] + # Get the line equation + _tmp = line[1] - line[0] + _direction = _tmp > 0 + _flat = _tmp == 0 + # vertical case + if _tmp[0] == 0: + solutions = [ + # y = 0 + (line[0, 0], 0), + # y = bot + (line[0, 0], target_shape[0]), + ] + # horizontal + elif _tmp[1] == 0: + solutions = [ + # x = 0 + (0, line[0, 1]), + # x = right + (target_shape[1], line[0, 1]), + ] + else: + alpha = _tmp[1] / _tmp[0] + beta = line[1, 1] - alpha * line[1, 0] + + # Solve it for edges + solutions = [ + # x = 0 + (0, beta), + # y = 0 + (-beta / alpha, 0), + # x = right + (target_shape[1], alpha * target_shape[1] + beta), + # y = bot + ((target_shape[0] - beta) / alpha, target_shape[0]), + ] + for point in solutions: + # Skip points that are out of the final image + if any(val < 0 or val > size for val, size in zip(point, target_shape[::-1])): + continue + if all( + val == ref if _same else (val < ref if _dir else val > ref) + for val, ref, _dir, _same in zip(point, line[1], _direction, _flat) + ): + return point + raise ValueError + + +def create_shadow_mask( + target_shape: Tuple[int, int], + min_base_width=0.3, + max_tip_width=0.5, + max_tip_height=0.3, +) -> np.ndarray: + """Creates a random shadow mask + + Args: + ---- + target_shape: the target shape (H, W) + min_base_width: the relative minimum shadow base width + max_tip_width: the relative maximum shadow tip width + max_tip_height: the relative maximum shadow tip height + + Returns: + ------- + a numpy ndarray of shape (H, W, 1) with values in the range [0, 1] + """ + # Default base is top + _params = np.random.rand(6) + base_width = min_base_width + (1 - min_base_width) * _params[0] + base_center = base_width / 2 + (1 - base_width) * _params[1] + # Ensure tip width is smaller for shadow consistency + tip_width = min(_params[2] * base_width * target_shape[0] / target_shape[1], max_tip_width) + tip_center = tip_width / 2 + (1 - tip_width) * _params[3] + tip_height = _params[4] * max_tip_height + tip_mid = tip_height / 2 + (1 - tip_height) * _params[5] + _order = tip_center < base_center + contour: np.ndarray = np.array( + [ + [base_center - base_width / 2, 0], + [base_center + base_width / 2, 0], + [tip_center + tip_width / 2, tip_mid + tip_height / 2 if _order else tip_mid - tip_height / 2], + [tip_center - tip_width / 2, tip_mid - tip_height / 2 if _order else tip_mid + tip_height / 2], + ], + dtype=np.float32, + ) + + # Convert to absolute coords + abs_contour: np.ndarray = ( + np.stack( + (contour[:, 0] * target_shape[1], contour[:, 1] * target_shape[0]), + axis=-1, + ) + .round() + .astype(np.int32) + ) + + # Direction + _params = np.random.rand(1) + rotated_contour = ( + rotate_abs_geoms( + abs_contour[None, ...], + 360 * _params[0], + target_shape, + expand=False, + )[0] + .round() + .astype(np.int32) + ) + # Check approx quadrant + quad_idx = int(_params[0] / 0.25) + # Top-bot + if quad_idx % 2 == 0: + intensity_mask = np.repeat(np.arange(target_shape[0])[:, None], target_shape[1], axis=1) / (target_shape[0] - 1) + if quad_idx == 0: + intensity_mask = 1 - intensity_mask + # Left - right + else: + intensity_mask = np.repeat(np.arange(target_shape[1])[None, :], target_shape[0], axis=0) / (target_shape[1] - 1) + if quad_idx == 1: + intensity_mask = 1 - intensity_mask + + # Expand base + final_contour = rotated_contour.copy() + final_contour[0] = expand_line(final_contour[[0, 3]], target_shape) + final_contour[1] = expand_line(final_contour[[1, 2]], target_shape) + # If both base are not on the same side, add a point + if not np.any(final_contour[0] == final_contour[1]): + corner_x = 0 if max(final_contour[0, 0], final_contour[1, 0]) < target_shape[1] else target_shape[1] + corner_y = 0 if max(final_contour[0, 1], final_contour[1, 1]) < target_shape[0] else target_shape[0] + corner: np.ndarray = np.array([corner_x, corner_y]) + final_contour = np.concatenate((final_contour[:1], corner[None, ...], final_contour[1:]), axis=0) + + # Direction & rotate + mask: np.ndarray = np.zeros((*target_shape, 1), dtype=np.uint8) + mask = cv2.fillPoly(mask, [final_contour], (255,), lineType=cv2.LINE_AA)[..., 0] + + return (mask / 255).astype(np.float32).clip(0, 1) * intensity_mask.astype(np.float32) # type: ignore[operator] diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..65649ea2c8184fc7e0cef95034d9102e9e1d3b5f --- /dev/null +++ b/doctr/transforms/functional/pytorch.py @@ -0,0 +1,145 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Tuple + +import numpy as np +import torch +from torchvision.transforms import functional as F + +from doctr.utils.geometry import rotate_abs_geoms + +from .base import create_shadow_mask, crop_boxes + +__all__ = ["invert_colors", "rotate_sample", "crop_detection", "random_shadow"] + + +def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: + """Invert the colors of an image + + Args: + ---- + img : torch.Tensor, the image to invert + min_val : minimum value of the random shift + + Returns: + ------- + the inverted image + """ + out = F.rgb_to_grayscale(img, num_output_channels=3) + # Random RGB shift + shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1] + rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) + # Inverse the color + if out.dtype == torch.uint8: + out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) + else: + out = out * rgb_shift.to(dtype=out.dtype) + # Inverse the color + out = 255 - out if out.dtype == torch.uint8 else 1 - out + return out + + +def rotate_sample( + img: torch.Tensor, + geoms: np.ndarray, + angle: float, + expand: bool = False, +) -> Tuple[torch.Tensor, np.ndarray]: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + ---- + img: image to rotate + geoms: array of geometries of shape (N, 4) or (N, 4, 2) + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + ------- + A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2) + """ + rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default + rotated_img = rotated_img[:3] # when expand=True, it expands to RGBA channels + # Get absolute coords + _geoms = deepcopy(geoms) + if _geoms.shape[1:] == (4,): + if np.max(_geoms) <= 1: + _geoms[:, [0, 2]] *= img.shape[-1] + _geoms[:, [1, 3]] *= img.shape[-2] + elif _geoms.shape[1:] == (4, 2): + if np.max(_geoms) <= 1: + _geoms[..., 0] *= img.shape[-1] + _geoms[..., 1] *= img.shape[-2] + else: + raise AssertionError("invalid format for arg `geoms`") + + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon + rotated_geoms: np.ndarray = rotate_abs_geoms( + _geoms, + angle, + img.shape[1:], # type: ignore[arg-type] + expand, + ).astype(np.float32) + + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2] + rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1] + + return rotated_img, np.clip(rotated_geoms, 0, 1) + + +def crop_detection( + img: torch.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float] +) -> Tuple[torch.Tensor, np.ndarray]: + """Crop and image and associated bboxes + + Args: + ---- + img: image to crop + boxes: array of boxes to clip, absolute (int) or relative (float) + crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. + + Returns: + ------- + A tuple of cropped image, cropped boxes, where the image is not resized. + """ + if any(val < 0 or val > 1 for val in crop_box): + raise AssertionError("coordinates of arg `crop_box` should be relative") + h, w = img.shape[-2:] + xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1))) + xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1))) + cropped_img = F.crop(img, ymin, xmin, ymax - ymin, xmax - xmin) + # Crop the box + boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax)) + + return cropped_img, boxes + + +def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwargs) -> torch.Tensor: + """Crop and image and associated bboxes + + Args: + ---- + img: image to modify + opacity_range: the minimum and maximum desired opacity of the shadow + **kwargs: additional arguments to pass to `create_shadow_mask` + + Returns: + ------- + shaded image + """ + shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type] + + opacity = np.random.uniform(*opacity_range) + shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...]) + + # Add some blur to make it believable + k = 7 + 2 * int(4 * np.random.rand(1)) + sigma = np.random.uniform(0.5, 5.0) + shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma]) + + return opacity * shadow_tensor * img + (1 - opacity) * img diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..27336089b42750999bcafa1dc4e43a68e43a7821 --- /dev/null +++ b/doctr/transforms/functional/tensorflow.py @@ -0,0 +1,266 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +import random +from copy import deepcopy +from typing import Iterable, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.utils.geometry import compute_expanded_shape, rotate_abs_geoms + +from .base import create_shadow_mask, crop_boxes + +__all__ = ["invert_colors", "rotate_sample", "crop_detection", "random_shadow", "rotated_img_tensor"] + + +def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor: + """Invert the colors of an image + + Args: + ---- + img : tf.Tensor, the image to invert + min_val : minimum value of the random shift + + Returns: + ------- + the inverted image + """ + out = tf.image.rgb_to_grayscale(img) # Convert to gray + # Random RGB shift + shift_shape = [img.shape[0], 1, 1, 3] if img.ndim == 4 else [1, 1, 3] + rgb_shift = tf.random.uniform(shape=shift_shape, minval=min_val, maxval=1) + # Inverse the color + if out.dtype == tf.uint8: + out = tf.cast(tf.cast(out, dtype=rgb_shift.dtype) * rgb_shift, dtype=tf.uint8) + else: + out *= tf.cast(rgb_shift, dtype=out.dtype) + # Inverse the color + out = 255 - out if out.dtype == tf.uint8 else 1 - out + return out + + +def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf.Tensor: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + ---- + img: image to rotate + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + ------- + the rotated image (tensor) + """ + # Compute the expanded padding + h_crop, w_crop = 0, 0 + if expand: + exp_h, exp_w = compute_expanded_shape(img.shape[:-1], angle) + h_diff, w_diff = int(math.ceil(exp_h - img.shape[0])), int(math.ceil(exp_w - img.shape[1])) + h_pad, w_pad = max(h_diff, 0), max(w_diff, 0) + exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]])) + h_crop, w_crop = int(round(max(exp_img.shape[0] - exp_h, 0))), int(round(min(exp_img.shape[1] - exp_w, 0))) + else: + exp_img = img + + # Compute the rotation matrix + height, width = tf.cast(tf.shape(exp_img)[0], tf.float32), tf.cast(tf.shape(exp_img)[1], tf.float32) + cos_angle, sin_angle = tf.math.cos(angle * math.pi / 180.0), tf.math.sin(angle * math.pi / 180.0) + x_offset = ((width - 1) - (cos_angle * (width - 1) - sin_angle * (height - 1))) / 2.0 + y_offset = ((height - 1) - (sin_angle * (width - 1) + cos_angle * (height - 1))) / 2.0 + + rotation_matrix = tf.convert_to_tensor( + [cos_angle, -sin_angle, x_offset, sin_angle, cos_angle, y_offset, 0.0, 0.0], + dtype=tf.float32, + ) + # Rotate the image + rotated_img = tf.squeeze( + tf.raw_ops.ImageProjectiveTransformV3( + images=exp_img[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3 + transforms=rotation_matrix[None], # Add a batch dimension for compatibility with ImageProjectiveTransformV3 + output_shape=tf.shape(exp_img)[:2], + interpolation="NEAREST", + fill_mode="CONSTANT", + fill_value=tf.constant(0.0, dtype=tf.float32), + ) + ) + # Crop the rest + if h_crop > 0 or w_crop > 0: + h_slice = slice(h_crop // 2, -h_crop // 2) if h_crop > 0 else slice(rotated_img.shape[0]) + w_slice = slice(-w_crop // 2, -w_crop // 2) if w_crop > 0 else slice(rotated_img.shape[1]) + rotated_img = rotated_img[h_slice, w_slice] + + return rotated_img + + +def rotate_sample( + img: tf.Tensor, + geoms: np.ndarray, + angle: float, + expand: bool = False, +) -> Tuple[tf.Tensor, np.ndarray]: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + ---- + img: image to rotate + geoms: array of geometries of shape (N, 4) or (N, 4, 2) + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + ------- + A tuple of rotated img (tensor), rotated boxes (np array) + """ + # Rotated the image + rotated_img = rotated_img_tensor(img, angle, expand) + + # Get absolute coords + _geoms = deepcopy(geoms) + if _geoms.shape[1:] == (4,): + if np.max(_geoms) <= 1: + _geoms[:, [0, 2]] *= img.shape[1] + _geoms[:, [1, 3]] *= img.shape[0] + elif _geoms.shape[1:] == (4, 2): + if np.max(_geoms) <= 1: + _geoms[..., 0] *= img.shape[1] + _geoms[..., 1] *= img.shape[0] + else: + raise AssertionError + + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon + rotated_geoms: np.ndarray = rotate_abs_geoms(_geoms, angle, img.shape[:-1], expand).astype(np.float32) + + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1] + rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0] + + return rotated_img, np.clip(rotated_geoms, 0, 1) + + +def crop_detection( + img: tf.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float] +) -> Tuple[tf.Tensor, np.ndarray]: + """Crop and image and associated bboxes + + Args: + ---- + img: image to crop + boxes: array of boxes to clip, absolute (int) or relative (float) + crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. + + Returns: + ------- + A tuple of cropped image, cropped boxes, where the image is not resized. + """ + if any(val < 0 or val > 1 for val in crop_box): + raise AssertionError("coordinates of arg `crop_box` should be relative") + h, w = img.shape[:2] + xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1))) + xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1))) + cropped_img = tf.image.crop_to_bounding_box(img, ymin, xmin, ymax - ymin, xmax - xmin) + # Crop the box + boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax)) + + return cropped_img, boxes + + +def _gaussian_filter( + img: tf.Tensor, + kernel_size: Union[int, Iterable[int]], + sigma: float, + mode: Optional[str] = None, + pad_value: Optional[int] = 0, +): + """Apply Gaussian filter to image. + Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py + + Args: + ---- + img: image to filter of shape (N, H, W, C) + kernel_size: kernel size of the filter + sigma: standard deviation of the Gaussian filter + mode: padding mode, one of "CONSTANT", "REFLECT", "SYMMETRIC" + pad_value: value to pad the image with + + Returns: + ------- + A tensor of shape (N, H, W, C) + """ + ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32) + sigma = tf.convert_to_tensor(tf.broadcast_to(sigma, [2]), dtype=img.dtype) + assert mode in ("CONSTANT", "REFLECT", "SYMMETRIC"), "mode should be one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'" + mode = "CONSTANT" if mode is None else str.upper(mode) + constant_values = ( + tf.zeros([], dtype=img.dtype) if pad_value is None else tf.convert_to_tensor(pad_value, dtype=img.dtype) + ) + + def kernel1d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType): + x = tf.range(ksize, dtype=dtype) + x = x - tf.cast(tf.math.floordiv(ksize, 2), dtype=dtype) + x = x + tf.where(tf.math.equal(tf.math.mod(ksize, 2), 0), tf.cast(0.5, dtype), 0) + g = tf.math.exp(-(tf.math.pow(x, 2) / (2 * tf.math.pow(sigma, 2)))) + g = g / tf.reduce_sum(g) + return g + + def kernel2d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType): + kernel_x = kernel1d(ksize[0], sigma[0], dtype) + kernel_y = kernel1d(ksize[1], sigma[1], dtype) + return tf.matmul( + tf.expand_dims(kernel_x, axis=-1), + tf.transpose(tf.expand_dims(kernel_y, axis=-1)), + ) + + g = kernel2d(ksize, sigma, img.dtype) + # Pad the image + height, width = ksize[0], ksize[1] + paddings = [ + [0, 0], + [(height - 1) // 2, height - 1 - (height - 1) // 2], + [(width - 1) // 2, width - 1 - (width - 1) // 2], + [0, 0], + ] + img = tf.pad(img, paddings, mode=mode, constant_values=constant_values) + + channel = tf.shape(img)[-1] + shape = tf.concat([ksize, tf.constant([1, 1], ksize.dtype)], axis=0) + g = tf.reshape(g, shape) + shape = tf.concat([ksize, [channel], tf.constant([1], ksize.dtype)], axis=0) + g = tf.broadcast_to(g, shape) + return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC") + + +def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs) -> tf.Tensor: + """Apply a random shadow to a given image + + Args: + ---- + img: image to modify + opacity_range: the minimum and maximum desired opacity of the shadow + **kwargs: additional arguments to pass to `create_shadow_mask` + + Returns: + ------- + shadowed image + """ + shadow_mask = create_shadow_mask(img.shape[:2], **kwargs) + + opacity = np.random.uniform(*opacity_range) + shadow_tensor = 1 - tf.convert_to_tensor(shadow_mask[..., None], dtype=tf.float32) + + # Add some blur to make it believable + k = 7 + int(2 * 4 * random.random()) + sigma = random.uniform(0.5, 5.0) + shadow_tensor = _gaussian_filter( + shadow_tensor[tf.newaxis, ...], + kernel_size=k, + sigma=sigma, + mode="REFLECT", + ) + + return tf.squeeze(opacity * shadow_tensor * img + (1 - opacity) * img, axis=0) diff --git a/doctr/transforms/modules/__init__.py b/doctr/transforms/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4053ff5520cd64ae8bf2f165df2d7254b0e349fe --- /dev/null +++ b/doctr/transforms/modules/__init__.py @@ -0,0 +1,8 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +from .base import * + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/transforms/modules/__pycache__/__init__.cpython-311.pyc b/doctr/transforms/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e02edb64fe4377c09e02be439a37e1ac64ffae23 Binary files /dev/null and b/doctr/transforms/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/transforms/modules/__pycache__/__init__.cpython-38.pyc b/doctr/transforms/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c3c601c27b021cbad221bda66e6c898be7a1af7 Binary files /dev/null and b/doctr/transforms/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/transforms/modules/__pycache__/base.cpython-311.pyc b/doctr/transforms/modules/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d2e5c3e04968d9fba8b12f30470bb07e1d0b643 Binary files /dev/null and b/doctr/transforms/modules/__pycache__/base.cpython-311.pyc differ diff --git a/doctr/transforms/modules/__pycache__/base.cpython-38.pyc b/doctr/transforms/modules/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..030b2ab1912a72ca29e3fb8bb8215fad2b9428fe Binary files /dev/null and b/doctr/transforms/modules/__pycache__/base.cpython-38.pyc differ diff --git a/doctr/transforms/modules/__pycache__/pytorch.cpython-311.pyc b/doctr/transforms/modules/__pycache__/pytorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad30cc9ac1c6012803476c1ded3ab67f4096ac53 Binary files /dev/null and b/doctr/transforms/modules/__pycache__/pytorch.cpython-311.pyc differ diff --git a/doctr/transforms/modules/__pycache__/tensorflow.cpython-311.pyc b/doctr/transforms/modules/__pycache__/tensorflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d5ec2a2721c0899fff4cbca989c978acccb64c6 Binary files /dev/null and b/doctr/transforms/modules/__pycache__/tensorflow.cpython-311.pyc differ diff --git a/doctr/transforms/modules/__pycache__/tensorflow.cpython-38.pyc b/doctr/transforms/modules/__pycache__/tensorflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55868414cc2840ee552d22335ea7c90169b6858a Binary files /dev/null and b/doctr/transforms/modules/__pycache__/tensorflow.cpython-38.pyc differ diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..25d15c98ef0c065da943d7a29053e0dfd5253ef5 --- /dev/null +++ b/doctr/transforms/modules/base.py @@ -0,0 +1,299 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +import random +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np + +from doctr.utils.repr import NestedObject + +from .. import functional as F + +__all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "RandomApply", "RandomRotate", "RandomCrop"] + + +class SampleCompose(NestedObject): + """Implements a wrapper that will apply transformations sequentially on both image and target + + .. tabs:: + + .. tab:: TensorFlow + + .. code:: python + + >>> import numpy as np + >>> import tensorflow as tf + >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate + >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4))) + + .. tab:: PyTorch + + .. code:: python + + >>> import numpy as np + >>> import torch + >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate + >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4))) + + Args: + ---- + transforms: list of transformation modules + """ + + _children_names: List[str] = ["sample_transforms"] + + def __init__(self, transforms: List[Callable[[Any, Any], Tuple[Any, Any]]]) -> None: + self.sample_transforms = transforms + + def __call__(self, x: Any, target: Any) -> Tuple[Any, Any]: + for t in self.sample_transforms: + x, target = t(x, target) + + return x, target + + +class ImageTransform(NestedObject): + """Implements a transform wrapper to turn an image-only transformation into an image+target transform + + .. tabs:: + + .. tab:: TensorFlow + + .. code:: python + + >>> import tensorflow as tf + >>> from doctr.transforms import ImageTransform, ColorInversion + >>> transfo = ImageTransform(ColorInversion((32, 32))) + >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None) + + .. tab:: PyTorch + + .. code:: python + + >>> import torch + >>> from doctr.transforms import ImageTransform, ColorInversion + >>> transfo = ImageTransform(ColorInversion((32, 32))) + >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None) + + Args: + ---- + transform: the image transformation module to wrap + """ + + _children_names: List[str] = ["img_transform"] + + def __init__(self, transform: Callable[[Any], Any]) -> None: + self.img_transform = transform + + def __call__(self, img: Any, target: Any) -> Tuple[Any, Any]: + img = self.img_transform(img) + return img, target + + +class ColorInversion(NestedObject): + """Applies the following tranformation to a tensor (image or batch of images): + convert to grayscale, colorize (shift 0-values randomly), and then invert colors + + .. tabs:: + + .. tab:: TensorFlow + + .. code:: python + + >>> import tensorflow as tf + >>> from doctr.transforms import ColorInversion + >>> transfo = ColorInversion(min_val=0.6) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + .. tab:: PyTorch + + .. code:: python + + >>> import torch + >>> from doctr.transforms import ColorInversion + >>> transfo = ColorInversion(min_val=0.6) + >>> out = transfo(torch.rand(8, 64, 64, 3)) + + Args: + ---- + min_val: range [min_val, 1] to colorize RGB pixels + """ + + def __init__(self, min_val: float = 0.5) -> None: + self.min_val = min_val + + def extra_repr(self) -> str: + return f"min_val={self.min_val}" + + def __call__(self, img: Any) -> Any: + return F.invert_colors(img, self.min_val) + + +class OneOf(NestedObject): + """Randomly apply one of the input transformations + + .. tabs:: + + .. tab:: TensorFlow + + .. code:: python + + >>> import tensorflow as tf + >>> from doctr.transforms import OneOf + >>> transfo = OneOf([JpegQuality(), Gamma()]) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + .. tab:: PyTorch + + .. code:: python + + >>> import torch + >>> from doctr.transforms import OneOf + >>> transfo = OneOf([JpegQuality(), Gamma()]) + >>> out = transfo(torch.rand(1, 64, 64, 3)) + + Args: + ---- + transforms: list of transformations, one only will be picked + """ + + _children_names: List[str] = ["transforms"] + + def __init__(self, transforms: List[Callable[[Any], Any]]) -> None: + self.transforms = transforms + + def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]: + # Pick transformation + transfo = self.transforms[int(random.random() * len(self.transforms))] + # Apply + return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg] + + +class RandomApply(NestedObject): + """Apply with a probability p the input transformation + + .. tabs:: + + .. tab:: TensorFlow + + .. code:: python + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomApply + >>> transfo = RandomApply(Gamma(), p=.5) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + .. tab:: PyTorch + + .. code:: python + + >>> import torch + >>> from doctr.transforms import RandomApply + >>> transfo = RandomApply(Gamma(), p=.5) + >>> out = transfo(torch.rand(1, 64, 64, 3)) + + Args: + ---- + transform: transformation to apply + p: probability to apply + """ + + def __init__(self, transform: Callable[[Any], Any], p: float = 0.5) -> None: + self.transform = transform + self.p = p + + def extra_repr(self) -> str: + return f"transform={self.transform}, p={self.p}" + + def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]: + if random.random() < self.p: + return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg] + return img if target is None else (img, target) + + +class RandomRotate(NestedObject): + """Randomly rotate a tensor image and its boxes + + .. image:: https://doctr-static.mindee.com/models?id=v0.4.0/rotation_illustration.png&src=0 + :align: center + + Args: + ---- + max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in + [-max_angle, max_angle] + expand: whether the image should be padded before the rotation + """ + + def __init__(self, max_angle: float = 5.0, expand: bool = False) -> None: + self.max_angle = max_angle + self.expand = expand + + def extra_repr(self) -> str: + return f"max_angle={self.max_angle}, expand={self.expand}" + + def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]: + angle = random.uniform(-self.max_angle, self.max_angle) + r_img, r_polys = F.rotate_sample(img, target, angle, self.expand) + # Removes deleted boxes + is_kept = (r_polys.max(1) > r_polys.min(1)).sum(1) == 2 + return r_img, r_polys[is_kept] + + +class RandomCrop(NestedObject): + """Randomly crop a tensor image and its boxes + + Args: + ---- + scale: tuple of floats, relative (min_area, max_area) of the crop + ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w + """ + + def __init__(self, scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (0.75, 1.33)) -> None: + self.scale = scale + self.ratio = ratio + + def extra_repr(self) -> str: + return f"scale={self.scale}, ratio={self.ratio}" + + def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]: + scale = random.uniform(self.scale[0], self.scale[1]) + ratio = random.uniform(self.ratio[0], self.ratio[1]) + + height, width = img.shape[:2] + + # Calculate crop size + crop_area = scale * width * height + aspect_ratio = ratio * (width / height) + crop_width = int(round(math.sqrt(crop_area * aspect_ratio))) + crop_height = int(round(math.sqrt(crop_area / aspect_ratio))) + + # Ensure crop size does not exceed image dimensions + crop_width = min(crop_width, width) + crop_height = min(crop_height, height) + + # Randomly select crop position + x = random.randint(0, width - crop_width) + y = random.randint(0, height - crop_height) + + # relative crop box + crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height) + if target.shape[1:] == (4, 2): + min_xy = np.min(target, axis=1) + max_xy = np.max(target, axis=1) + _target = np.concatenate((min_xy, max_xy), axis=1) + else: + _target = target + + # Crop image and targets + croped_img, crop_boxes = F.crop_detection(img, _target, crop_box) + # hard fallback if no box is kept + if crop_boxes.shape[0] == 0: + return img, target + # clip boxes + return croped_img, np.clip(crop_boxes, 0, 1) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..f893afc2f7ea62cf5c6278b73616b6967ac3cfe0 --- /dev/null +++ b/doctr/transforms/modules/pytorch.py @@ -0,0 +1,270 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image +from torch.nn.functional import pad +from torchvision.transforms import functional as F +from torchvision.transforms import transforms as T + +from ..functional.pytorch import random_shadow + +__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"] + + +class Resize(T.Resize): + """Resize the input image to the given size""" + + def __init__( + self, + size: Union[int, Tuple[int, int]], + interpolation=F.InterpolationMode.BILINEAR, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = False, + ) -> None: + super().__init__(size, interpolation, antialias=True) + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + + if not isinstance(self.size, (int, tuple, list)): + raise AssertionError("size should be either a tuple, a list or an int") + + def forward( + self, + img: torch.Tensor, + target: Optional[np.ndarray] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, np.ndarray]]: + if isinstance(self.size, int): + target_ratio = img.shape[-2] / img.shape[-1] + else: + target_ratio = self.size[0] / self.size[1] + actual_ratio = img.shape[-2] / img.shape[-1] + + if not self.preserve_aspect_ratio or (target_ratio == actual_ratio and (isinstance(self.size, (tuple, list)))): + # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one + # We can use with the regular resize + if target is not None: + return super().forward(img), target + return super().forward(img) + else: + # Resize + if isinstance(self.size, (tuple, list)): + if actual_ratio > target_ratio: + tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) + else: + tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + elif isinstance(self.size, int): # self.size is the longest side, infer the other + if img.shape[-2] <= img.shape[-1]: + tmp_size = (max(int(self.size * actual_ratio), 1), self.size) + else: + tmp_size = (self.size, max(int(self.size / actual_ratio), 1)) + + # Scale image + img = F.resize(img, tmp_size, self.interpolation, antialias=True) + raw_shape = img.shape[-2:] + if isinstance(self.size, (tuple, list)): + # Pad (inverted in pytorch) + _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) + if self.symmetric_pad: + half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) + _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) + img = pad(img, _pad) + + # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) + if target is not None: + if self.preserve_aspect_ratio: + # Get absolute coords + if target.shape[1:] == (4,): + if isinstance(self.size, (tuple, list)) and self.symmetric_pad: + if np.max(target) <= 1: + offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] + target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1] + target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2] + else: + target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1] + target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2] + elif target.shape[1:] == (4, 2): + if isinstance(self.size, (tuple, list)) and self.symmetric_pad: + if np.max(target) <= 1: + offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] + target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1] + target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2] + else: + target[..., 0] *= raw_shape[-1] / img.shape[-1] + target[..., 1] *= raw_shape[-2] / img.shape[-2] + else: + raise AssertionError + return img, target + + return img + + def __repr__(self) -> str: + interpolate_str = self.interpolation.value + _repr = f"output_size={self.size}, interpolation='{interpolate_str}'" + if self.preserve_aspect_ratio: + _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" + return f"{self.__class__.__name__}({_repr})" + + +class GaussianNoise(torch.nn.Module): + """Adds Gaussian Noise to the input tensor + + >>> import torch + >>> from doctr.transforms import GaussianNoise + >>> transfo = GaussianNoise(0., 1.) + >>> out = transfo(torch.rand((3, 224, 224))) + + Args: + ---- + mean : mean of the gaussian distribution + std : std of the gaussian distribution + """ + + def __init__(self, mean: float = 0.0, std: float = 1.0) -> None: + super().__init__() + self.std = std + self.mean = mean + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Reshape the distribution + noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std + if x.dtype == torch.uint8: + return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) + else: + return (x + noise.to(dtype=x.dtype)).clamp(0, 1) + + def extra_repr(self) -> str: + return f"mean={self.mean}, std={self.std}" + + +class ChannelShuffle(torch.nn.Module): + """Randomly shuffle channel order of a given image""" + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + # Get a random order + chan_order = torch.rand(img.shape[0]).argsort() + return img[chan_order] + + +class RandomHorizontalFlip(T.RandomHorizontalFlip): + """Randomly flip the input image horizontally""" + + def forward( + self, img: Union[torch.Tensor, Image], target: np.ndarray + ) -> Tuple[Union[torch.Tensor, Image], np.ndarray]: + if torch.rand(1) < self.p: + _img = F.hflip(img) + _target = target.copy() + # Changing the relative bbox coordinates + if target.shape[1:] == (4,): + _target[:, ::2] = 1 - target[:, [2, 0]] + else: + _target[..., 0] = 1 - target[..., 0] + return _img, _target + return img, target + + +class RandomShadow(torch.nn.Module): + """Adds random shade to the input image + + >>> import torch + >>> from doctr.transforms import RandomShadow + >>> transfo = RandomShadow((0., 1.)) + >>> out = transfo(torch.rand((3, 64, 64))) + + Args: + ---- + opacity_range : minimum and maximum opacity of the shade + """ + + def __init__(self, opacity_range: Optional[Tuple[float, float]] = None) -> None: + super().__init__() + self.opacity_range = opacity_range if isinstance(opacity_range, tuple) else (0.2, 0.8) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + # Reshape the distribution + try: + if x.dtype == torch.uint8: + return ( + ( + 255 + * random_shadow( + x.to(dtype=torch.float32) / 255, + self.opacity_range, + ) + ) + .round() + .clip(0, 255) + .to(dtype=torch.uint8) + ) + else: + return random_shadow(x, self.opacity_range).clip(0, 1) + except ValueError: + return x + + def extra_repr(self) -> str: + return f"opacity_range={self.opacity_range}" + + +class RandomResize(torch.nn.Module): + """Randomly resize the input image and align corresponding targets + + >>> import torch + >>> from doctr.transforms import RandomResize + >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5) + >>> out = transfo(torch.rand((3, 64, 64))) + + Args: + ---- + scale_range: range of the resizing factor for width and height (independently) + preserve_aspect_ratio: whether to preserve the aspect ratio of the image, + given a float value, the aspect ratio will be preserved with this probability + symmetric_pad: whether to symmetrically pad the image, + given a float value, the symmetric padding will be applied with this probability + p: probability to apply the transformation + """ + + def __init__( + self, + scale_range: Tuple[float, float] = (0.3, 0.9), + preserve_aspect_ratio: Union[bool, float] = False, + symmetric_pad: Union[bool, float] = False, + p: float = 0.5, + ) -> None: + super().__init__() + self.scale_range = scale_range + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.p = p + self._resize = Resize + + def forward(self, img: torch.Tensor, target: np.ndarray) -> Tuple[torch.Tensor, np.ndarray]: + if torch.rand(1) < self.p: + scale_h = np.random.uniform(*self.scale_range) + scale_w = np.random.uniform(*self.scale_range) + new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w)) + + _img, _target = self._resize( + new_size, + preserve_aspect_ratio=self.preserve_aspect_ratio + if isinstance(self.preserve_aspect_ratio, bool) + else bool(torch.rand(1) <= self.symmetric_pad), + symmetric_pad=self.symmetric_pad + if isinstance(self.symmetric_pad, bool) + else bool(torch.rand(1) <= self.symmetric_pad), + )(img, target) + + return _img, _target + return img, target + + def extra_repr(self) -> str: + return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501 diff --git a/doctr/transforms/modules/tensorflow.py b/doctr/transforms/modules/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f7bcfd8aa15c68b8a8044ff30f3978c393cafa --- /dev/null +++ b/doctr/transforms/modules/tensorflow.py @@ -0,0 +1,573 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import random +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.utils.repr import NestedObject + +from ..functional.tensorflow import _gaussian_filter, random_shadow + +__all__ = [ + "Compose", + "Resize", + "Normalize", + "LambdaTransformation", + "ToGray", + "RandomBrightness", + "RandomContrast", + "RandomSaturation", + "RandomHue", + "RandomGamma", + "RandomJpegQuality", + "GaussianBlur", + "ChannelShuffle", + "GaussianNoise", + "RandomHorizontalFlip", + "RandomShadow", + "RandomResize", +] + + +class Compose(NestedObject): + """Implements a wrapper that will apply transformations sequentially + + >>> import tensorflow as tf + >>> from doctr.transforms import Compose, Resize + >>> transfos = Compose([Resize((32, 32))]) + >>> out = transfos(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + transforms: list of transformation modules + """ + + _children_names: List[str] = ["transforms"] + + def __init__(self, transforms: List[Callable[[Any], Any]]) -> None: + self.transforms = transforms + + def __call__(self, x: Any) -> Any: + for t in self.transforms: + x = t(x) + + return x + + +class Resize(NestedObject): + """Resizes a tensor to a target size + + >>> import tensorflow as tf + >>> from doctr.transforms import Resize + >>> transfo = Resize((32, 32)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + output_size: expected output size + method: interpolation method + preserve_aspect_ratio: if `True`, preserve aspect ratio and pad the rest with zeros + symmetric_pad: if `True` while preserving aspect ratio, the padding will be done symmetrically + """ + + def __init__( + self, + output_size: Union[int, Tuple[int, int]], + method: str = "bilinear", + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = False, + ) -> None: + self.output_size = output_size + self.method = method + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.antialias = True + + if isinstance(self.output_size, int): + self.wanted_size = (self.output_size, self.output_size) + elif isinstance(self.output_size, (tuple, list)): + self.wanted_size = self.output_size + else: + raise AssertionError("Output size should be either a list, a tuple or an int") + + def extra_repr(self) -> str: + _repr = f"output_size={self.output_size}, method='{self.method}'" + if self.preserve_aspect_ratio: + _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" + return _repr + + def __call__( + self, + img: tf.Tensor, + target: Optional[np.ndarray] = None, + ) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]: + input_dtype = img.dtype + + img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias) + # It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio + raw_shape = img.shape[:2] + if self.preserve_aspect_ratio: + if isinstance(self.output_size, (tuple, list)): + # In that case we need to pad because we want to enforce both width and height + if not self.symmetric_pad: + offset = (0, 0) + elif self.output_size[0] == img.shape[0]: + offset = (0, int((self.output_size[1] - img.shape[1]) / 2)) + else: + offset = (int((self.output_size[0] - img.shape[0]) / 2), 0) + img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size) + + # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) + if target is not None: + if self.preserve_aspect_ratio: + # Get absolute coords + if target.shape[1:] == (4,): + if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad: + if np.max(target) <= 1: + offset = offset[0] / img.shape[0], offset[1] / img.shape[1] + target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1] + target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0] + else: + target[:, [0, 2]] *= raw_shape[1] / img.shape[1] + target[:, [1, 3]] *= raw_shape[0] / img.shape[0] + elif target.shape[1:] == (4, 2): + if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad: + if np.max(target) <= 1: + offset = offset[0] / img.shape[0], offset[1] / img.shape[1] + target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1] + target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0] + else: + target[..., 0] *= raw_shape[1] / img.shape[1] + target[..., 1] *= raw_shape[0] / img.shape[0] + else: + raise AssertionError + return tf.cast(img, dtype=input_dtype), target + + return tf.cast(img, dtype=input_dtype) + + +class Normalize(NestedObject): + """Normalize a tensor to a Gaussian distribution for each channel + + >>> import tensorflow as tf + >>> from doctr.transforms import Normalize + >>> transfo = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + mean: average value per channel + std: standard deviation per channel + """ + + def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None: + self.mean = tf.constant(mean) + self.std = tf.constant(std) + + def extra_repr(self) -> str: + return f"mean={self.mean.numpy().tolist()}, std={self.std.numpy().tolist()}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + img -= tf.cast(self.mean, dtype=img.dtype) + img /= tf.cast(self.std, dtype=img.dtype) + return img + + +class LambdaTransformation(NestedObject): + """Normalize a tensor to a Gaussian distribution for each channel + + >>> import tensorflow as tf + >>> from doctr.transforms import LambdaTransformation + >>> transfo = LambdaTransformation(lambda x: x/ 255.) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + fn: the function to be applied to the input tensor + """ + + def __init__(self, fn: Callable[[tf.Tensor], tf.Tensor]) -> None: + self.fn = fn + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return self.fn(img) + + +class ToGray(NestedObject): + """Convert a RGB tensor (batch of images or image) to a 3-channels grayscale tensor + + >>> import tensorflow as tf + >>> from doctr.transforms import ToGray + >>> transfo = ToGray() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + """ + + def __init__(self, num_output_channels: int = 1): + self.num_output_channels = num_output_channels + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + img = tf.image.rgb_to_grayscale(img) + return img if self.num_output_channels == 1 else tf.repeat(img, self.num_output_channels, axis=-1) + + +class RandomBrightness(NestedObject): + """Randomly adjust brightness of a tensor (batch of images or image) by adding a delta + to all pixels + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomBrightness + >>> transfo = RandomBrightness() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] + p: probability to apply transformation + """ + + def __init__(self, max_delta: float = 0.3) -> None: + self.max_delta = max_delta + + def extra_repr(self) -> str: + return f"max_delta={self.max_delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_brightness(img, max_delta=self.max_delta) + + +class RandomContrast(NestedObject): + """Randomly adjust contrast of a tensor (batch of images or image) by adjusting + each pixel: (img - mean) * contrast_factor + mean. + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomContrast + >>> transfo = RandomContrast() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce contrast if factor<1) + """ + + def __init__(self, delta: float = 0.3) -> None: + self.delta = delta + + def extra_repr(self) -> str: + return f"delta={self.delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_contrast(img, lower=1 - self.delta, upper=1 / (1 - self.delta)) + + +class RandomSaturation(NestedObject): + """Randomly adjust saturation of a tensor (batch of images or image) by converting to HSV and + increasing saturation by a factor. + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomSaturation + >>> transfo = RandomSaturation() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce saturation if factor<1) + """ + + def __init__(self, delta: float = 0.5) -> None: + self.delta = delta + + def extra_repr(self) -> str: + return f"delta={self.delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_saturation(img, lower=1 - self.delta, upper=1 + self.delta) + + +class RandomHue(NestedObject): + """Randomly adjust hue of a tensor (batch of images or image) by converting to HSV and adding a delta + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomHue + >>> transfo = RandomHue() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] + """ + + def __init__(self, max_delta: float = 0.3) -> None: + self.max_delta = max_delta + + def extra_repr(self) -> str: + return f"max_delta={self.max_delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_hue(img, max_delta=self.max_delta) + + +class RandomGamma(NestedObject): + """randomly performs gamma correction for a tensor (batch of images or image) + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomGamma + >>> transfo = RandomGamma() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + min_gamma: non-negative real number, lower bound for gamma param + max_gamma: non-negative real number, upper bound for gamma + min_gain: lower bound for constant multiplier + max_gain: upper bound for constant multiplier + """ + + def __init__( + self, + min_gamma: float = 0.5, + max_gamma: float = 1.5, + min_gain: float = 0.8, + max_gain: float = 1.2, + ) -> None: + self.min_gamma = min_gamma + self.max_gamma = max_gamma + self.min_gain = min_gain + self.max_gain = max_gain + + def extra_repr(self) -> str: + return f"""gamma_range=({self.min_gamma}, {self.max_gamma}), + gain_range=({self.min_gain}, {self.max_gain})""" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + gamma = random.uniform(self.min_gamma, self.max_gamma) + gain = random.uniform(self.min_gain, self.max_gain) + return tf.image.adjust_gamma(img, gamma=gamma, gain=gain) + + +class RandomJpegQuality(NestedObject): + """Randomly adjust jpeg quality of a 3 dimensional RGB image + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomJpegQuality + >>> transfo = RandomJpegQuality() + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + min_quality: int between [0, 100] + max_quality: int between [0, 100] + """ + + def __init__(self, min_quality: int = 60, max_quality: int = 100) -> None: + self.min_quality = min_quality + self.max_quality = max_quality + + def extra_repr(self) -> str: + return f"min_quality={self.min_quality}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_jpeg_quality(img, min_jpeg_quality=self.min_quality, max_jpeg_quality=self.max_quality) + + +class GaussianBlur(NestedObject): + """Randomly adjust jpeg quality of a 3 dimensional RGB image + + >>> import tensorflow as tf + >>> from doctr.transforms import GaussianBlur + >>> transfo = GaussianBlur(3, (.1, 5)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + kernel_shape: size of the blurring kernel + std: min and max value of the standard deviation + """ + + def __init__(self, kernel_shape: Union[int, Iterable[int]], std: Tuple[float, float]) -> None: + self.kernel_shape = kernel_shape + self.std = std + + def extra_repr(self) -> str: + return f"kernel_shape={self.kernel_shape}, std={self.std}" + + @tf.function + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.squeeze( + _gaussian_filter( + img[tf.newaxis, ...], + kernel_size=self.kernel_shape, + sigma=random.uniform(self.std[0], self.std[1]), + mode="REFLECT", + ), + axis=0, + ) + + +class ChannelShuffle(NestedObject): + """Randomly shuffle channel order of a given image""" + + def __init__(self): + pass + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.transpose(tf.random.shuffle(tf.transpose(img, perm=[2, 0, 1])), perm=[1, 2, 0]) + + +class GaussianNoise(NestedObject): + """Adds Gaussian Noise to the input tensor + + >>> import tensorflow as tf + >>> from doctr.transforms import GaussianNoise + >>> transfo = GaussianNoise(0., 1.) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + mean : mean of the gaussian distribution + std : std of the gaussian distribution + """ + + def __init__(self, mean: float = 0.0, std: float = 1.0) -> None: + super().__init__() + self.std = std + self.mean = mean + + def __call__(self, x: tf.Tensor) -> tf.Tensor: + # Reshape the distribution + noise = self.mean + 2 * self.std * tf.random.uniform(x.shape) - self.std + if x.dtype == tf.uint8: + return tf.cast( + tf.clip_by_value(tf.math.round(tf.cast(x, dtype=tf.float32) + 255 * noise), 0, 255), dtype=tf.uint8 + ) + else: + return tf.cast(tf.clip_by_value(x + noise, 0, 1), dtype=x.dtype) + + def extra_repr(self) -> str: + return f"mean={self.mean}, std={self.std}" + + +class RandomHorizontalFlip(NestedObject): + """Adds random horizontal flip to the input tensor/np.ndarray + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomHorizontalFlip + >>> transfo = RandomHorizontalFlip(p=0.5) + >>> image = tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1) + >>> target = np.array([[0.1, 0.1, 0.4, 0.5] ], dtype= np.float32) + >>> out = transfo(image, target) + + Args: + ---- + p : probability of Horizontal Flip + """ + + def __init__(self, p: float) -> None: + super().__init__() + self.p = p + + def __call__(self, img: Union[tf.Tensor, np.ndarray], target: np.ndarray) -> Tuple[tf.Tensor, np.ndarray]: + if np.random.rand(1) <= self.p: + _img = tf.image.flip_left_right(img) + _target = target.copy() + # Changing the relative bbox coordinates + if target.shape[1:] == (4,): + _target[:, ::2] = 1 - target[:, [2, 0]] + else: + _target[..., 0] = 1 - target[..., 0] + return _img, _target + return img, target + + +class RandomShadow(NestedObject): + """Adds random shade to the input image + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomShadow + >>> transfo = RandomShadow(0., 1.) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + opacity_range : minimum and maximum opacity of the shade + """ + + def __init__(self, opacity_range: Optional[Tuple[float, float]] = None) -> None: + super().__init__() + self.opacity_range = opacity_range if isinstance(opacity_range, tuple) else (0.2, 0.8) + + def __call__(self, x: tf.Tensor) -> tf.Tensor: + # Reshape the distribution + if x.dtype == tf.uint8: + return tf.cast( + tf.clip_by_value( + tf.math.round(255 * random_shadow(tf.cast(x, dtype=tf.float32) / 255, self.opacity_range)), + 0, + 255, + ), + dtype=tf.uint8, + ) + else: + return tf.clip_by_value(random_shadow(x, self.opacity_range), 0, 1) + + def extra_repr(self) -> str: + return f"opacity_range={self.opacity_range}" + + +class RandomResize(NestedObject): + """Randomly resize the input image and align corresponding targets + + >>> import tensorflow as tf + >>> from doctr.transforms import RandomResize + >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + ---- + scale_range: range of the resizing factor for width and height (independently) + preserve_aspect_ratio: whether to preserve the aspect ratio of the image, + given a float value, the aspect ratio will be preserved with this probability + symmetric_pad: whether to symmetrically pad the image, + given a float value, the symmetric padding will be applied with this probability + p: probability to apply the transformation + """ + + def __init__( + self, + scale_range: Tuple[float, float] = (0.3, 0.9), + preserve_aspect_ratio: Union[bool, float] = False, + symmetric_pad: Union[bool, float] = False, + p: float = 0.5, + ): + super().__init__() + self.scale_range = scale_range + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + self.p = p + self._resize = Resize + + def __call__(self, img: tf.Tensor, target: np.ndarray) -> Tuple[tf.Tensor, np.ndarray]: + if np.random.rand(1) <= self.p: + scale_h = random.uniform(*self.scale_range) + scale_w = random.uniform(*self.scale_range) + new_size = (int(img.shape[-3] * scale_h), int(img.shape[-2] * scale_w)) + + _img, _target = self._resize( + new_size, + preserve_aspect_ratio=self.preserve_aspect_ratio + if isinstance(self.preserve_aspect_ratio, bool) + else bool(np.random.rand(1) <= self.symmetric_pad), + symmetric_pad=self.symmetric_pad + if isinstance(self.symmetric_pad, bool) + else bool(np.random.rand(1) <= self.symmetric_pad), + )(img, target) + + return _img, _target + return img, target + + def extra_repr(self) -> str: + return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501 diff --git a/doctr/utils/__init__.py b/doctr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb9b15920cc1c971c5391ebafecb9c7289a157e --- /dev/null +++ b/doctr/utils/__init__.py @@ -0,0 +1,4 @@ +from .common_types import * +from .data import * +from .geometry import * +from .metrics import * diff --git a/doctr/utils/__pycache__/__init__.cpython-310.pyc b/doctr/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bfb1cb97fe6754f44aaf62bae4e0a082d264ec0 Binary files /dev/null and b/doctr/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/__init__.cpython-311.pyc b/doctr/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a253a23b2b17bde4aa2e7089f22f0200773ca0e4 Binary files /dev/null and b/doctr/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/__init__.cpython-38.pyc b/doctr/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ffb20d00bd5242a1b3cededc37f1e4deed00c06 Binary files /dev/null and b/doctr/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/common_types.cpython-310.pyc b/doctr/utils/__pycache__/common_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab9d6017980e163cd8cc783433ed0560fae3607 Binary files /dev/null and b/doctr/utils/__pycache__/common_types.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/common_types.cpython-311.pyc b/doctr/utils/__pycache__/common_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a52ab4c5450372185b204535968221006bf3aaf2 Binary files /dev/null and b/doctr/utils/__pycache__/common_types.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/common_types.cpython-38.pyc b/doctr/utils/__pycache__/common_types.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81112584d1c7f16d651a766589303aaf41919418 Binary files /dev/null and b/doctr/utils/__pycache__/common_types.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/data.cpython-310.pyc b/doctr/utils/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dcc8d50c202c88c05e49fdfe422b65551cb969d Binary files /dev/null and b/doctr/utils/__pycache__/data.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/data.cpython-311.pyc b/doctr/utils/__pycache__/data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..993cc761325a985a9b178e913309c10fa4e5fa73 Binary files /dev/null and b/doctr/utils/__pycache__/data.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/data.cpython-38.pyc b/doctr/utils/__pycache__/data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..872aa76c31f7ec81e2ef5c3d0cde9c0c0b5b20bc Binary files /dev/null and b/doctr/utils/__pycache__/data.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/fonts.cpython-310.pyc b/doctr/utils/__pycache__/fonts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b352df5fc9e8c8d90489b1c7f27fddb7fe6b88a Binary files /dev/null and b/doctr/utils/__pycache__/fonts.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/fonts.cpython-311.pyc b/doctr/utils/__pycache__/fonts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3182022aa92f9e81c8986ec83cbb929819f525ad Binary files /dev/null and b/doctr/utils/__pycache__/fonts.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/fonts.cpython-38.pyc b/doctr/utils/__pycache__/fonts.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08680b54bf8d016a9f1f191570fae46ce61bb2cf Binary files /dev/null and b/doctr/utils/__pycache__/fonts.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/geometry.cpython-310.pyc b/doctr/utils/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42597067b002886a5f72ab8febb37f11098a8390 Binary files /dev/null and b/doctr/utils/__pycache__/geometry.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/geometry.cpython-311.pyc b/doctr/utils/__pycache__/geometry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c93df7e495ffee81d9b06091df8027a1a3498c8 Binary files /dev/null and b/doctr/utils/__pycache__/geometry.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/geometry.cpython-38.pyc b/doctr/utils/__pycache__/geometry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3caaf2637ee8564a016fa069d86693a348669f6 Binary files /dev/null and b/doctr/utils/__pycache__/geometry.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/metrics.cpython-310.pyc b/doctr/utils/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8574f75bac7fb8e3ee4c48de031b9213de90cea Binary files /dev/null and b/doctr/utils/__pycache__/metrics.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/metrics.cpython-311.pyc b/doctr/utils/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90790431c52e3ecef1a4a51214d192817c5bfbc3 Binary files /dev/null and b/doctr/utils/__pycache__/metrics.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/metrics.cpython-38.pyc b/doctr/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..babf6ed214572202ba44e58262740930d32c4644 Binary files /dev/null and b/doctr/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/multithreading.cpython-311.pyc b/doctr/utils/__pycache__/multithreading.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950728a31d6e0463a069db8283e3abe21e16a98f Binary files /dev/null and b/doctr/utils/__pycache__/multithreading.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/multithreading.cpython-38.pyc b/doctr/utils/__pycache__/multithreading.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07bbf98de9b66ba6566840bdfa58df93499181b7 Binary files /dev/null and b/doctr/utils/__pycache__/multithreading.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/reconstitution.cpython-310.pyc b/doctr/utils/__pycache__/reconstitution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ed19c0c59b1ab90d6a85f38af244cea88da229 Binary files /dev/null and b/doctr/utils/__pycache__/reconstitution.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/reconstitution.cpython-311.pyc b/doctr/utils/__pycache__/reconstitution.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c688ae7fa65a063ee8133dd8006421d00d2dc85 Binary files /dev/null and b/doctr/utils/__pycache__/reconstitution.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/reconstitution.cpython-38.pyc b/doctr/utils/__pycache__/reconstitution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7277050999d53f43ccfdea8724011dd3cb2f45a Binary files /dev/null and b/doctr/utils/__pycache__/reconstitution.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/repr.cpython-310.pyc b/doctr/utils/__pycache__/repr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92711f5904e609b85851ad622de5df09bef92b9f Binary files /dev/null and b/doctr/utils/__pycache__/repr.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/repr.cpython-311.pyc b/doctr/utils/__pycache__/repr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a1b19b0a653cb384809ea084cca1527a84521c Binary files /dev/null and b/doctr/utils/__pycache__/repr.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/repr.cpython-38.pyc b/doctr/utils/__pycache__/repr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e1fe1233291423d7d1dd55588dde762ee389624 Binary files /dev/null and b/doctr/utils/__pycache__/repr.cpython-38.pyc differ diff --git a/doctr/utils/__pycache__/visualization.cpython-310.pyc b/doctr/utils/__pycache__/visualization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c61a295b30e73720b2b8ae254e06ff785c7b220 Binary files /dev/null and b/doctr/utils/__pycache__/visualization.cpython-310.pyc differ diff --git a/doctr/utils/__pycache__/visualization.cpython-311.pyc b/doctr/utils/__pycache__/visualization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef46034a94fa51f4518dc79ed9536ff7cdb4b49c Binary files /dev/null and b/doctr/utils/__pycache__/visualization.cpython-311.pyc differ diff --git a/doctr/utils/__pycache__/visualization.cpython-38.pyc b/doctr/utils/__pycache__/visualization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efd9065409bd7898c2a0fa3a2bd6f2b84c0354dd Binary files /dev/null and b/doctr/utils/__pycache__/visualization.cpython-38.pyc differ diff --git a/doctr/utils/common_types.py b/doctr/utils/common_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a82e8db5d38a0f3493a7a68a99de19ad90254e91 --- /dev/null +++ b/doctr/utils/common_types.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Tuple, Union + +__all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"] + + +Point2D = Tuple[float, float] +BoundingBox = Tuple[Point2D, Point2D] +Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D] +Polygon = List[Point2D] +AbstractPath = Union[str, Path] +AbstractFile = Union[AbstractPath, bytes] +Bbox = Tuple[float, float, float, float] diff --git a/doctr/utils/data.py b/doctr/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..7aec7720d87ed102246acc49db0202cec1ae4c72 --- /dev/null +++ b/doctr/utils/data.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py + +import hashlib +import logging +import os +import re +import urllib +import urllib.error +import urllib.request +from pathlib import Path +from typing import Optional, Union + +from tqdm.auto import tqdm + +__all__ = ["download_from_url"] + + +# matches bfd8deac from resnet18-bfd8deac.ckpt +HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") +USER_AGENT = "mindee/doctr" + + +def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool: + with open(file_path, "rb") as f: + sha_hash = hashlib.sha256(f.read()).hexdigest() + + return sha_hash[: len(hash_prefix)] == hash_prefix + + +def download_from_url( + url: str, + file_name: Optional[str] = None, + hash_prefix: Optional[str] = None, + cache_dir: Optional[str] = None, + cache_subdir: Optional[str] = None, +) -> Path: + """Download a file using its URL + + >>> from doctr.models import download_from_url + >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + ---- + url: the URL of the file to download + file_name: optional name of the file once downloaded + hash_prefix: optional expected SHA256 hash of the file + cache_dir: cache directory + cache_subdir: subfolder to use in the cache + + Returns: + ------- + the location of the downloaded file + + Note: + ---- + You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable. + """ + if not isinstance(file_name, str): + file_name = url.rpartition("/")[-1].split("&")[0] + + cache_dir = ( + str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr"))) + if cache_dir is None + else cache_dir + ) + + # Check hash in file name + if hash_prefix is None: + r = HASH_REGEX.search(file_name) + hash_prefix = r.group(1) if r else None + + folder_path = Path(cache_dir) if cache_subdir is None else Path(cache_dir, cache_subdir) + file_path = folder_path.joinpath(file_name) + # Check file existence + if file_path.is_file() and (hash_prefix is None or _check_integrity(file_path, hash_prefix)): + logging.info(f"Using downloaded & verified file: {file_path}") + return file_path + + try: + # Create folder hierarchy + folder_path.mkdir(parents=True, exist_ok=True) + except OSError: + error_message = f"Failed creating cache direcotry at {folder_path}" + if os.environ.get("DOCTR_CACHE_DIR", ""): + error_message += " using path from 'DOCTR_CACHE_DIR' environment variable." + else: + error_message += ( + ". You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed." + ) + logging.error(error_message) + raise + # Download the file + try: + print(f"Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + except (urllib.error.URLError, IOError) as e: + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + else: + raise e + + # Remove corrupted files + if isinstance(hash_prefix, str) and not _check_integrity(file_path, hash_prefix): + # Remove file + os.remove(file_path) + raise ValueError(f"corrupted download, the hash of {url} does not match its expected value") + + return file_path diff --git a/doctr/utils/fonts.py b/doctr/utils/fonts.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f0772f683db8686256bcb79852230ff776784a --- /dev/null +++ b/doctr/utils/fonts.py @@ -0,0 +1,41 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +import platform +from typing import Optional + +from PIL import ImageFont + +__all__ = ["get_font"] + + +def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont: + """Resolves a compatible ImageFont for the system + + Args: + ---- + font_family: the font family to use + font_size: the size of the font upon rendering + + Returns: + ------- + the Pillow font + """ + # Font selection + if font_family is None: + try: + font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) + except OSError: + font = ImageFont.load_default() + logging.warning( + "unable to load recommended font family. Loading default PIL font," + "font size issues may be expected." + "To prevent this, it is recommended to specify the value of 'font_family'." + ) + else: + font = ImageFont.truetype(font_family, font_size) + + return font diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..fea02e375edc3a874b11a56b2c4db63dbecbc60e --- /dev/null +++ b/doctr/utils/geometry.py @@ -0,0 +1,456 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from math import ceil +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np + +from .common_types import BoundingBox, Polygon4P + +__all__ = [ + "bbox_to_polygon", + "polygon_to_bbox", + "resolve_enclosing_bbox", + "resolve_enclosing_rbbox", + "rotate_boxes", + "compute_expanded_shape", + "rotate_image", + "estimate_page_angle", + "convert_to_relative_coords", + "rotate_abs_geoms", + "extract_crops", + "extract_rcrops", +] + + +def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: + """Convert a bounding box to a polygon + + Args: + ---- + bbox: a bounding box + + Returns: + ------- + a polygon + """ + return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] + + +def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: + """Convert a polygon to a bounding box + + Args: + ---- + polygon: a polygon + + Returns: + ------- + a bounding box + """ + x, y = zip(*polygon) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]: + """Compute enclosing bbox either from: + + Args: + ---- + bboxes: boxes in one of the following formats: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + + Returns: + ------- + a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + if isinstance(bboxes, np.ndarray): + xmin, ymin, xmax, ymax, score = np.split(bboxes, 5, axis=1) + return np.array([xmin.min(), ymin.min(), xmax.max(), ymax.max(), score.mean()]) + else: + x, y = zip(*[point for box in bboxes for point in box]) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024) -> np.ndarray: + """Compute enclosing rotated bbox either from: + + Args: + ---- + rbboxes: boxes in one of the following formats: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + intermed_size: size of the intermediate image + + Returns: + ------- + a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + cloud: np.ndarray = np.concatenate(rbboxes, axis=0) + # Convert to absolute for minAreaRect + cloud *= intermed_size + rect = cv2.minAreaRect(cloud.astype(np.int32)) + return cv2.boxPoints(rect) / intermed_size # type: ignore[operator] + + +def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray: + """Rotate points counter-clockwise. + + Args: + ---- + points: array of size (N, 2) + angle: angle between -90 and +90 degrees + + Returns: + ------- + Rotated points + """ + angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions + rotation_mat = np.array( + [[np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)]], dtype=points.dtype + ) + return np.matmul(points, rotation_mat.T) + + +def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]: + """Compute the shape of an expanded rotated image + + Args: + ---- + img_shape: the height and width of the image + angle: angle between -90 and +90 degrees + + Returns: + ------- + the height and width of the rotated image + """ + points: np.ndarray = np.array([ + [img_shape[1] / 2, img_shape[0] / 2], + [-img_shape[1] / 2, img_shape[0] / 2], + ]) + + rotated_points = rotate_abs_points(points, angle) + + wh_shape = 2 * np.abs(rotated_points).max(axis=0) + return wh_shape[1], wh_shape[0] + + +def rotate_abs_geoms( + geoms: np.ndarray, + angle: float, + img_shape: Tuple[int, int], + expand: bool = True, +) -> np.ndarray: + """Rotate a batch of bounding boxes or polygons by an angle around the + image center. + + Args: + ---- + geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes + angle: anti-clockwise rotation angle in degrees + img_shape: the height and width of the image + expand: whether the image should be padded to avoid information loss + + Returns: + ------- + A batch of rotated polygons (N, 4, 2) + """ + # Switch to polygons + polys = ( + np.stack([geoms[:, [0, 1]], geoms[:, [2, 1]], geoms[:, [2, 3]], geoms[:, [0, 3]]], axis=1) + if geoms.ndim == 2 + else geoms + ) + polys = polys.astype(np.float32) + + # Switch to image center as referential + polys[..., 0] -= img_shape[1] / 2 + polys[..., 1] = img_shape[0] / 2 - polys[..., 1] + + # Rotated them around image center + rotated_polys = rotate_abs_points(polys.reshape(-1, 2), angle).reshape(-1, 4, 2) + # Switch back to top-left corner as referential + target_shape = compute_expanded_shape(img_shape, angle) if expand else img_shape + # Clip coords to fit since there is no expansion + rotated_polys[..., 0] = (rotated_polys[..., 0] + target_shape[1] / 2).clip(0, target_shape[1]) + rotated_polys[..., 1] = (target_shape[0] / 2 - rotated_polys[..., 1]).clip(0, target_shape[0]) + + return rotated_polys + + +def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: Tuple[int, int]) -> np.ndarray: + """Remaps a batch of rotated locpred (N, 4, 2) expressed for an origin_shape to a destination_shape. + This does not impact the absolute shape of the boxes, but allow to calculate the new relative RotatedBbox + coordinates after a resizing of the image. + + Args: + ---- + loc_preds: (N, 4, 2) array of RELATIVE loc_preds + orig_shape: shape of the origin image + dest_shape: shape of the destination image + + Returns: + ------- + A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial + """ + if len(dest_shape) != 2: + raise ValueError(f"Mask length should be 2, was found at: {len(dest_shape)}") + if len(orig_shape) != 2: + raise ValueError(f"Image_shape length should be 2, was found at: {len(orig_shape)}") + orig_height, orig_width = orig_shape + dest_height, dest_width = dest_shape + mboxes = loc_preds.copy() + mboxes[:, :, 0] = ((loc_preds[:, :, 0] * orig_width) + (dest_width - orig_width) / 2) / dest_width + mboxes[:, :, 1] = ((loc_preds[:, :, 1] * orig_height) + (dest_height - orig_height) / 2) / dest_height + + return mboxes + + +def rotate_boxes( + loc_preds: np.ndarray, + angle: float, + orig_shape: Tuple[int, int], + min_angle: float = 1.0, + target_shape: Optional[Tuple[int, int]] = None, +) -> np.ndarray: + """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes + (4, 2) of an angle, if angle > min_angle, around the center of the page. + If target_shape is specified, the boxes are remapped to the target shape after the rotation. This + is done to remove the padding that is created by rotate_page(expand=True) + + Args: + ---- + loc_preds: (N, 5) or (N, 4, 2) array of RELATIVE boxes + angle: angle between -90 and +90 degrees + orig_shape: shape of the origin image + min_angle: minimum angle to rotate boxes + target_shape: shape of the destination image + + Returns: + ------- + A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes + """ + # Change format of the boxes to rotated boxes + _boxes = loc_preds.copy() + if _boxes.ndim == 2: + _boxes = np.stack( + [ + _boxes[:, [0, 1]], + _boxes[:, [2, 1]], + _boxes[:, [2, 3]], + _boxes[:, [0, 3]], + ], + axis=1, + ) + # If small angle, return boxes (no rotation) + if abs(angle) < min_angle or abs(angle) > 90 - min_angle: + return _boxes + # Compute rotation matrix + angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions + rotation_mat = np.array( + [[np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)]], dtype=_boxes.dtype + ) + # Rotate absolute points + points: np.ndarray = np.stack((_boxes[:, :, 0] * orig_shape[1], _boxes[:, :, 1] * orig_shape[0]), axis=-1) + image_center = (orig_shape[1] / 2, orig_shape[0] / 2) + rotated_points = image_center + np.matmul(points - image_center, rotation_mat) + rotated_boxes: np.ndarray = np.stack( + (rotated_points[:, :, 0] / orig_shape[1], rotated_points[:, :, 1] / orig_shape[0]), axis=-1 + ) + + # Apply a mask if requested + if target_shape is not None: + rotated_boxes = remap_boxes(rotated_boxes, orig_shape=orig_shape, dest_shape=target_shape) + + return rotated_boxes + + +def rotate_image( + image: np.ndarray, + angle: float, + expand: bool = False, + preserve_origin_shape: bool = False, +) -> np.ndarray: + """Rotate an image counterclockwise by an given angle. + + Args: + ---- + image: numpy tensor to rotate + angle: rotation angle in degrees, between -90 and +90 + expand: whether the image should be padded before the rotation + preserve_origin_shape: if expand is set to True, resizes the final output to the original image size + + Returns: + ------- + Rotated array, padded by 0 by default. + """ + # Compute the expanded padding + exp_img: np.ndarray + if expand: + exp_shape = compute_expanded_shape(image.shape[:2], angle) # type: ignore[arg-type] + h_pad, w_pad = ( + int(max(0, ceil(exp_shape[0] - image.shape[0]))), + int(max(0, ceil(exp_shape[1] - image.shape[1]))), + ) + exp_img = np.pad(image, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + else: + exp_img = image + + height, width = exp_img.shape[:2] + rot_mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) + rot_img = cv2.warpAffine(exp_img, rot_mat, (width, height)) + if expand: + # Pad to get the same aspect ratio + if (image.shape[0] / image.shape[1]) != (rot_img.shape[0] / rot_img.shape[1]): + # Pad width + if (rot_img.shape[0] / rot_img.shape[1]) > (image.shape[0] / image.shape[1]): + h_pad, w_pad = 0, int(rot_img.shape[0] * image.shape[1] / image.shape[0] - rot_img.shape[1]) + # Pad height + else: + h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0 + rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + if preserve_origin_shape: + # rescale + rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR) + + return rot_img + + +def estimate_page_angle(polys: np.ndarray) -> float: + """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the + estimated angle ccw in degrees + """ + # Compute mean left points and mean right point with respect to the reading direction (oriented polygon) + xleft = polys[:, 0, 0] + polys[:, 3, 0] + yleft = polys[:, 0, 1] + polys[:, 3, 1] + xright = polys[:, 1, 0] + polys[:, 2, 0] + yright = polys[:, 1, 1] + polys[:, 2, 1] + with np.errstate(divide="raise", invalid="raise"): + try: + return float( + np.median(np.arctan((yleft - yright) / (xright - xleft)) * 180 / np.pi) # Y axis from top to bottom! + ) + except FloatingPointError: + return 0.0 + + +def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray: + """Convert a geometry to relative coordinates + + Args: + ---- + geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) + img_shape: the height and width of the image + + Returns: + ------- + the updated geometry + """ + # Polygon + if geoms.ndim == 3 and geoms.shape[1:] == (4, 2): + polygons: np.ndarray = np.empty(geoms.shape, dtype=np.float32) + polygons[..., 0] = geoms[..., 0] / img_shape[1] + polygons[..., 1] = geoms[..., 1] / img_shape[0] + return polygons.clip(0, 1) + if geoms.ndim == 2 and geoms.shape[1] == 4: + boxes: np.ndarray = np.empty(geoms.shape, dtype=np.float32) + boxes[:, ::2] = geoms[:, ::2] / img_shape[1] + boxes[:, 1::2] = geoms[:, 1::2] / img_shape[0] + return boxes.clip(0, 1) + + raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}") + + +def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]: + """Created cropped images from list of bounding boxes + + Args: + ---- + img: input image + boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative + coordinates (xmin, ymin, xmax, ymax) + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + ------- + list of cropped images + """ + if boxes.shape[0] == 0: + return [] + if boxes.shape[1] != 4: + raise AssertionError("boxes are expected to be relative and in order (xmin, ymin, xmax, ymax)") + + # Project relative coordinates + _boxes = boxes.copy() + h, w = img.shape[:2] if channels_last else img.shape[-2:] + if not np.issubdtype(_boxes.dtype, np.integer): + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.round().astype(int) + # Add last index + _boxes[2:] += 1 + if channels_last: + return deepcopy([img[box[1] : box[3], box[0] : box[2]] for box in _boxes]) + + return deepcopy([img[:, box[1] : box[3], box[0] : box[2]] for box in _boxes]) + + +def extract_rcrops( + img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True +) -> List[np.ndarray]: + """Created cropped images from list of rotated bounding boxes + + Args: + ---- + img: input image + polys: bounding boxes of shape (N, 4, 2) + dtype: target data type of bounding boxes + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + ------- + list of cropped images + """ + if polys.shape[0] == 0: + return [] + if polys.shape[1:] != (4, 2): + raise AssertionError("polys are expected to be quadrilateral, of shape (N, 4, 2)") + + # Project relative coordinates + _boxes = polys.copy() + height, width = img.shape[:2] if channels_last else img.shape[-2:] + if not np.issubdtype(_boxes.dtype, np.integer): + _boxes[:, :, 0] *= width + _boxes[:, :, 1] *= height + + src_pts = _boxes[:, :3].astype(np.float32) + # Preserve size + d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) + d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1) + # (N, 3, 2) + dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype) + dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1 + dst_pts[:, 2, 1] = d2 - 1 + # Use a warp transformation to extract the crop + crops = [ + cv2.warpAffine( + img if channels_last else img.transpose(1, 2, 0), + # Transformation matrix + cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), + (int(d1[idx]), int(d2[idx])), + ) + for idx in range(_boxes.shape[0]) + ] + return crops diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..faea10a3ab43b87d5779dfc0b12b59ca6adbbc16 --- /dev/null +++ b/doctr/utils/metrics.py @@ -0,0 +1,571 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Dict, List, Optional, Tuple + +import numpy as np +from anyascii import anyascii +from scipy.optimize import linear_sum_assignment +from shapely.geometry import Polygon + +__all__ = [ + "TextMatch", + "box_iou", + "polygon_iou", + "nms", + "LocalizationConfusion", + "OCRMetric", + "DetectionMetric", +] + + +def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]: + """Performs string comparison with multiple levels of tolerance + + Args: + ---- + word1: a string + word2: another string + + Returns: + ------- + a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their + anyascii counterparts and their lower-case anyascii counterparts match + """ + raw_match = word1 == word2 + caseless_match = word1.lower() == word2.lower() + anyascii_match = anyascii(word1) == anyascii(word2) + + # Warning: the order is important here otherwise the pair ("EUR", "€") cannot be matched + unicase_match = anyascii(word1).lower() == anyascii(word2).lower() + + return raw_match, caseless_match, anyascii_match, unicase_match + + +class TextMatch: + r"""Implements text match metric (word-level accuracy) for recognition task. + + The raw aggregated metric is computed as follows: + + .. math:: + \forall X, Y \in \mathcal{W}^N, + TextMatch(X, Y) = \frac{1}{N} \sum\limits_{i=1}^N f_{Y_i}(X_i) + + with the indicator function :math:`f_{a}` defined as: + + .. math:: + \forall a, x \in \mathcal{W}, + f_a(x) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } x = a \\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{W}` is the set of all possible character sequences, + :math:`N` is a strictly positive integer. + + >>> from doctr.utils import TextMatch + >>> metric = TextMatch() + >>> metric.update(['Hello', 'world'], ['hello', 'world']) + >>> metric.summary() + """ + + def __init__(self) -> None: + self.reset() + + def update( + self, + gt: List[str], + pred: List[str], + ) -> None: + """Update the state of the metric with new predictions + + Args: + ---- + gt: list of groung-truth character sequences + pred: list of predicted character sequences + """ + if len(gt) != len(pred): + raise AssertionError("prediction size does not match with ground-truth labels size") + + for gt_word, pred_word in zip(gt, pred): + _raw, _caseless, _anyascii, _unicase = string_match(gt_word, pred_word) + self.raw += int(_raw) + self.caseless += int(_caseless) + self.anyascii += int(_anyascii) + self.unicase += int(_unicase) + + self.total += len(gt) + + def summary(self) -> Dict[str, float]: + """Computes the aggregated metrics + + Returns + ------- + a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii + counterpart and its lower-case anyascii counterpart + """ + if self.total == 0: + raise AssertionError("you need to update the metric before getting the summary") + + return dict( + raw=self.raw / self.total, + caseless=self.caseless / self.total, + anyascii=self.anyascii / self.total, + unicase=self.unicase / self.total, + ) + + def reset(self) -> None: + self.raw = 0 + self.caseless = 0 + self.anyascii = 0 + self.unicase = 0 + self.total = 0 + + +def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: + """Computes the IoU between two sets of bounding boxes + + Args: + ---- + boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax) + boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax) + + Returns: + ------- + the IoU matrix of shape (N, M) + """ + iou_mat: np.ndarray = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32) + + if boxes_1.shape[0] > 0 and boxes_2.shape[0] > 0: + l1, t1, r1, b1 = np.split(boxes_1, 4, axis=1) + l2, t2, r2, b2 = np.split(boxes_2, 4, axis=1) + + left = np.maximum(l1, l2.T) + top = np.maximum(t1, t2.T) + right = np.minimum(r1, r2.T) + bot = np.minimum(b1, b2.T) + + intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf) + union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection + iou_mat = intersection / union + + return iou_mat + + +def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray: + """Computes the IoU between two sets of rotated bounding boxes + + Args: + ---- + polys_1: rotated bounding boxes of shape (N, 4, 2) + polys_2: rotated bounding boxes of shape (M, 4, 2) + mask_shape: spatial shape of the intermediate masks + use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory + + Returns: + ------- + the IoU matrix of shape (N, M) + """ + if polys_1.ndim != 3 or polys_2.ndim != 3: + raise AssertionError("expects boxes to be in format (N, 4, 2)") + + iou_mat = np.zeros((polys_1.shape[0], polys_2.shape[0]), dtype=np.float32) + + shapely_polys_1 = [Polygon(poly) for poly in polys_1] + shapely_polys_2 = [Polygon(poly) for poly in polys_2] + + for i, poly1 in enumerate(shapely_polys_1): + for j, poly2 in enumerate(shapely_polys_2): + intersection_area = poly1.intersection(poly2).area + union_area = poly1.area + poly2.area - intersection_area + iou_mat[i, j] = intersection_area / union_area + + return iou_mat + + +def nms(boxes: np.ndarray, thresh: float = 0.5) -> List[int]: + """Perform non-max suppression, borrowed from `_. + + Args: + ---- + boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score) + thresh: iou threshold to perform box suppression. + + Returns: + ------- + A list of box indexes to keep + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + scores = boxes[:, 4] + + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + +class LocalizationConfusion: + r"""Implements common confusion metrics and mean IoU for localization evaluation. + + The aggregated metrics are computed as follows: + + .. math:: + \forall Y \in \mathcal{B}^N, \forall X \in \mathcal{B}^M, \\ + Recall(X, Y) = \frac{1}{N} \sum\limits_{i=1}^N g_{X}(Y_i) \\ + Precision(X, Y) = \frac{1}{M} \sum\limits_{i=1}^M g_{X}(Y_i) \\ + meanIoU(X, Y) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(X_i, Y_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`g_{X}` defined as: + + .. math:: + \forall y \in \mathcal{B}, + g_X(y) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } y\mbox{ has been assigned to any }(X_i)_i\mbox{ with an }IoU \geq 0.5 \\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + >>> import numpy as np + >>> from doctr.utils import LocalizationConfusion + >>> metric = LocalizationConfusion(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]])) + >>> metric.summary() + + Args: + ---- + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.reset() + + def update(self, gts: np.ndarray, preds: np.ndarray) -> None: + """Updates the metric + + Args: + ---- + gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + """ + if preds.shape[0] > 0: + # Compute IoU + if self.use_polygons: + iou_mat = polygon_iou(gts, preds) + else: + iou_mat = box_iou(gts, preds) + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + self.matches += int((iou_mat[gt_indices, pred_indices] >= self.iou_thresh).sum()) + + # Update counts + self.num_gts += gts.shape[0] + self.num_preds += preds.shape[0] + + def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: + """Computes the aggregated metrics + + Returns + ------- + a tuple with the recall, precision and meanIoU scores + """ + # Recall + recall = self.matches / self.num_gts if self.num_gts > 0 else None + + # Precision + precision = self.matches / self.num_preds if self.num_preds > 0 else None + + # mean IoU + mean_iou = round(self.tot_iou / self.num_preds, 2) if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.matches = 0 + self.tot_iou = 0.0 + + +class OCRMetric: + r"""Implements an end-to-end OCR metric. + + The aggregated metrics are computed as follows: + + .. math:: + \forall (B, L) \in \mathcal{B}^N \times \mathcal{L}^N, + \forall (\hat{B}, \hat{L}) \in \mathcal{B}^M \times \mathcal{L}^M, \\ + Recall(B, \hat{B}, L, \hat{L}) = \frac{1}{N} \sum\limits_{i=1}^N h_{B,L}(\hat{B}_i, \hat{L}_i) \\ + Precision(B, \hat{B}, L, \hat{L}) = \frac{1}{M} \sum\limits_{i=1}^M h_{B,L}(\hat{B}_i, \hat{L}_i) \\ + meanIoU(B, \hat{B}) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(\hat{B}_i, B_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`h_{B, L}` defined as: + + .. math:: + \forall (b, l) \in \mathcal{B} \times \mathcal{L}, + h_{B,L}(b, l) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } b\mbox{ has been assigned to a given }B_j\mbox{ with an } \\ + & IoU \geq 0.5 \mbox{ and that for this assignment, } l = L_j\\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`\mathcal{L}` is the set of possible character sequences, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + >>> import numpy as np + >>> from doctr.utils import OCRMetric + >>> metric = OCRMetric(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]]), + >>> ['hello'], ['hello', 'world']) + >>> metric.summary() + + Args: + ---- + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.reset() + + def update( + self, + gt_boxes: np.ndarray, + pred_boxes: np.ndarray, + gt_labels: List[str], + pred_labels: List[str], + ) -> None: + """Updates the metric + + Args: + ---- + gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + gt_labels: a list of N string labels + pred_labels: a list of M string labels + """ + if gt_boxes.shape[0] != len(gt_labels) or pred_boxes.shape[0] != len(pred_labels): + raise AssertionError( + "there should be the same number of boxes and string both for the ground truth " "and the predictions" + ) + + # Compute IoU + if pred_boxes.shape[0] > 0: + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, pred_boxes) + else: + iou_mat = box_iou(gt_boxes, pred_boxes) + + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh + # String comparison + for gt_idx, pred_idx in zip(gt_indices[is_kept], pred_indices[is_kept]): + _raw, _caseless, _anyascii, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx]) + self.raw_matches += int(_raw) + self.caseless_matches += int(_caseless) + self.anyascii_matches += int(_anyascii) + self.unicase_matches += int(_unicase) + + self.num_gts += gt_boxes.shape[0] + self.num_preds += pred_boxes.shape[0] + + def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]], Optional[float]]: + """Computes the aggregated metrics + + Returns + ------- + a tuple with the recall & precision for each string comparison and the mean IoU + """ + # Recall + recall = dict( + raw=self.raw_matches / self.num_gts if self.num_gts > 0 else None, + caseless=self.caseless_matches / self.num_gts if self.num_gts > 0 else None, + anyascii=self.anyascii_matches / self.num_gts if self.num_gts > 0 else None, + unicase=self.unicase_matches / self.num_gts if self.num_gts > 0 else None, + ) + + # Precision + precision = dict( + raw=self.raw_matches / self.num_preds if self.num_preds > 0 else None, + caseless=self.caseless_matches / self.num_preds if self.num_preds > 0 else None, + anyascii=self.anyascii_matches / self.num_preds if self.num_preds > 0 else None, + unicase=self.unicase_matches / self.num_preds if self.num_preds > 0 else None, + ) + + # mean IoU (overall detected boxes) + mean_iou = round(self.tot_iou / self.num_preds, 2) if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.tot_iou = 0.0 + self.raw_matches = 0 + self.caseless_matches = 0 + self.anyascii_matches = 0 + self.unicase_matches = 0 + + +class DetectionMetric: + r"""Implements an object detection metric. + + The aggregated metrics are computed as follows: + + .. math:: + \forall (B, C) \in \mathcal{B}^N \times \mathcal{C}^N, + \forall (\hat{B}, \hat{C}) \in \mathcal{B}^M \times \mathcal{C}^M, \\ + Recall(B, \hat{B}, C, \hat{C}) = \frac{1}{N} \sum\limits_{i=1}^N h_{B,C}(\hat{B}_i, \hat{C}_i) \\ + Precision(B, \hat{B}, C, \hat{C}) = \frac{1}{M} \sum\limits_{i=1}^M h_{B,C}(\hat{B}_i, \hat{C}_i) \\ + meanIoU(B, \hat{B}) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(\hat{B}_i, B_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`h_{B, C}` defined as: + + .. math:: + \forall (b, c) \in \mathcal{B} \times \mathcal{C}, + h_{B,C}(b, c) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } b\mbox{ has been assigned to a given }B_j\mbox{ with an } \\ + & IoU \geq 0.5 \mbox{ and that for this assignment, } c = C_j\\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`\mathcal{C}` is the set of possible class indices, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + >>> import numpy as np + >>> from doctr.utils import DetectionMetric + >>> metric = DetectionMetric(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]]), + >>> np.zeros(1, dtype=np.int64), np.array([0, 1], dtype=np.int64)) + >>> metric.summary() + + Args: + ---- + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.reset() + + def update( + self, + gt_boxes: np.ndarray, + pred_boxes: np.ndarray, + gt_labels: np.ndarray, + pred_labels: np.ndarray, + ) -> None: + """Updates the metric + + Args: + ---- + gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + gt_labels: an array of class indices of shape (N,) + pred_labels: an array of class indices of shape (M,) + """ + if gt_boxes.shape[0] != gt_labels.shape[0] or pred_boxes.shape[0] != pred_labels.shape[0]: + raise AssertionError( + "there should be the same number of boxes and string both for the ground truth " "and the predictions" + ) + + # Compute IoU + if pred_boxes.shape[0] > 0: + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, pred_boxes) + else: + iou_mat = box_iou(gt_boxes, pred_boxes) + + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh + # Category comparison + self.num_matches += int((gt_labels[gt_indices[is_kept]] == pred_labels[pred_indices[is_kept]]).sum()) + + self.num_gts += gt_boxes.shape[0] + self.num_preds += pred_boxes.shape[0] + + def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: + """Computes the aggregated metrics + + Returns + ------- + a tuple with the recall & precision for each class prediction and the mean IoU + """ + # Recall + recall = self.num_matches / self.num_gts if self.num_gts > 0 else None + + # Precision + precision = self.num_matches / self.num_preds if self.num_preds > 0 else None + + # mean IoU (overall detected boxes) + mean_iou = round(self.tot_iou / self.num_preds, 2) if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.tot_iou = 0.0 + self.num_matches = 0 diff --git a/doctr/utils/multithreading.py b/doctr/utils/multithreading.py new file mode 100644 index 0000000000000000000000000000000000000000..6450a0bfd25e5336bfbfb7e824a29bcada735df1 --- /dev/null +++ b/doctr/utils/multithreading.py @@ -0,0 +1,50 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import multiprocessing as mp +import os +from multiprocessing.pool import ThreadPool +from typing import Any, Callable, Iterable, Iterator, Optional + +from doctr.file_utils import ENV_VARS_TRUE_VALUES + +__all__ = ["multithread_exec"] + + +def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterator[Any]: + """Execute a given function in parallel for each element of a given sequence + + >>> from doctr.utils.multithreading import multithread_exec + >>> entries = [1, 4, 8] + >>> results = multithread_exec(lambda x: x ** 2, entries) + + Args: + ---- + func: function to be executed on each element of the iterable + seq: iterable + threads: number of workers to be used for multiprocessing + + Returns: + ------- + iterator of the function's results using the iterable as inputs + + Notes: + ----- + This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. + If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance), + you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'. + """ + threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) + # Single-thread + if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES: + results = map(func, seq) + # Multi-threading + else: + with ThreadPool(threads) as tp: + # ThreadPool's map function returns a list, but seq could be of a different type + # That's why wrapping result in map to return iterator + results = map(lambda x: x, tp.map(func, seq)) # noqa: C417 + return results diff --git a/doctr/utils/reconstitution.py b/doctr/utils/reconstitution.py new file mode 100644 index 0000000000000000000000000000000000000000..82ae20cdd0c9824801546961416912a8977e70f1 --- /dev/null +++ b/doctr/utils/reconstitution.py @@ -0,0 +1,126 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. +from typing import Any, Dict, Optional + +import numpy as np +from anyascii import anyascii +from PIL import Image, ImageDraw + +from .fonts import get_font + +__all__ = ["synthesize_page", "synthesize_kie_page"] + + +def synthesize_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for block in page["blocks"]: + for line in block["lines"]: + for word in line["words"]: + # Get absolute word geometry + (xmin, ymin), (xmax, ymax) = word["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * word["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response + + +def synthesize_kie_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + ---- + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Returns: + ------- + the synthesized page + """ + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for predictions in page["predictions"].values(): + for prediction in predictions: + # Get aboslute word geometry + (xmin, ymin), (xmax, ymax) = prediction["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its anyascii version + d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * prediction["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response diff --git a/doctr/utils/repr.py b/doctr/utils/repr.py new file mode 100644 index 0000000000000000000000000000000000000000..ccae2d6afc3272cd6450c86072eb9e056a12e723 --- /dev/null +++ b/doctr/utils/repr.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/torch/blob/master/torch/nn/modules/module.py + +from typing import List + +__all__ = ["NestedObject"] + + +def _addindent(s_, num_spaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +class NestedObject: + """Base class for all nested objects in doctr""" + + _children_names: List[str] + + def extra_repr(self) -> str: + return "" + + def __repr__(self): + # We treat the extra repr like the sub-object, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + if hasattr(self, "_children_names"): + for key in self._children_names: + child = getattr(self, key) + if isinstance(child, list) and len(child) > 0: + child_str = ",\n".join([repr(subchild) for subchild in child]) + if len(child) > 1: + child_str = _addindent(f"\n{child_str},", 2) + "\n" + child_str = f"[{child_str}]" + else: + child_str = repr(child) + child_str = _addindent(child_str, 2) + child_lines.append("(" + key + "): " + child_str) + lines = extra_lines + child_lines + + main_str = self.__class__.__name__ + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..4e97f751fe5ef7fadd3b4df63c6d5e45c7f86e30 --- /dev/null +++ b/doctr/utils/visualization.py @@ -0,0 +1,388 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. +import colorsys +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.figure import Figure + +from .common_types import BoundingBox, Polygon4P + +__all__ = ["visualize_page", "visualize_kie_page", "draw_boxes"] + + +def rect_patch( + geometry: BoundingBox, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False, +) -> patches.Rectangle: + """Create a matplotlib rectangular patch for the element + + Args: + ---- + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + ------- + a rectangular Patch + """ + if len(geometry) != 2 or any(not isinstance(elt, tuple) or len(elt) != 2 for elt in geometry): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + (xmin, ymin), (xmax, ymax) = geometry + # Switch to absolute coords + if preserve_aspect_ratio: + width = height = max(height, width) + xmin, w = xmin * width, (xmax - xmin) * width + ymin, h = ymin * height, (ymax - ymin) * height + + return patches.Rectangle( + (xmin, ymin), + w, + h, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def polygon_patch( + geometry: np.ndarray, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False, +) -> patches.Polygon: + """Create a matplotlib polygon patch for the element + + Args: + ---- + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + ------- + a polygon Patch + """ + if not geometry.shape == (4, 2): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + geometry[:, 0] = geometry[:, 0] * (max(width, height) if preserve_aspect_ratio else width) + geometry[:, 1] = geometry[:, 1] * (max(width, height) if preserve_aspect_ratio else height) + + return patches.Polygon( + geometry, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def create_obj_patch( + geometry: Union[BoundingBox, Polygon4P, np.ndarray], + page_dimensions: Tuple[int, int], + **kwargs: Any, +) -> patches.Patch: + """Create a matplotlib patch for the element + + Args: + ---- + geometry: bounding box (straight or rotated) of the element + page_dimensions: dimensions of the page in format (height, width) + **kwargs: keyword arguments for the patch + + Returns: + ------- + a matplotlib Patch + """ + if isinstance(geometry, tuple): + if len(geometry) == 2: # straight word BB (2 pts) + return rect_patch(geometry, page_dimensions, **kwargs) + elif len(geometry) == 4: # rotated word BB (4 pts) + return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) + elif isinstance(geometry, np.ndarray) and geometry.shape == (4, 2): # rotated line + return polygon_patch(geometry, page_dimensions, **kwargs) + raise ValueError("invalid geometry format") + + +def get_colors(num_colors: int) -> List[Tuple[float, float, float]]: + """Generate num_colors color for matplotlib + + Args: + ---- + num_colors: number of colors to generate + + Returns: + ------- + colors: list of generated colors + """ + colors = [] + for i in np.arange(0.0, 360.0, 360.0 / num_colors): + hue = i / 360.0 + lightness = (50 + np.random.rand() * 10) / 100.0 + saturation = (90 + np.random.rand() * 10) / 100.0 + colors.append(colorsys.hls_to_rgb(hue, lightness, saturation)) + return colors + + +def visualize_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = True, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from doctr.utils.visualization import visualize_page + >>> from doctr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + ---- + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + **kwargs: keyword arguments for the polygon patch + + Returns: + ------- + the matplotlib figure + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis("off") + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + for block in page["blocks"]: + if not words_only: + rect = create_obj_patch( + block["geometry"], page["dimensions"], label="block", color=(0, 1, 0), linewidth=1, **kwargs + ) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + for line in block["lines"]: + if not words_only: + rect = create_obj_patch( + line["geometry"], page["dimensions"], label="line", color=(1, 0, 0), linewidth=1, **kwargs + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + for word in line["words"]: + rect = create_obj_patch( + word["geometry"], + page["dimensions"], + label=f"{word['value']} (confidence: {word['confidence']:.2%})", + color=(0, 0, 1), + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels: + if len(word["geometry"]) == 5: + text_loc = ( + int(page["dimensions"][1] * (word["geometry"][0] - word["geometry"][2] / 2)), + int(page["dimensions"][0] * (word["geometry"][1] - word["geometry"][3] / 2)), + ) + else: + text_loc = ( + int(page["dimensions"][1] * word["geometry"][0][0]), + int(page["dimensions"][0] * word["geometry"][0][1]), + ) + + if len(word["geometry"]) == 2: + # We draw only if boxes are in straight format + ax.text( + *text_loc, + word["value"], + size=10, + alpha=0.5, + color=(0, 0, 1), + ) + + if display_artefacts: + for artefact in block["artefacts"]: + rect = create_obj_patch( + artefact["geometry"], + page["dimensions"], + label="artefact", + color=(0.5, 0.5, 0.5), + linewidth=1, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + if interactive: + import mplcursors + + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.0) + + return fig + + +def visualize_kie_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = False, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from doctr.utils.visualization import visualize_page + >>> from doctr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_kie_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + ---- + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + **kwargs: keyword arguments for the polygon patch + + Returns: + ------- + the matplotlib figure + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis("off") + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + colors = {k: color for color, k in zip(get_colors(len(page["predictions"])), page["predictions"])} + for key, value in page["predictions"].items(): + for prediction in value: + if not words_only: + rect = create_obj_patch( + prediction["geometry"], + page["dimensions"], + label=f"{key} \n {prediction['value']} (confidence: {prediction['confidence']:.2%}", + color=colors[key], + linewidth=1, + **kwargs, + ) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + if interactive: + import mplcursors + + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.0) + + return fig + + +def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, int, int]] = None, **kwargs) -> None: + """Draw an array of relative straight boxes on an image + + Args: + ---- + boxes: array of relative boxes, of shape (*, 4) + image: np array, float32 or uint8 + color: color to use for bounding box edges + **kwargs: keyword arguments from `matplotlib.pyplot.plot` + """ + h, w = image.shape[:2] + # Convert boxes to absolute coords + _boxes = deepcopy(boxes) + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.astype(np.int32) + for box in _boxes.tolist(): + xmin, ymin, xmax, ymax = box + image = cv2.rectangle( + image, (xmin, ymin), (xmax, ymax), color=color if isinstance(color, tuple) else (0, 0, 255), thickness=2 + ) + plt.imshow(image) + plt.plot(**kwargs) diff --git a/doctr/version.py b/doctr/version.py new file mode 100644 index 0000000000000000000000000000000000000000..061861ce25d330b37778363b3a5b2920e34c8247 --- /dev/null +++ b/doctr/version.py @@ -0,0 +1 @@ +__version__ = '0.9.0a0'