File size: 3,258 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from typing import Any, Dict, List, Tuple

import numpy as np

from .datasets import VisionDataset

__all__ = ["DocArtefacts"]


class DocArtefacts(VisionDataset):
    """Object detection dataset for non-textual elements in documents.
    The dataset includes a variety of synthetic document pages with non-textual elements.

    .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0
        :align: center

    >>> from doctr.datasets import DocArtefacts
    >>> train_set = DocArtefacts(train=True, download=True)
    >>> img, target = train_set[0]

    Args:
    ----
        train: whether the subset should be the training one
        use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
        **kwargs: keyword arguments from `VisionDataset`.
    """

    URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0"
    SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b"
    CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"]

    def __init__(
        self,
        train: bool = True,
        use_polygons: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(self.URL, None, self.SHA256, True, **kwargs)
        self.train = train

        # Update root
        self.root = os.path.join(self.root, "train" if train else "val")
        # List images
        tmp_root = os.path.join(self.root, "images")
        with open(os.path.join(self.root, "labels.json"), "rb") as f:
            labels = json.load(f)
        self.data: List[Tuple[str, Dict[str, Any]]] = []
        img_list = os.listdir(tmp_root)
        if len(labels) != len(img_list):
            raise AssertionError("the number of images and labels do not match")
        np_dtype = np.float32
        for img_name, label in labels.items():
            # File existence check
            if not os.path.exists(os.path.join(tmp_root, img_name)):
                raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")

            # xmin, ymin, xmax, ymax
            boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype)
            classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64)
            if use_polygons:
                # (x, y) coordinates of top left, top right, bottom right, bottom left corners
                boxes = np.stack(
                    [
                        np.stack([boxes[:, 0], boxes[:, 1]], axis=-1),
                        np.stack([boxes[:, 2], boxes[:, 1]], axis=-1),
                        np.stack([boxes[:, 2], boxes[:, 3]], axis=-1),
                        np.stack([boxes[:, 0], boxes[:, 3]], axis=-1),
                    ],
                    axis=1,
                )
            self.data.append((img_name, dict(boxes=boxes, labels=classes)))
        self.root = tmp_root

    def extra_repr(self) -> str:
        return f"train={self.train}"