Spaces:
Running
Running
File size: 4,449 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import List, Optional, Tuple
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, SequentialSampler
from training.datasets import LibriTTSDatasetAcoustic
def train_dataloader(
batch_size: int = 6,
num_workers: int = 5,
root: str = "datasets_cache/LIBRITTS",
cache: bool = True,
cache_dir: str = "datasets_cache",
mem_cache: bool = False,
url: str = "train-clean-360",
lang: str = "en",
selected_speaker_ids: Optional[List[int]] = None,
) -> DataLoader:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
batch_size (int): The batch size.
num_workers (int): The number of workers.
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache.
mem_cache (bool): Whether to use memory cache.
url (str): The URL of the dataset.
lang (str): The language of the dataset.
selected_speaker_ids (Optional[List[int]]): A list of selected speakers.
Returns:
DataLoader: The training and validation dataloaders.
"""
dataset = LibriTTSDatasetAcoustic(
root=root,
lang=lang,
cache=cache,
cache_dir=cache_dir,
mem_cache=mem_cache,
url=url,
selected_speaker_ids=selected_speaker_ids,
)
train_loader = DataLoader(
dataset,
# 4x80Gb max 10 sec audio
# batch_size=20, # self.train_config.batch_size,
# 4*80Gb max ~20.4 sec audio
batch_size=batch_size,
# TODO: find the optimal num_workers
num_workers=num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=False,
collate_fn=dataset.collate_fn,
)
return train_loader
def train_val_dataloader(
batch_size: int = 6,
num_workers: int = 5,
root: str = "datasets_cache/LIBRITTS",
cache: bool = True,
cache_dir: str = "datasets_cache",
mem_cache: bool = False,
url: str = "train-clean-360",
lang: str = "en",
validation_split: float = 0.02, # Percentage of data to use for validation
) -> Tuple[DataLoader, DataLoader]:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
batch_size (int): The batch size.
num_workers (int): The number of workers.
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache.
mem_cache (bool): Whether to use memory cache.
url (str): The URL of the dataset.
lang (str): The language of the dataset.
validation_split (float): The percentage of data to use for validation.
Returns:
Tupple[DataLoader, DataLoader]: The training and validation dataloaders.
"""
dataset = LibriTTSDatasetAcoustic(
root=root,
lang=lang,
cache=cache,
cache_dir=cache_dir,
mem_cache=mem_cache,
url=url,
)
# Split dataset into train and validation
train_indices, val_indices = train_test_split(
list(range(len(dataset))),
test_size=validation_split,
random_state=42,
)
# Create Samplers
train_sampler = SequentialSampler(train_indices)
val_sampler = SequentialSampler(val_indices)
# dataset = LibriTTSMMDatasetAcoustic("checkpoints/libri_preprocessed_data.pt")
train_loader = DataLoader(
dataset,
# 4x80Gb max 10 sec audio
# batch_size=20, # self.train_config.batch_size,
# 4*80Gb max ~20.4 sec audio
batch_size=batch_size,
# TODO: find the optimal num_workers
num_workers=num_workers,
sampler=train_sampler,
persistent_workers=True,
pin_memory=True,
shuffle=False,
collate_fn=dataset.collate_fn,
)
val_loader = DataLoader(
dataset,
# 4x80Gb max 10 sec audio
# batch_size=20, # self.train_config.batch_size,
# 4*80Gb max ~20.4 sec audio
batch_size=batch_size,
# TODO: find the optimal num_workers
num_workers=num_workers,
sampler=val_sampler,
persistent_workers=True,
pin_memory=True,
shuffle=False,
collate_fn=dataset.collate_fn,
)
return train_loader, val_loader
|