import os import unittest import torch from torch.utils.data import DataLoader from training.datasets import LibriTTSDatasetAcoustic class TestLibriTTSDatasetAcoustic(unittest.TestCase): def setUp(self): self.batch_size = 2 self.lang = "en" self.download = False self.dataset = LibriTTSDatasetAcoustic( root="datasets_cache/LIBRITTS", lang=self.lang, download=self.download, ) # def test_len(self): # self.assertEqual(len(self.dataset), 33236) # def test_getitem(self): # sample = self.dataset[0] # self.assertEqual(sample["id"], "1034_121119_000001_000001") # self.assertEqual(sample["speaker"], 1034) # self.assertEqual(sample["text"].shape, torch.Size([6])) # self.assertEqual(sample["mel"].shape, torch.Size([100, 58])) # self.assertEqual(sample["pitch"].shape, torch.Size([58])) # self.assertEqual(sample["raw_text"], "The Law.") # self.assertEqual(sample["normalized_text"], "The Law.") # self.assertFalse(sample["pitch_is_normalized"]) # self.assertEqual(sample["lang"], 3) # self.assertEqual(sample["attn_prior"].shape, torch.Size([6, 58])) # self.assertEqual(sample["wav"].shape, torch.Size([1, 14994])) # self.assertEqual(sample["energy"].shape, torch.Size([1, 957])) # def test_cache_item(self): # dataset = LibriTTSDatasetAcoustic( # cache=True, # ) # idxs = [0, 1, 1000, 1002, 2010] # for idx in idxs: # # Get a sample from the dataset # sample = dataset[idx] # cache_subdir_path = os.path.join(dataset.cache_dir, dataset.cache_subdir(idx)) # cache_file = os.path.join(cache_subdir_path, f"{idx}.pt") # # Check if the data is in the cache # self.assertTrue(os.path.exists(cache_file)) # # Load the data from the cache file # cached_sample = torch.load(cache_file) # # Check if the cached data is the same as the original data # for key in sample: # if torch.is_tensor(sample[key]): # self.assertTrue(torch.all(sample[key] == cached_sample[key])) # else: # self.assertEqual(sample[key], cached_sample[key]) # def test_collate_fn(self): # data = [ # self.dataset[0], # self.dataset[2], # ] # # Call the collate_fn method # result = self.dataset.collate_fn(data) # # Check the output # self.assertEqual(len(result), 13) # # Check that all the batches are the same size # for batch in result: # self.assertEqual(len(batch), 2) def test_normalize_pitch(self): pitches = [ torch.tensor([100.0, 200.0, 300.0]), torch.tensor([150.0, 250.0, 350.0]), ] result = self.dataset.normalize_pitch(pitches) expected_output = (100.0, 350.0, 225.0, 93.54143524169922) self.assertEqual(result, expected_output) # def test_dataloader(self): # # Create a DataLoader from the dataset # dataloader = DataLoader( # self.dataset, # batch_size=self.batch_size, # shuffle=False, # collate_fn=self.dataset.collate_fn, # ) # iter_dataloader = iter(dataloader) # # Iterate over the DataLoader and check the output # for _, items in enumerate([next(iter_dataloader), next(iter_dataloader)]): # # items = batch[0] # # Check the batch size # self.assertEqual(len(items), 13) # for it in items: # self.assertEqual(len(it), self.batch_size) if __name__ == "__main__": unittest.main()