PeechTTSv22050 / training /datasets /libritts_mm_dataset_acoustic.py
nickovchinnikov's picture
Init
9d61c9b
from typing import Any, List, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from training.tools import pad_1D, pad_2D, pad_3D
class LibriTTSMMDatasetAcoustic(Dataset):
def __init__(self, file_path: str):
r"""A PyTorch dataset for loading preprocessed acoustic data stored in memory-mapped files.
Args:
file_path (str): Path to the memory-mapped file.
"""
self.data = torch.load(file_path)
def __getitem__(self, idx: int):
r"""Returns a sample from the dataset at the given index.
Args:
idx (int): Index of the sample to return.
Returns:
Dict[str, Any]: A dictionary containing the sample data.
"""
return self.data[idx]
def __len__(self):
r"""Returns the number of samples in the dataset.
Returns
int: Number of samples in the dataset.
"""
return len(self.data)
def collate_fn(self, data: List) -> List:
r"""Collates a batch of data samples.
Args:
data (List): A list of data samples.
Returns:
List: A list of reprocessed data batches.
"""
data_size = len(data)
idxs = list(range(data_size))
# Initialize empty lists to store extracted values
empty_lists: List[List] = [[] for _ in range(11)]
(
ids,
speakers,
texts,
raw_texts,
mels,
pitches,
attn_priors,
langs,
src_lens,
mel_lens,
wavs,
) = empty_lists
# Extract fields from data dictionary and populate the lists
for idx in idxs:
data_entry = data[idx]
ids.append(data_entry["id"])
speakers.append(data_entry["speaker"])
texts.append(data_entry["text"])
raw_texts.append(data_entry["raw_text"])
mels.append(data_entry["mel"])
pitches.append(data_entry["pitch"])
attn_priors.append(data_entry["attn_prior"].numpy())
langs.append(data_entry["lang"])
src_lens.append(data_entry["text"].shape[0])
mel_lens.append(data_entry["mel"].shape[1])
wavs.append(data_entry["wav"].numpy())
# Convert langs, src_lens, and mel_lens to numpy arrays
langs = np.array(langs)
src_lens = np.array(src_lens)
mel_lens = np.array(mel_lens)
# NOTE: Instead of the pitches for the whole dataset, used stat for the batch
# Take only min and max values for pitch
pitches_stat = list(self.normalize_pitch(pitches)[:2])
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens))
speakers = np.repeat(
np.expand_dims(np.array(speakers), axis=1), texts.shape[1], axis=1,
)
langs = np.repeat(
np.expand_dims(np.array(langs), axis=1), texts.shape[1], axis=1,
)
wavs = pad_2D(wavs)
return [
ids,
raw_texts,
torch.from_numpy(speakers),
torch.from_numpy(texts).int(),
torch.from_numpy(src_lens),
torch.from_numpy(mels),
torch.from_numpy(pitches),
pitches_stat,
torch.from_numpy(mel_lens),
torch.from_numpy(langs),
torch.from_numpy(attn_priors),
torch.from_numpy(wavs),
]
def normalize_pitch(
self, pitches: List[torch.Tensor],
) -> Tuple[float, float, float, float]:
r"""Normalizes the pitch values.
Args:
pitches (List[torch.Tensor]): A list of pitch values.
Returns:
Tuple: A tuple containing the normalized pitch values.
"""
pitches_t = torch.concatenate(pitches)
min_value = torch.min(pitches_t).item()
max_value = torch.max(pitches_t).item()
mean = torch.mean(pitches_t).item()
std = torch.std(pitches_t).item()
return min_value, max_value, mean, std