File size: 10,167 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
from multiprocessing import Pool, cpu_count
import os
from pathlib import Path
from typing import Optional, Tuple, Union

import pandas as pd
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
from torchaudio._internal import download_url_to_file  # type: ignore
from torchaudio.datasets.utils import _extract_tar

URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriTTS"
_CHECKSUMS = {
    "http://us.openslr.org/resources/141/dev_clean.tar.gz": "2c1f5312914890634cc2d15783032ff3",
    "http://us.openslr.org/resources/141/dev_other.tar.gz": "62d3a80ad8a282b6f31b3904f0507e4f",
    "http://us.openslr.org/resources/141/test_clean.tar.gz": "4d373d453eb96c0691e598061bbafab7",
    "http://us.openslr.org/resources/141/test_other.tar.gz": "dbc0959d8bdb6d52200595cabc9995ae",
    "http://us.openslr.org/resources/141/train_clean_100.tar.gz": "6df668d8f5f33e70876bfa33862ad02b",
    "http://us.openslr.org/resources/141/train_clean_360.tar.gz": "382eb3e64394b3da6a559f864339b22c",
    "http://us.openslr.org/resources/141/train_other_500.tar.gz": "a37a8e9f4fe79d20601639bf23d1add8",
}


def load_libritts_item(

    fileid: str,

    path: str,

    ext_audio: str,

    ext_original_txt: str,

    ext_normalized_txt: str,

) -> Tuple[Tensor, int, str, str, int, int, str]:
    speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_")
    utterance_id = fileid

    file_audio = utterance_id + ext_audio
    file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

    # Load audio
    waveform, sample_rate = torchaudio.load(file_audio) # type: ignore

    # Try to load transcriptions from individual files
    normalized_text_filename = utterance_id + ext_normalized_txt
    normalized_text_path = os.path.join(path, speaker_id, chapter_id, normalized_text_filename)

    original_text_filename = utterance_id + ext_original_txt
    original_text_path = os.path.join(path, speaker_id, chapter_id, original_text_filename)

    try:
        # Load normalized text
        with open(normalized_text_path) as ft:
            normalized_text = ft.readline()

        # Load original text
        with open(original_text_path) as ft:
            original_text = ft.readline()

    except FileNotFoundError:
        # If individual files are not found, load from .tsv file
        trans_file = f"{speaker_id}_{chapter_id}.trans.tsv"
        trans_file = os.path.join(path, speaker_id, chapter_id, trans_file)
        df = pd.read_csv(trans_file, sep="\t", header=None, names=["id", "original_text", "normalized_text"])

        row = df[df["id"] == utterance_id].iloc[0]

        original_text = row["original_text"]
        normalized_text = row["normalized_text"]

        # Save original_text and normalized_text to separate text files
        with open(normalized_text_path, "w") as ft:
            ft.write(original_text)

        with open(original_text_path, "w") as ft:
            ft.write(normalized_text)

    return (
        waveform,
        sample_rate,
        original_text,
        normalized_text,
        int(speaker_id),
        int(chapter_id),
        utterance_id,
    )


def check_audio_length(args: Tuple[str, str, str, str, str, float, Optional[float]]) -> Optional[str]:
    """Check if the duration of an audio file is within a specified range.



    Args:

        args (Tuple[str, str, str, str, str, float, Optional[float]]): A tuple containing the following:

            - fileid (str): The ID of the file to check.

            - path (str): The path to the directory containing the audio file.

            - ext_audio (str): The file extension of the audio file.

            - ext_original_txt (str): The file extension of the original text file.

            - ext_normalized_txt (str): The file extension of the normalized text file.

            - min_audio_length (float): The minimum audio length in seconds. If the audio is shorter than this, it will be excluded.

            - max_audio_length (Optional[float]): The maximum audio length in seconds. If the audio is longer than this, it will be excluded. If None, no maximum length is enforced.



    Returns:

        Optional[str]: The ID of the file if its duration is within the specified range, or None if it's not.

    """
    (
        fileid,
        path,
        ext_audio,
        ext_original_txt,
        ext_normalized_txt,
        min_audio_length,
        max_audio_length,
    ) = args

    waveform, sample_rate, _, _, _, _, _ = load_libritts_item(
        fileid,
        path,
        ext_audio,
        ext_original_txt,
        ext_normalized_txt,
    )
    duration = waveform.shape[1] / sample_rate

    min_length_condition = duration > min_audio_length if min_audio_length > 0.0 else True
    max_length_condition = duration <= max_audio_length if max_audio_length is not None else True

    if min_length_condition and max_length_condition:
        return fileid
    else:
        return None


