from typing import Any, Dict, List import numpy as np import torch from torch.utils.data import Dataset from torchaudio import datasets from models.config import PreprocessingConfigUnivNet, get_lang_map from training.preprocess import PreprocessLibriTTS from training.tools import pad_1D, pad_2D class LibriTTSDatasetVocoder(Dataset): r"""Loading preprocessed univnet model data.""" def __init__( self, root: str, batch_size: int, download: bool = True, lang: str = "en", ): r"""A PyTorch dataset for loading preprocessed univnet data. Args: root (str): Path to the directory where the dataset is found or downloaded. batch_size (int): Batch size for the dataset. download (bool, optional): Whether to download the dataset if it is not found. Defaults to True. """ self.dataset = datasets.LIBRITTS(root=root, download=download) self.batch_size = batch_size lang_map = get_lang_map(lang) self.preprocess_libtts = PreprocessLibriTTS( PreprocessingConfigUnivNet(lang_map.processing_lang_type), ) def __len__(self) -> int: r"""Returns the number of samples in the dataset. Returns int: Number of samples in the dataset. """ return len(self.dataset) def __getitem__(self, idx: int) -> Dict[str, Any]: r"""Returns a sample from the dataset at the given index. Args: idx (int): Index of the sample to return. Returns: Dict[str, Any]: A dictionary containing the sample data. """ # Retrive the dataset row data = self.dataset[idx] data = self.preprocess_libtts.univnet(data) if data is None: # print("Skipping due to preprocessing error") rand_idx = np.random.randint(0, self.__len__()) return self.__getitem__(rand_idx) mel, audio, speaker_id = data return { "mel": mel, "audio": audio, "speaker_id": speaker_id, } def collate_fn(self, data: List) -> List: r"""Collates a batch of data samples. Args: data (List): A list of data samples. Returns: List: A list of reprocessed data batches. """ data_size = len(data) idxs = list(range(data_size)) # Initialize empty lists to store extracted values empty_lists: List[List] = [[] for _ in range(4)] ( mels, mel_lens, audios, speaker_ids, ) = empty_lists # Extract fields from data dictionary and populate the lists for idx in idxs: data_entry = data[idx] mels.append(data_entry["mel"]) mel_lens.append(data_entry["mel"].shape[1]) audios.append(data_entry["audio"]) speaker_ids.append(data_entry["speaker_id"]) mels = torch.tensor(pad_2D(mels), dtype=torch.float32) mel_lens = torch.tensor(mel_lens, dtype=torch.int64) audios = torch.tensor(pad_1D(audios), dtype=torch.float32) speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64) return [ mels, mel_lens, audios, speaker_ids, ]