conex / espnet2 /iterators /multiple_iter_factory.py
tobiasc's picture
Initial commit
ad16788
raw
history blame
1.14 kB
import logging
from typing import Callable
from typing import Collection
from typing import Iterator
import numpy as np
from typeguard import check_argument_types
from espnet2.iterators.abs_iter_factory import AbsIterFactory
class MultipleIterFactory(AbsIterFactory):
def __init__(
self,
build_funcs: Collection[Callable[[], AbsIterFactory]],
seed: int = 0,
shuffle: bool = False,
):
assert check_argument_types()
self.build_funcs = list(build_funcs)
self.seed = seed
self.shuffle = shuffle
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
if shuffle is None:
shuffle = self.shuffle
build_funcs = list(self.build_funcs)
if shuffle:
np.random.RandomState(epoch + self.seed).shuffle(build_funcs)
for i, build_func in enumerate(build_funcs):
logging.info(f"Building {i}th iter-factory...")
iter_factory = build_func()
assert isinstance(iter_factory, AbsIterFactory), type(iter_factory)
yield from iter_factory.build_iter(epoch, shuffle)