File size: 3,607 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np

from doctr.file_utils import CLASS_NAME

from .datasets import AbstractDataset
from .utils import pre_transform_multiclass

__all__ = ["DetectionDataset"]


class DetectionDataset(AbstractDataset):
    """Implements a text detection dataset

    >>> from doctr.datasets import DetectionDataset
    >>> train_set = DetectionDataset(img_folder="/path/to/images",
    >>>                              label_path="/path/to/labels.json")
    >>> img, target = train_set[0]

    Args:
    ----
        img_folder: folder with all the images of the dataset
        label_path: path to the annotations of each image
        use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
        **kwargs: keyword arguments from `AbstractDataset`.
    """

    def __init__(
        self,
        img_folder: str,
        label_path: str,
        use_polygons: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            img_folder,
            pre_transforms=pre_transform_multiclass,
            **kwargs,
        )

        # File existence check
        self._class_names: List = []
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"unable to locate {label_path}")
        with open(label_path, "rb") as f:
            labels = json.load(f)

        self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
        np_dtype = np.float32
        for img_name, label in labels.items():
            # File existence check
            if not os.path.exists(os.path.join(self.root, img_name)):
                raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

            geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)

            self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))

    def format_polygons(
        self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
    ) -> Tuple[np.ndarray, List[str]]:
        """Format polygons into an array

        Args:
        ----
            polygons: the bounding boxes
            use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
            np_dtype: dtype of array

        Returns:
        -------
            geoms: bounding boxes as np array
            polygons_classes: list of classes for each bounding box
        """
        if isinstance(polygons, list):
            self._class_names += [CLASS_NAME]
            polygons_classes = [CLASS_NAME for _ in polygons]
            _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
        elif isinstance(polygons, dict):
            self._class_names += list(polygons.keys())
            polygons_classes = [k for k, v in polygons.items() for _ in v]
            _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
        else:
            raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
        geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
        return geoms, polygons_classes

    @property
    def class_names(self):
        return sorted(set(self._class_names))