class LIBRITTS_R(Dataset):
    """*LibriTTS-R*: A Restored Multi-Speaker Text-to-Speech Corpus, arXiv, 2023



    Args:

        root (str or Path): Path to the directory where the dataset is found or downloaded.

        url (str, optional): The URL to download the dataset from,

            or the type of the dataset to dowload.

            Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,

            ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and

            ``"train-other-500"``. (default: ``"train-clean-100"``)

        folder_in_archive (str, optional):

            The top-level directory of the dataset. (default: ``"LibriTTS"``)

        download (bool, optional):

            Whether to download the dataset if it is not found at root path. (default: ``False``).

        selected_speaker_ids (list, optional): List of speaker IDs to be selected. (default: ``None``)

        min_audio_length (float, optional): Minimum audio length in seconds. (default: ``0.0``)

        max_audio_length (float, optional): Maximum audio length in seconds. (default: ``None``)

    """

    _ext_original_txt = ".original.txt"
    _ext_normalized_txt = ".normalized.txt"
    _ext_audio = ".wav"

    def __init__(

        self,

        root: Union[str, Path],

        url: str = URL,

        folder_in_archive: str = FOLDER_IN_ARCHIVE,

        download: bool = False,

        selected_speaker_ids: Union[None, list] = None,

        min_audio_length: float = 0.0,

        max_audio_length: Union[None, float] = None,

    ) -> None:

        if url in [
            "dev-clean",
            "dev-other",
            "test-clean",
            "test-other",
            "train-clean-100",
            "train-clean-360",
            "train-clean-460",
            "train-other-500",
            "train-960",
        ]:

            ext_archive = ".tar.gz"
            base_url = "http://us.openslr.org/resources/141/"

            url = os.path.join(base_url, url + ext_archive)

        # Get string representation of 'root' in case Path object is passed
        root = os.fspath(root)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.split(".")[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url)
                    download_url_to_file(url, archive, hash_prefix=checksum)
                _extract_tar(archive)
        else:
            if not os.path.exists(self._path):
                raise RuntimeError(
                    f"The path {self._path} doesn't exist. "
                    "Please check the ``root`` path or set `download=True` to download it",
                )

        self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))

        # Filter the walker based on the selected speaker IDs
        selected_speaker_ids_ = set(selected_speaker_ids) if selected_speaker_ids is not None else None
        if selected_speaker_ids_ is not None:
            self._walker = [w for w in self._walker if int(w.split("_")[0]) in selected_speaker_ids_]

        # Filter the walker based on the maximum audio length
        if max_audio_length is not None or min_audio_length > 0.0:
            params = (
                self._path,
                self._ext_audio,
                self._ext_original_txt,
                self._ext_normalized_txt,
                min_audio_length,
                max_audio_length,
            )
            with Pool(cpu_count()) as p:
                self._walker = [
                    fileid
                    for fileid in p.map(
                        check_audio_length,
                        [(fileid, *params) for fileid in self._walker],
                    )
                    if fileid is not None
                ]

    def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
        """Load the n-th sample from the dataset.



        Args:

            n (int): The index of the sample to be loaded



        Returns:

            Tuple of the following items;



            Tensor:

                Waveform

            int:

                Sample rate

            str:

                Original text

            str:

                Normalized text

            int:

                Speaker ID

            int:

                Chapter ID

            str:

                Utterance ID

        """
        fileid = self._walker[n]
        return load_libritts_item(
            fileid,
            self._path,
            self._ext_audio,
            self._ext_original_txt,
            self._ext_normalized_txt,
        )

    def __len__(self) -> int:
        return len(self._walker)