# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """audio_test.py""" import unittest import os import numpy as np import wave import tempfile from utils.audio import load_audio_file from utils.audio import get_audio_file_info from utils.audio import slice_padded_array from utils.audio import slice_padded_array_for_subbatch from utils.audio import write_wav_file class TestLoadAudioFile(unittest.TestCase): def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str: n_samples = int(duration * fs) temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) temp_filename = temp_file.name data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16) with wave.open(temp_filename, 'wb') as f: f.setnchannels(1) f.setsampwidth(2) f.setframerate(fs) f.writeframes(data.tobytes()) return temp_filename def test_load_audio_file(self): duration = 3.0 fs = 16000 temp_filename = self.create_temp_wav_file(duration, fs) # Test load entire file audio_data = load_audio_file(temp_filename, dtype=np.int16) file_fs, n_frames, n_channels = get_audio_file_info(temp_filename) self.assertEqual(len(audio_data), n_frames) self.assertEqual(file_fs, fs) self.assertEqual(n_channels, 1) # Test load specific segment seg_start_sec = 1.0 seg_length_sec = 1.0 audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16) self.assertEqual(len(audio_data), int(seg_length_sec * fs)) # Test unsupported file extension with self.assertRaises(NotImplementedError): load_audio_file("unsupported.xyz") class TestSliceArray(unittest.TestCase): def setUp(self): self.x = np.random.randint(0, 10, size=(1, 10000)) def test_without_padding(self): sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False) self.assertEqual(sliced_x.shape, (199, 100)) def test_with_padding(self): sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) self.assertEqual(sliced_x.shape, (199, 100)) def test_content(self): sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) for i in range(sliced_x.shape[0] - 1): np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten()) # Test the last slice separately to account for potential padding last_slice = sliced_x[-1, :] last_slice_no_padding = self.x[:, -100:].flatten() np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding) class TestSlicePadForSubbatch(unittest.TestCase): def test_slice_padded_array_for_subbatch(self): input_array = np.random.randn(6, 10) slice_length = 4 slice_hop = 2 pad = True sub_batch_size = 4 expected_output_shape = (4, 4) # Call the slice_pad_for_subbatch function result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size) # Check if the output shape is correct self.assertEqual(result.shape, expected_output_shape) # Check if the number of slices is divisible by sub_batch_size self.assertEqual(result.shape[0] % sub_batch_size, 0) class TestWriteWavFile(unittest.TestCase): def test_write_wav_file_z(self): # Generate some test audio data samplerate = 16000 duration = 1 # 1 second t = np.linspace(0, duration, int(samplerate * duration), endpoint=False) x = np.sin(2 * np.pi * 440 * t) # Write the test audio data to a WAV file filename = "extras/test.wav" write_wav_file(filename, x, samplerate) # Read the written WAV file and check its contents with wave.open(filename, "rb") as wav_file: # Check the WAV file parameters self.assertEqual(wav_file.getnchannels(), 1) self.assertEqual(wav_file.getsampwidth(), 2) self.assertEqual(wav_file.getframerate(), samplerate) self.assertEqual(wav_file.getnframes(), len(x)) # Read the audio samples from the WAV file data = wav_file.readframes(len(x)) # Convert the audio sample byte string to a NumPy array and normalize it to the range [-1, 1] x_read = np.frombuffer(data, dtype=np.int16) / 32767.0 # Check that the audio samples read from the WAV file are equal to the original audio samples np.testing.assert_allclose(x_read, x, atol=1e-4) # Delete the written WAV file os.remove(filename) if __name__ == '__main__': unittest.main()