File size: 4,350 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from typing import Any, List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset

from training.tools import pad_1D, pad_2D, pad_3D


class LibriTTSMMDatasetAcoustic(Dataset):
    def __init__(self, file_path: str):
        r"""A PyTorch dataset for loading preprocessed acoustic data stored in memory-mapped files.



        Args:

            file_path (str): Path to the memory-mapped file.

        """
        self.data = torch.load(file_path)

    def __getitem__(self, idx: int):
        r"""Returns a sample from the dataset at the given index.



        Args:

            idx (int): Index of the sample to return.



        Returns:

            Dict[str, Any]: A dictionary containing the sample data.

        """
        return self.data[idx]

    def __len__(self):
        r"""Returns the number of samples in the dataset.



        Returns

            int: Number of samples in the dataset.

        """
        return len(self.data)

    def collate_fn(self, data: List) -> List:
        r"""Collates a batch of data samples.



        Args:

            data (List): A list of data samples.



        Returns:

            List: A list of reprocessed data batches.

        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(11)]
        (
            ids,
            speakers,
            texts,
            raw_texts,
            mels,
            pitches,
            attn_priors,
            langs,
            src_lens,
            mel_lens,
            wavs,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]
            ids.append(data_entry["id"])
            speakers.append(data_entry["speaker"])
            texts.append(data_entry["text"])
            raw_texts.append(data_entry["raw_text"])
            mels.append(data_entry["mel"])
            pitches.append(data_entry["pitch"])
            attn_priors.append(data_entry["attn_prior"].numpy())
            langs.append(data_entry["lang"])
            src_lens.append(data_entry["text"].shape[0])
            mel_lens.append(data_entry["mel"].shape[1])
            wavs.append(data_entry["wav"].numpy())

        # Convert langs, src_lens, and mel_lens to numpy arrays
        langs = np.array(langs)
        src_lens = np.array(src_lens)
        mel_lens = np.array(mel_lens)

        # NOTE: Instead of the pitches for the whole dataset, used stat for the batch
        # Take only min and max values for pitch
        pitches_stat = list(self.normalize_pitch(pitches)[:2])

        texts = pad_1D(texts)
        mels = pad_2D(mels)
        pitches = pad_1D(pitches)
        attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens))

        speakers = np.repeat(
            np.expand_dims(np.array(speakers), axis=1), texts.shape[1], axis=1,
        )
        langs = np.repeat(
            np.expand_dims(np.array(langs), axis=1), texts.shape[1], axis=1,
        )

        wavs = pad_2D(wavs)

        return [
            ids,
            raw_texts,
            torch.from_numpy(speakers),
            torch.from_numpy(texts).int(),
            torch.from_numpy(src_lens),
            torch.from_numpy(mels),
            torch.from_numpy(pitches),
            pitches_stat,
            torch.from_numpy(mel_lens),
            torch.from_numpy(langs),
            torch.from_numpy(attn_priors),
            torch.from_numpy(wavs),
        ]

    def normalize_pitch(

        self, pitches: List[torch.Tensor],

    ) -> Tuple[float, float, float, float]:
        r"""Normalizes the pitch values.



        Args:

            pitches (List[torch.Tensor]): A list of pitch values.



        Returns:

            Tuple: A tuple containing the normalized pitch values.

        """
        pitches_t = torch.concatenate(pitches)

        min_value = torch.min(pitches_t).item()
        max_value = torch.max(pitches_t).item()

        mean = torch.mean(pitches_t).item()
        std = torch.std(pitches_t).item()

        return min_value, max_value, mean, std