from typing import List, Optional, Tuple from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, SequentialSampler from training.datasets import LibriTTSDatasetAcoustic def train_dataloader( batch_size: int = 6, num_workers: int = 5, root: str = "datasets_cache/LIBRITTS", cache: bool = True, cache_dir: str = "datasets_cache", mem_cache: bool = False, url: str = "train-clean-360", lang: str = "en", selected_speaker_ids: Optional[List[int]] = None, ) -> DataLoader: r"""Returns the training dataloader, that is using the LibriTTS dataset. Args: batch_size (int): The batch size. num_workers (int): The number of workers. root (str): The root directory of the dataset. cache (bool): Whether to cache the preprocessed data. cache_dir (str): The directory for the cache. mem_cache (bool): Whether to use memory cache. url (str): The URL of the dataset. lang (str): The language of the dataset. selected_speaker_ids (Optional[List[int]]): A list of selected speakers. Returns: DataLoader: The training and validation dataloaders. """ dataset = LibriTTSDatasetAcoustic( root=root, lang=lang, cache=cache, cache_dir=cache_dir, mem_cache=mem_cache, url=url, selected_speaker_ids=selected_speaker_ids, ) train_loader = DataLoader( dataset, # 4x80Gb max 10 sec audio # batch_size=20, # self.train_config.batch_size, # 4*80Gb max ~20.4 sec audio batch_size=batch_size, # TODO: find the optimal num_workers num_workers=num_workers, persistent_workers=True, pin_memory=True, shuffle=False, collate_fn=dataset.collate_fn, ) return train_loader def train_val_dataloader( batch_size: int = 6, num_workers: int = 5, root: str = "datasets_cache/LIBRITTS", cache: bool = True, cache_dir: str = "datasets_cache", mem_cache: bool = False, url: str = "train-clean-360", lang: str = "en", validation_split: float = 0.02, # Percentage of data to use for validation ) -> Tuple[DataLoader, DataLoader]: r"""Returns the training dataloader, that is using the LibriTTS dataset. Args: batch_size (int): The batch size. num_workers (int): The number of workers. root (str): The root directory of the dataset. cache (bool): Whether to cache the preprocessed data. cache_dir (str): The directory for the cache. mem_cache (bool): Whether to use memory cache. url (str): The URL of the dataset. lang (str): The language of the dataset. validation_split (float): The percentage of data to use for validation. Returns: Tupple[DataLoader, DataLoader]: The training and validation dataloaders. """ dataset = LibriTTSDatasetAcoustic( root=root, lang=lang, cache=cache, cache_dir=cache_dir, mem_cache=mem_cache, url=url, ) # Split dataset into train and validation train_indices, val_indices = train_test_split( list(range(len(dataset))), test_size=validation_split, random_state=42, ) # Create Samplers train_sampler = SequentialSampler(train_indices) val_sampler = SequentialSampler(val_indices) # dataset = LibriTTSMMDatasetAcoustic("checkpoints/libri_preprocessed_data.pt") train_loader = DataLoader( dataset, # 4x80Gb max 10 sec audio # batch_size=20, # self.train_config.batch_size, # 4*80Gb max ~20.4 sec audio batch_size=batch_size, # TODO: find the optimal num_workers num_workers=num_workers, sampler=train_sampler, persistent_workers=True, pin_memory=True, shuffle=False, collate_fn=dataset.collate_fn, ) val_loader = DataLoader( dataset, # 4x80Gb max 10 sec audio # batch_size=20, # self.train_config.batch_size, # 4*80Gb max ~20.4 sec audio batch_size=batch_size, # TODO: find the optimal num_workers num_workers=num_workers, sampler=val_sampler, persistent_workers=True, pin_memory=True, shuffle=False, collate_fn=dataset.collate_fn, ) return train_loader, val_loader