Spaces:
Sleeping
Sleeping
File size: 5,136 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""audio_test.py"""
import unittest
import os
import numpy as np
import wave
import tempfile
from utils.audio import load_audio_file
from utils.audio import get_audio_file_info
from utils.audio import slice_padded_array
from utils.audio import slice_padded_array_for_subbatch
from utils.audio import write_wav_file
class TestLoadAudioFile(unittest.TestCase):
def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str:
n_samples = int(duration * fs)
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
temp_filename = temp_file.name
data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16)
with wave.open(temp_filename, 'wb') as f:
f.setnchannels(1)
f.setsampwidth(2)
f.setframerate(fs)
f.writeframes(data.tobytes())
return temp_filename
def test_load_audio_file(self):
duration = 3.0
fs = 16000
temp_filename = self.create_temp_wav_file(duration, fs)
# Test load entire file
audio_data = load_audio_file(temp_filename, dtype=np.int16)
file_fs, n_frames, n_channels = get_audio_file_info(temp_filename)
self.assertEqual(len(audio_data), n_frames)
self.assertEqual(file_fs, fs)
self.assertEqual(n_channels, 1)
# Test load specific segment
seg_start_sec = 1.0
seg_length_sec = 1.0
audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16)
self.assertEqual(len(audio_data), int(seg_length_sec * fs))
# Test unsupported file extension
with self.assertRaises(NotImplementedError):
load_audio_file("unsupported.xyz")
class TestSliceArray(unittest.TestCase):
def setUp(self):
self.x = np.random.randint(0, 10, size=(1, 10000))
def test_without_padding(self):
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False)
self.assertEqual(sliced_x.shape, (199, 100))
def test_with_padding(self):
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True)
self.assertEqual(sliced_x.shape, (199, 100))
def test_content(self):
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True)
for i in range(sliced_x.shape[0] - 1):
np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten())
# Test the last slice separately to account for potential padding
last_slice = sliced_x[-1, :]
last_slice_no_padding = self.x[:, -100:].flatten()
np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding)
class TestSlicePadForSubbatch(unittest.TestCase):
def test_slice_padded_array_for_subbatch(self):
input_array = np.random.randn(6, 10)
slice_length = 4
slice_hop = 2
pad = True
sub_batch_size = 4
expected_output_shape = (4, 4)
# Call the slice_pad_for_subbatch function
result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size)
# Check if the output shape is correct
self.assertEqual(result.shape, expected_output_shape)
# Check if the number of slices is divisible by sub_batch_size
self.assertEqual(result.shape[0] % sub_batch_size, 0)
class TestWriteWavFile(unittest.TestCase):
def test_write_wav_file_z(self):
# Generate some test audio data
samplerate = 16000
duration = 1 # 1 second
t = np.linspace(0, duration, int(samplerate * duration), endpoint=False)
x = np.sin(2 * np.pi * 440 * t)
# Write the test audio data to a WAV file
filename = "extras/test.wav"
write_wav_file(filename, x, samplerate)
# Read the written WAV file and check its contents
with wave.open(filename, "rb") as wav_file:
# Check the WAV file parameters
self.assertEqual(wav_file.getnchannels(), 1)
self.assertEqual(wav_file.getsampwidth(), 2)
self.assertEqual(wav_file.getframerate(), samplerate)
self.assertEqual(wav_file.getnframes(), len(x))
# Read the audio samples from the WAV file
data = wav_file.readframes(len(x))
# Convert the audio sample byte string to a NumPy array and normalize it to the range [-1, 1]
x_read = np.frombuffer(data, dtype=np.int16) / 32767.0
# Check that the audio samples read from the WAV file are equal to the original audio samples
np.testing.assert_allclose(x_read, x, atol=1e-4)
# Delete the written WAV file
os.remove(filename)
if __name__ == '__main__':
unittest.main()
|