adirathor07's picture
added doctr folder
153628e
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import math
from typing import Callable, Optional
import numpy as np
import tensorflow as tf
from doctr.utils.multithreading import multithread_exec
__all__ = ["DataLoader"]
def default_collate(samples):
"""Collate multiple elements into batches
Args:
----
samples: list of N tuples containing M elements
Returns:
-------
Tuple of M sequences contianing N elements each
"""
batch_data = zip(*samples)
tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)
return tf_data
class DataLoader:
"""Implements a dataset wrapper for fast data loading
>>> from doctr.datasets import CORD, DataLoader
>>> train_set = CORD(train=True, download=True)
>>> train_loader = DataLoader(train_set, batch_size=32)
>>> train_iter = iter(train_loader)
>>> images, targets = next(train_iter)
Args:
----
dataset: the dataset
shuffle: whether the samples should be shuffled before passing it to the iterator
batch_size: number of elements in each batch
drop_last: if `True`, drops the last batch if it isn't full
num_workers: number of workers to use for data loading
collate_fn: function to merge samples into a batch
"""
def __init__(
self,
dataset,
shuffle: bool = True,
batch_size: int = 1,
drop_last: bool = False,
num_workers: Optional[int] = None,
collate_fn: Optional[Callable] = None,
) -> None:
self.dataset = dataset
self.shuffle = shuffle
self.batch_size = batch_size
nb = len(self.dataset) / batch_size
self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
if collate_fn is None:
self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
else:
self.collate_fn = collate_fn
self.num_workers = num_workers
self.reset()
def __len__(self) -> int:
return self.num_batches
def reset(self) -> None:
# Updates indices after each epoch
self._num_yielded = 0
self.indices = np.arange(len(self.dataset))
if self.shuffle is True:
np.random.shuffle(self.indices)
def __iter__(self):
self.reset()
return self
def __next__(self):
if self._num_yielded < self.num_batches:
# Get next indices
idx = self._num_yielded * self.batch_size
indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
batch_data = self.collate_fn(samples)
self._num_yielded += 1
return batch_data
else:
raise StopIteration