File size: 1,881 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from pathlib import Path
from typing import Any, List, Tuple

from .datasets import AbstractDataset

__all__ = ["RecognitionDataset"]


class RecognitionDataset(AbstractDataset):
    """Dataset implementation for text recognition tasks

    >>> from doctr.datasets import RecognitionDataset
    >>> train_set = RecognitionDataset(img_folder="/path/to/images",
    >>>                                labels_path="/path/to/labels.json")
    >>> img, target = train_set[0]

    Args:
    ----
        img_folder: path to the images folder
        labels_path: pathe to the json file containing all labels (character sequences)
        **kwargs: keyword arguments from `AbstractDataset`.
    """

    def __init__(
        self,
        img_folder: str,
        labels_path: str,
        **kwargs: Any,
    ) -> None:
        super().__init__(img_folder, **kwargs)

        self.data: List[Tuple[str, str]] = []
        with open(labels_path, encoding="utf-8") as f:
            labels = json.load(f)

        for img_name, label in labels.items():
            if not os.path.exists(os.path.join(self.root, img_name)):
                raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

            self.data.append((img_name, label))

    def merge_dataset(self, ds: AbstractDataset) -> None:
        # Update data with new root for self
        self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data]
        # Define new root
        self.root = Path("/")
        # Merge with ds data
        for img_path, label in ds.data:
            self.data.append((str(Path(ds.root).joinpath(img_path)), label))