PeechTTSv22050 / training /datasets /libritts_dataset_vocoder.py
nickovchinnikov's picture
Init
9d61c9b
raw
history blame
3.49 kB
from typing import Any, Dict, List
import numpy as np
import torch
from torch.utils.data import Dataset
from torchaudio import datasets
from models.config import PreprocessingConfigUnivNet, get_lang_map
from training.preprocess import PreprocessLibriTTS
from training.tools import pad_1D, pad_2D
class LibriTTSDatasetVocoder(Dataset):
r"""Loading preprocessed univnet model data."""
def __init__(
self,
root: str,
batch_size: int,
download: bool = True,
lang: str = "en",
):
r"""A PyTorch dataset for loading preprocessed univnet data.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
batch_size (int): Batch size for the dataset.
download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
"""
self.dataset = datasets.LIBRITTS(root=root, download=download)
self.batch_size = batch_size
lang_map = get_lang_map(lang)
self.preprocess_libtts = PreprocessLibriTTS(
PreprocessingConfigUnivNet(lang_map.processing_lang_type),
)
def __len__(self) -> int:
r"""Returns the number of samples in the dataset.
Returns
int: Number of samples in the dataset.
"""
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
r"""Returns a sample from the dataset at the given index.
Args:
idx (int): Index of the sample to return.
Returns:
Dict[str, Any]: A dictionary containing the sample data.
"""
# Retrive the dataset row
data = self.dataset[idx]
data = self.preprocess_libtts.univnet(data)
if data is None:
# print("Skipping due to preprocessing error")
rand_idx = np.random.randint(0, self.__len__())
return self.__getitem__(rand_idx)
mel, audio, speaker_id = data
return {
"mel": mel,
"audio": audio,
"speaker_id": speaker_id,
}
def collate_fn(self, data: List) -> List:
r"""Collates a batch of data samples.
Args:
data (List): A list of data samples.
Returns:
List: A list of reprocessed data batches.
"""
data_size = len(data)
idxs = list(range(data_size))
# Initialize empty lists to store extracted values
empty_lists: List[List] = [[] for _ in range(4)]
(
mels,
mel_lens,
audios,
speaker_ids,
) = empty_lists
# Extract fields from data dictionary and populate the lists
for idx in idxs:
data_entry = data[idx]
mels.append(data_entry["mel"])
mel_lens.append(data_entry["mel"].shape[1])
audios.append(data_entry["audio"])
speaker_ids.append(data_entry["speaker_id"])
mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)
return [
mels,
mel_lens,
audios,
speaker_ids,
]