File size: 4,236 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import unittest

from training.datasets.libritts_r import LIBRITTS_R, load_libritts_item


class TestLibriTTS(unittest.TestCase):
    def setUp(self):
        # Set up any necessary values for the tests
        self.fileid = "1061_146197_000015_000000"
        self.path = "datasets_cache/LIBRITTS/LibriTTS/train-clean-360"
        self.ext_audio = ".wav"
        self.ext_original_txt = ".original.txt"
        self.ext_normalized_txt = ".normalized.txt"

    def test_load_libritts_item(self):
        # Test the load_libritts_item function
        waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id = load_libritts_item(
            self.fileid,
            self.path,
            self.ext_audio,
            self.ext_original_txt,
            self.ext_normalized_txt,
        )

        base_path = os.path.join(
            self.path,
            f"{speaker_id}",
            f"{chapter_id}",
        )

        # Check that the files were created
        self.assertTrue(
            os.path.exists(
                os.path.join(
                    base_path,
                    self.fileid + self.ext_original_txt,
                ),
            ),
        )
        self.assertTrue(
            os.path.exists(
                os.path.join(
                    base_path,
                    self.fileid + self.ext_normalized_txt,
                ),
            ),
        )

    def test_selected_speaker_ids(self):
        # Initialize the dataset with selected speaker IDs
        dataset = LIBRITTS_R(root="datasets_cache/LIBRITTS", url="train-clean-100", selected_speaker_ids=[19, 26])

        # Iterate over the dataset and check the speaker IDs
        for _, _, _, _, speaker_id, _, _ in dataset:
            # Assert that the speaker ID is in the list of selected speaker IDs
            self.assertIn(speaker_id, [19, 26])

    def test_max_audio_length(self):
        # Initialize the dataset with a maximum audio length
        dataset = LIBRITTS_R(
            root="datasets_cache/LIBRITTS",
            url="train-clean-100",
            max_audio_length=3.0,
            selected_speaker_ids=[19, 26],
        )

        # Iterate over the dataset and check the audio lengths
        for waveform, sample_rate, _, _, speaker_id, _, _ in dataset:
            # Get the duration of the waveform in seconds
            duration = waveform.shape[1] / sample_rate

            # Assert that the speaker ID is in the list of selected speaker IDs
            self.assertIn(speaker_id, [19, 26])
            # Assert that the duration is less than or equal to the maximum length
            self.assertLessEqual(duration, 3.0)

    def test_min_audio_length(self):
        # Initialize the dataset with a minimum audio length
        dataset = LIBRITTS_R(
            root="datasets_cache/LIBRITTS",
            url="train-clean-100",
            min_audio_length=30.0,
        )

        # Iterate over the dataset and check the audio lengths
        for waveform, sample_rate, _, _, _, _, _ in dataset:
            # Get the duration of the waveform in seconds
            duration = waveform.shape[1] / sample_rate

            # Assert that the duration is greater than or equal to the minimum length
            self.assertGreaterEqual(duration, 30.0)

    # Add any other assertions you want to make about the return values
    def tearDown(self):
        speaker_id, chapter_id, _, _ = self.fileid.split("_")

        normalized_text_filename = self.fileid + self.ext_normalized_txt
        normalized_text_path = os.path.join(self.path, speaker_id, chapter_id, normalized_text_filename)

        original_text_filename = self.fileid + self.ext_original_txt
        original_text_path = os.path.join(self.path, speaker_id, chapter_id, original_text_filename)

        # Clean up any created files after tests are done
        if os.path.exists(normalized_text_path):
            os.remove(normalized_text_path)
        if os.path.exists(original_text_path):
            os.remove(original_text_path)

if __name__ == "__main__":
    unittest.main()