conex / espnet /utils /training /iterators.py
tobiasc's picture
Initial commit
ad16788
import chainer
from chainer.iterators import MultiprocessIterator
from chainer.iterators import SerialIterator
from chainer.iterators import ShuffleOrderSampler
from chainer.training.extension import Extension
import numpy as np
class ShufflingEnabler(Extension):
"""An extension enabling shuffling on an Iterator"""
def __init__(self, iterators):
"""Inits the ShufflingEnabler
:param list[Iterator] iterators: The iterators to enable shuffling on
"""
self.set = False
self.iterators = iterators
def __call__(self, trainer):
"""Calls the enabler on the given iterator
:param trainer: The iterator
"""
if not self.set:
for iterator in self.iterators:
iterator.start_shuffle()
self.set = True
class ToggleableShufflingSerialIterator(SerialIterator):
"""A SerialIterator having its shuffling property activated during training"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
"""Init the Iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat data (allow multiple epochs)
:param bool shuffle: Whether to shuffle the batches
"""
super(ToggleableShufflingSerialIterator, self).__init__(
dataset, batch_size, repeat, shuffle
)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self._shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
class ToggleableShufflingMultiprocessIterator(MultiprocessIterator):
"""A MultiprocessIterator having its shuffling property activated during training"""
def __init__(
self,
dataset,
batch_size,
repeat=True,
shuffle=True,
n_processes=None,
n_prefetch=1,
shared_mem=None,
maxtasksperchild=20,
):
"""Init the iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat batches or not (enables multiple epochs)
:param bool shuffle: Whether to shuffle the order of the batches
:param int n_processes: How many processes to use
:param int n_prefetch: The number of prefetch to use
:param int shared_mem: How many memory to share between processes
:param int maxtasksperchild: Maximum number of tasks per child
"""
super(ToggleableShufflingMultiprocessIterator, self).__init__(
dataset=dataset,
batch_size=batch_size,
repeat=repeat,
shuffle=shuffle,
n_processes=n_processes,
n_prefetch=n_prefetch,
shared_mem=shared_mem,
maxtasksperchild=maxtasksperchild,
)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self.shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
self._set_prefetch_state()