from pathlib import Path import unittest from torch import Tensor import torchaudio from voicefixer import Vocoder from training.datasets.hifi_libri_dataset import HifiLibriDataset, HifiLibriItem class TestHifiLibriDataset(unittest.TestCase): def setUp(self): self.cache_dir = "datasets_cache" self.dataset = HifiLibriDataset(cache_dir=self.cache_dir, cache=True) self.vocoder_vf = Vocoder(44100) def test_init(self): self.assertEqual(len(self.dataset.cutset), 129751) def test_get_cache_subdir_path(self): idx = 1234 expected_path = Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" self.assertEqual(self.dataset.get_cache_subdir_path(idx), expected_path) def test_get_cache_file_path(self): idx = 1234 expected_path = ( Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" / f"{idx}.pt" ) self.assertEqual(self.dataset.get_cache_file_path(idx), expected_path) def test_getitem(self): # Take the hifi items from the beginning of the dataset item = self.dataset[0] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "hifitts") # Convert mel spectrogram to waveform and save it to a file # NOTE: Vocoder expects the mel spectrogram to be prepared in a specific way # wav = self.vocoder_vf.forward(item.mel.permute((1, 0)).unsqueeze(0)) # wav_path = Path(f"results/{item.id}.wav") # torchaudio.save(str(wav_path), wav, 44100) # Check that the cache file is created cache_file = self.dataset.get_cache_file_path(0) self.assertTrue(cache_file.exists()) # Take the same id again to check if the cache is used item = self.dataset[0] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "hifitts") item = self.dataset[10] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "hifitts") # Check that the cache file is created cache_file = self.dataset.get_cache_file_path(10) self.assertTrue(cache_file.exists()) item = self.dataset[20] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "hifitts") # Take the libri items from the end of the dataset item = self.dataset[len(self.dataset) - 20] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "libritts") # Check that the cache file is created cache_file = self.dataset.get_cache_file_path(len(self.dataset) - 20) self.assertTrue(cache_file.exists()) item = self.dataset[len(self.dataset) - 10] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "libritts") item = self.dataset[len(self.dataset) - 5] self.assertIsInstance(item, HifiLibriItem) self.assertEqual(item.dataset_type, "libritts") def test_collate_fn(self): data = [self.dataset[0] for _ in range(10)] collated = self.dataset.collate_fn(data) self.assertIsInstance(collated, list) self.assertIsInstance(collated[0], list) # ids self.assertIsInstance(collated[1], list) # raw_texts self.assertIsInstance(collated[2], Tensor) # speakers self.assertIsInstance(collated[3], Tensor) # texts self.assertIsInstance(collated[4], Tensor) # src_lens self.assertIsInstance(collated[5], Tensor) # mels self.assertIsInstance(collated[6], Tensor) # pitches self.assertIsInstance(collated[7], list) # pitches_stat self.assertIsInstance(collated[8], Tensor) # mel_lens self.assertIsInstance(collated[9], Tensor) # langs self.assertIsInstance(collated[10], Tensor) # attn_priors self.assertIsInstance(collated[11], Tensor) # wavs self.assertIsInstance(collated[12], Tensor) # energy def test_include_libri(self): dataset_with_libri = HifiLibriDataset( cache_dir="datasets_cache", include_libri=True, ) dataset_without_libri = HifiLibriDataset( cache_dir="datasets_cache", include_libri=False, ) # Check that the dataset with LibriTTS is larger than the dataset without LibriTTS self.assertTrue(len(dataset_with_libri) > len(dataset_without_libri)) # Check that the dataset with LibriTTS includes items of type 'libritts' libri_item = dataset_with_libri[len(dataset_with_libri) - 10] self.assertIsInstance(libri_item, HifiLibriItem) self.assertEqual(libri_item.dataset_type, "libritts") # Check that the dataset without LibriTTS does not include items of type 'libritts' hifi_item = dataset_without_libri[len(dataset_without_libri) - 10] self.assertIsInstance(hifi_item, HifiLibriItem) self.assertEqual(hifi_item.dataset_type, "hifitts") def test_dur_filter(self): # Test with a duration of 0.2 self.assertFalse(self.dataset.dur_filter(0.2)) # Test with a duration of 1.0 self.assertTrue(self.dataset.dur_filter(1.0)) # Test with a duration of 2.0 self.assertTrue(self.dataset.dur_filter(2.0)) # Test with a duration of 30.0 self.assertFalse(self.dataset.dur_filter(30.0)) if __name__ == "__main__": unittest.main()