File size: 5,136 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""audio_test.py"""
import unittest
import os
import numpy as np
import wave
import tempfile
from utils.audio import load_audio_file
from utils.audio import get_audio_file_info
from utils.audio import slice_padded_array
from utils.audio import slice_padded_array_for_subbatch
from utils.audio import write_wav_file


class TestLoadAudioFile(unittest.TestCase):

    def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str:
        n_samples = int(duration * fs)
        temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        temp_filename = temp_file.name

        data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16)

        with wave.open(temp_filename, 'wb') as f:
            f.setnchannels(1)
            f.setsampwidth(2)
            f.setframerate(fs)
            f.writeframes(data.tobytes())

        return temp_filename

    def test_load_audio_file(self):
        duration = 3.0
        fs = 16000
        temp_filename = self.create_temp_wav_file(duration, fs)

        # Test load entire file
        audio_data = load_audio_file(temp_filename, dtype=np.int16)
        file_fs, n_frames, n_channels = get_audio_file_info(temp_filename)

        self.assertEqual(len(audio_data), n_frames)
        self.assertEqual(file_fs, fs)
        self.assertEqual(n_channels, 1)

        # Test load specific segment
        seg_start_sec = 1.0
        seg_length_sec = 1.0
        audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16)

        self.assertEqual(len(audio_data), int(seg_length_sec * fs))

        # Test unsupported file extension
        with self.assertRaises(NotImplementedError):
            load_audio_file("unsupported.xyz")


class TestSliceArray(unittest.TestCase):

    def setUp(self):
        self.x = np.random.randint(0, 10, size=(1, 10000))

    def test_without_padding(self):
        sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False)
        self.assertEqual(sliced_x.shape, (199, 100))

    def test_with_padding(self):
        sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True)
        self.assertEqual(sliced_x.shape, (199, 100))

    def test_content(self):
        sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True)
        for i in range(sliced_x.shape[0] - 1):
            np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten())
        # Test the last slice separately to account for potential padding
        last_slice = sliced_x[-1, :]
        last_slice_no_padding = self.x[:, -100:].flatten()
        np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding)


class TestSlicePadForSubbatch(unittest.TestCase):

    def test_slice_padded_array_for_subbatch(self):
        input_array = np.random.randn(6, 10)
        slice_length = 4
        slice_hop = 2
        pad = True
        sub_batch_size = 4

        expected_output_shape = (4, 4)

        # Call the slice_pad_for_subbatch function
        result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size)

        # Check if the output shape is correct
        self.assertEqual(result.shape, expected_output_shape)

        # Check if the number of slices is divisible by sub_batch_size
        self.assertEqual(result.shape[0] % sub_batch_size, 0)


class TestWriteWavFile(unittest.TestCase):

    def test_write_wav_file_z(self):
        # Generate some test audio data
        samplerate = 16000
        duration = 1  # 1 second
        t = np.linspace(0, duration, int(samplerate * duration), endpoint=False)
        x = np.sin(2 * np.pi * 440 * t)

        # Write the test audio data to a WAV file
        filename = "extras/test.wav"
        write_wav_file(filename, x, samplerate)

        # Read the written WAV file and check its contents
        with wave.open(filename, "rb") as wav_file:
            # Check the WAV file parameters
            self.assertEqual(wav_file.getnchannels(), 1)
            self.assertEqual(wav_file.getsampwidth(), 2)
            self.assertEqual(wav_file.getframerate(), samplerate)
            self.assertEqual(wav_file.getnframes(), len(x))

            # Read the audio samples from the WAV file
            data = wav_file.readframes(len(x))

        # Convert the audio sample byte string to a NumPy array and normalize it to the range [-1, 1]
        x_read = np.frombuffer(data, dtype=np.int16) / 32767.0

        # Check that the audio samples read from the WAV file are equal to the original audio samples
        np.testing.assert_allclose(x_read, x, atol=1e-4)

        # Delete the written WAV file
        os.remove(filename)


if __name__ == '__main__':
    unittest.main()