File size: 3,044 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import math
from typing import Callable, Optional

import numpy as np
import tensorflow as tf

from doctr.utils.multithreading import multithread_exec

__all__ = ["DataLoader"]


def default_collate(samples):
    """Collate multiple elements into batches

    Args:
    ----
        samples: list of N tuples containing M elements

    Returns:
    -------
        Tuple of M sequences contianing N elements each
    """
    batch_data = zip(*samples)

    tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)

    return tf_data


class DataLoader:
    """Implements a dataset wrapper for fast data loading

    >>> from doctr.datasets import CORD, DataLoader
    >>> train_set = CORD(train=True, download=True)
    >>> train_loader = DataLoader(train_set, batch_size=32)
    >>> train_iter = iter(train_loader)
    >>> images, targets = next(train_iter)

    Args:
    ----
        dataset: the dataset
        shuffle: whether the samples should be shuffled before passing it to the iterator
        batch_size: number of elements in each batch
        drop_last: if `True`, drops the last batch if it isn't full
        num_workers: number of workers to use for data loading
        collate_fn: function to merge samples into a batch
    """

    def __init__(
        self,
        dataset,
        shuffle: bool = True,
        batch_size: int = 1,
        drop_last: bool = False,
        num_workers: Optional[int] = None,
        collate_fn: Optional[Callable] = None,
    ) -> None:
        self.dataset = dataset
        self.shuffle = shuffle
        self.batch_size = batch_size
        nb = len(self.dataset) / batch_size
        self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
        if collate_fn is None:
            self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
        else:
            self.collate_fn = collate_fn
        self.num_workers = num_workers
        self.reset()

    def __len__(self) -> int:
        return self.num_batches

    def reset(self) -> None:
        # Updates indices after each epoch
        self._num_yielded = 0
        self.indices = np.arange(len(self.dataset))
        if self.shuffle is True:
            np.random.shuffle(self.indices)

    def __iter__(self):
        self.reset()
        return self

    def __next__(self):
        if self._num_yielded < self.num_batches:
            # Get next indices
            idx = self._num_yielded * self.batch_size
            indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]

            samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))

            batch_data = self.collate_fn(samples)

            self._num_yielded += 1
            return batch_data
        else:
            raise StopIteration