Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""audio_test.py""" | |
import unittest | |
import os | |
import numpy as np | |
import wave | |
import tempfile | |
from utils.audio import load_audio_file | |
from utils.audio import get_audio_file_info | |
from utils.audio import slice_padded_array | |
from utils.audio import slice_padded_array_for_subbatch | |
from utils.audio import write_wav_file | |
class TestLoadAudioFile(unittest.TestCase): | |
def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str: | |
n_samples = int(duration * fs) | |
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
temp_filename = temp_file.name | |
data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16) | |
with wave.open(temp_filename, 'wb') as f: | |
f.setnchannels(1) | |
f.setsampwidth(2) | |
f.setframerate(fs) | |
f.writeframes(data.tobytes()) | |
return temp_filename | |
def test_load_audio_file(self): | |
duration = 3.0 | |
fs = 16000 | |
temp_filename = self.create_temp_wav_file(duration, fs) | |
# Test load entire file | |
audio_data = load_audio_file(temp_filename, dtype=np.int16) | |
file_fs, n_frames, n_channels = get_audio_file_info(temp_filename) | |
self.assertEqual(len(audio_data), n_frames) | |
self.assertEqual(file_fs, fs) | |
self.assertEqual(n_channels, 1) | |
# Test load specific segment | |
seg_start_sec = 1.0 | |
seg_length_sec = 1.0 | |
audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16) | |
self.assertEqual(len(audio_data), int(seg_length_sec * fs)) | |
# Test unsupported file extension | |
with self.assertRaises(NotImplementedError): | |
load_audio_file("unsupported.xyz") | |
class TestSliceArray(unittest.TestCase): | |
def setUp(self): | |
self.x = np.random.randint(0, 10, size=(1, 10000)) | |
def test_without_padding(self): | |
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False) | |
self.assertEqual(sliced_x.shape, (199, 100)) | |
def test_with_padding(self): | |
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) | |
self.assertEqual(sliced_x.shape, (199, 100)) | |
def test_content(self): | |
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) | |
for i in range(sliced_x.shape[0] - 1): | |
np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten()) | |
# Test the last slice separately to account for potential padding | |
last_slice = sliced_x[-1, :] | |
last_slice_no_padding = self.x[:, -100:].flatten() | |
np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding) | |
class TestSlicePadForSubbatch(unittest.TestCase): | |
def test_slice_padded_array_for_subbatch(self): | |
input_array = np.random.randn(6, 10) | |
slice_length = 4 | |
slice_hop = 2 | |
pad = True | |
sub_batch_size = 4 | |
expected_output_shape = (4, 4) | |
# Call the slice_pad_for_subbatch function | |
result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size) | |
# Check if the output shape is correct | |
self.assertEqual(result.shape, expected_output_shape) | |
# Check if the number of slices is divisible by sub_batch_size | |
self.assertEqual(result.shape[0] % sub_batch_size, 0) | |
class TestWriteWavFile(unittest.TestCase): | |
def test_write_wav_file_z(self): | |
# Generate some test audio data | |
samplerate = 16000 | |
duration = 1 # 1 second | |
t = np.linspace(0, duration, int(samplerate * duration), endpoint=False) | |
x = np.sin(2 * np.pi * 440 * t) | |
# Write the test audio data to a WAV file | |
filename = "extras/test.wav" | |
write_wav_file(filename, x, samplerate) | |
# Read the written WAV file and check its contents | |
with wave.open(filename, "rb") as wav_file: | |
# Check the WAV file parameters | |
self.assertEqual(wav_file.getnchannels(), 1) | |
self.assertEqual(wav_file.getsampwidth(), 2) | |
self.assertEqual(wav_file.getframerate(), samplerate) | |
self.assertEqual(wav_file.getnframes(), len(x)) | |
# Read the audio samples from the WAV file | |
data = wav_file.readframes(len(x)) | |
# Convert the audio sample byte string to a NumPy array and normalize it to the range [-1, 1] | |
x_read = np.frombuffer(data, dtype=np.int16) / 32767.0 | |
# Check that the audio samples read from the WAV file are equal to the original audio samples | |
np.testing.assert_allclose(x_read, x, atol=1e-4) | |
# Delete the written WAV file | |
os.remove(filename) | |
if __name__ == '__main__': | |
unittest.main() | |