Spaces:
Running
Running
File size: 4,350 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from typing import Any, List, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from training.tools import pad_1D, pad_2D, pad_3D
class LibriTTSMMDatasetAcoustic(Dataset):
def __init__(self, file_path: str):
r"""A PyTorch dataset for loading preprocessed acoustic data stored in memory-mapped files.
Args:
file_path (str): Path to the memory-mapped file.
"""
self.data = torch.load(file_path)
def __getitem__(self, idx: int):
r"""Returns a sample from the dataset at the given index.
Args:
idx (int): Index of the sample to return.
Returns:
Dict[str, Any]: A dictionary containing the sample data.
"""
return self.data[idx]
def __len__(self):
r"""Returns the number of samples in the dataset.
Returns
int: Number of samples in the dataset.
"""
return len(self.data)
def collate_fn(self, data: List) -> List:
r"""Collates a batch of data samples.
Args:
data (List): A list of data samples.
Returns:
List: A list of reprocessed data batches.
"""
data_size = len(data)
idxs = list(range(data_size))
# Initialize empty lists to store extracted values
empty_lists: List[List] = [[] for _ in range(11)]
(
ids,
speakers,
texts,
raw_texts,
mels,
pitches,
attn_priors,
langs,
src_lens,
mel_lens,
wavs,
) = empty_lists
# Extract fields from data dictionary and populate the lists
for idx in idxs:
data_entry = data[idx]
ids.append(data_entry["id"])
speakers.append(data_entry["speaker"])
texts.append(data_entry["text"])
raw_texts.append(data_entry["raw_text"])
mels.append(data_entry["mel"])
pitches.append(data_entry["pitch"])
attn_priors.append(data_entry["attn_prior"].numpy())
langs.append(data_entry["lang"])
src_lens.append(data_entry["text"].shape[0])
mel_lens.append(data_entry["mel"].shape[1])
wavs.append(data_entry["wav"].numpy())
# Convert langs, src_lens, and mel_lens to numpy arrays
langs = np.array(langs)
src_lens = np.array(src_lens)
mel_lens = np.array(mel_lens)
# NOTE: Instead of the pitches for the whole dataset, used stat for the batch
# Take only min and max values for pitch
pitches_stat = list(self.normalize_pitch(pitches)[:2])
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens))
speakers = np.repeat(
np.expand_dims(np.array(speakers), axis=1), texts.shape[1], axis=1,
)
langs = np.repeat(
np.expand_dims(np.array(langs), axis=1), texts.shape[1], axis=1,
)
wavs = pad_2D(wavs)
return [
ids,
raw_texts,
torch.from_numpy(speakers),
torch.from_numpy(texts).int(),
torch.from_numpy(src_lens),
torch.from_numpy(mels),
torch.from_numpy(pitches),
pitches_stat,
torch.from_numpy(mel_lens),
torch.from_numpy(langs),
torch.from_numpy(attn_priors),
torch.from_numpy(wavs),
]
def normalize_pitch(
self, pitches: List[torch.Tensor],
) -> Tuple[float, float, float, float]:
r"""Normalizes the pitch values.
Args:
pitches (List[torch.Tensor]): A list of pitch values.
Returns:
Tuple: A tuple containing the normalized pitch values.
"""
pitches_t = torch.concatenate(pitches)
min_value = torch.min(pitches_t).item()
max_value = torch.max(pitches_t).item()
mean = torch.mean(pitches_t).item()
std = torch.std(pitches_t).item()
return min_value, max_value, mean, std
|