PeechTTSv22050 / training /datasets /libritts_dataset_vocoder.py
nickovchinnikov's picture
Init
9d61c9b
from typing import Any, Dict, List
import numpy as np
import torch
from torch.utils.data import Dataset
from torchaudio import datasets
from models.config import PreprocessingConfigUnivNet, get_lang_map
from training.preprocess import PreprocessLibriTTS
from training.tools import pad_1D, pad_2D
class LibriTTSDatasetVocoder(Dataset):
r"""Loading preprocessed univnet model data."""
def __init__(
self,
root: str,
batch_size: int,
download: bool = True,
lang: str = "en",
):
r"""A PyTorch dataset for loading preprocessed univnet data.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
batch_size (int): Batch size for the dataset.
download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
"""
self.dataset = datasets.LIBRITTS(root=root, download=download)
self.batch_size = batch_size
lang_map = get_lang_map(lang)
self.preprocess_libtts = PreprocessLibriTTS(
PreprocessingConfigUnivNet(lang_map.processing_lang_type),
)
def __len__(self) -> int:
r"""Returns the number of samples in the dataset.
Returns
int: Number of samples in the dataset.
"""
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
r"""Returns a sample from the dataset at the given index.
Args:
idx (int): Index of the sample to return.
Returns:
Dict[str, Any]: A dictionary containing the sample data.
"""
# Retrive the dataset row
data = self.dataset[idx]
data = self.preprocess_libtts.univnet(data)
if data is None:
# print("Skipping due to preprocessing error")
rand_idx = np.random.randint(0, self.__len__())
return self.__getitem__(rand_idx)
mel, audio, speaker_id = data
return {
"mel": mel,
"audio": audio,
"speaker_id": speaker_id,
}
def collate_fn(self, data: List) -> List:
r"""Collates a batch of data samples.
Args:
data (List): A list of data samples.
Returns:
List: A list of reprocessed data batches.
"""
data_size = len(data)
idxs = list(range(data_size))
# Initialize empty lists to store extracted values
empty_lists: List[List] = [[] for _ in range(4)]
(
mels,
mel_lens,
audios,
speaker_ids,
) = empty_lists
# Extract fields from data dictionary and populate the lists
for idx in idxs:
data_entry = data[idx]
mels.append(data_entry["mel"])
mel_lens.append(data_entry["mel"].shape[1])
audios.append(data_entry["audio"])
speaker_ids.append(data_entry["speaker_id"])
mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)
return [
mels,
mel_lens,
audios,
speaker_ids,
]