# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np from doctr.file_utils import requires_package from .base import _BasePredictor __all__ = ["ArtefactDetector"] default_cfgs: Dict[str, Dict[str, Any]] = { "yolov8_artefact": { "input_shape": (3, 1024, 1024), "labels": ["bar_code", "qr_code", "logo", "photo"], "url": "https://doctr-static.mindee.com/models?id=v0.8.1/yolo_artefact-f9d66f14.onnx&src=0", }, } class ArtefactDetector(_BasePredictor): """ A class to detect artefacts in images >>> from doctr.io import DocumentFile >>> from doctr.contrib.artefacts import ArtefactDetector >>> doc = DocumentFile.from_images(["path/to/image.jpg"]) >>> detector = ArtefactDetector() >>> results = detector(doc) Args: ---- arch: the architecture to use batch_size: the batch size to use model_path: the path to the model to use labels: the labels to use input_shape: the input shape to use mask_labels: the mask labels to use conf_threshold: the confidence threshold to use iou_threshold: the intersection over union threshold to use **kwargs: additional arguments to be passed to `download_from_url` """ def __init__( self, arch: str = "yolov8_artefact", batch_size: int = 2, model_path: Optional[str] = None, labels: Optional[List[str]] = None, input_shape: Optional[Tuple[int, int, int]] = None, conf_threshold: float = 0.5, iou_threshold: float = 0.5, **kwargs: Any, ) -> None: super().__init__(batch_size=batch_size, url=default_cfgs[arch]["url"], model_path=model_path, **kwargs) self.labels = labels or default_cfgs[arch]["labels"] self.input_shape = input_shape or default_cfgs[arch]["input_shape"] self.conf_threshold = conf_threshold self.iou_threshold = iou_threshold def preprocess(self, img: np.ndarray) -> np.ndarray: return np.transpose(cv2.resize(img, (self.input_shape[2], self.input_shape[1])), (2, 0, 1)) / np.array(255.0) def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarray]]) -> List[List[Dict[str, Any]]]: results = [] for batch in zip(output, input_images): for out, img in zip(batch[0], batch[1]): org_height, org_width = img.shape[:2] width_scale, height_scale = org_width / self.input_shape[2], org_height / self.input_shape[1] for res in out: sample_results = [] for row in np.transpose(np.squeeze(res)): classes_scores = row[4:] max_score = np.amax(classes_scores) if max_score >= self.conf_threshold: class_id = np.argmax(classes_scores) x, y, w, h = row[0], row[1], row[2], row[3] # to rescaled xmin, ymin, xmax, ymax xmin = int((x - w / 2) * width_scale) ymin = int((y - h / 2) * height_scale) xmax = int((x + w / 2) * width_scale) ymax = int((y + h / 2) * height_scale) sample_results.append({ "label": self.labels[class_id], "confidence": float(max_score), "box": [xmin, ymin, xmax, ymax], }) # Filter out overlapping boxes boxes = [res["box"] for res in sample_results] scores = [res["confidence"] for res in sample_results] keep_indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) # type: ignore[arg-type] sample_results = [sample_results[i] for i in keep_indices] results.append(sample_results) self._results = results return results def show(self, **kwargs: Any) -> None: """ Display the results Args: ---- **kwargs: additional keyword arguments to be passed to `plt.show` """ requires_package("matplotlib", "`.show()` requires matplotlib installed") import matplotlib.pyplot as plt from matplotlib.patches import Rectangle # visualize the results with matplotlib if self._results and self._inputs: for img, res in zip(self._inputs, self._results): plt.figure(figsize=(10, 10)) plt.imshow(img) for obj in res: xmin, ymin, xmax, ymax = obj["box"] label = obj["label"] plt.text(xmin, ymin, f"{label} {obj['confidence']:.2f}", color="red") plt.gca().add_patch( Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor="red", linewidth=2) ) plt.show(**kwargs)