Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
import math | |
from typing import Callable, Optional | |
import numpy as np | |
import tensorflow as tf | |
from doctr.utils.multithreading import multithread_exec | |
__all__ = ["DataLoader"] | |
def default_collate(samples): | |
"""Collate multiple elements into batches | |
Args: | |
---- | |
samples: list of N tuples containing M elements | |
Returns: | |
------- | |
Tuple of M sequences contianing N elements each | |
""" | |
batch_data = zip(*samples) | |
tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data) | |
return tf_data | |
class DataLoader: | |
"""Implements a dataset wrapper for fast data loading | |
>>> from doctr.datasets import CORD, DataLoader | |
>>> train_set = CORD(train=True, download=True) | |
>>> train_loader = DataLoader(train_set, batch_size=32) | |
>>> train_iter = iter(train_loader) | |
>>> images, targets = next(train_iter) | |
Args: | |
---- | |
dataset: the dataset | |
shuffle: whether the samples should be shuffled before passing it to the iterator | |
batch_size: number of elements in each batch | |
drop_last: if `True`, drops the last batch if it isn't full | |
num_workers: number of workers to use for data loading | |
collate_fn: function to merge samples into a batch | |
""" | |
def __init__( | |
self, | |
dataset, | |
shuffle: bool = True, | |
batch_size: int = 1, | |
drop_last: bool = False, | |
num_workers: Optional[int] = None, | |
collate_fn: Optional[Callable] = None, | |
) -> None: | |
self.dataset = dataset | |
self.shuffle = shuffle | |
self.batch_size = batch_size | |
nb = len(self.dataset) / batch_size | |
self.num_batches = math.floor(nb) if drop_last else math.ceil(nb) | |
if collate_fn is None: | |
self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate | |
else: | |
self.collate_fn = collate_fn | |
self.num_workers = num_workers | |
self.reset() | |
def __len__(self) -> int: | |
return self.num_batches | |
def reset(self) -> None: | |
# Updates indices after each epoch | |
self._num_yielded = 0 | |
self.indices = np.arange(len(self.dataset)) | |
if self.shuffle is True: | |
np.random.shuffle(self.indices) | |
def __iter__(self): | |
self.reset() | |
return self | |
def __next__(self): | |
if self._num_yielded < self.num_batches: | |
# Get next indices | |
idx = self._num_yielded * self.batch_size | |
indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)] | |
samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers)) | |
batch_data = self.collate_fn(samples) | |
self._num_yielded += 1 | |
return batch_data | |
else: | |
raise StopIteration | |