File size: 2,408 Bytes
affcd23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import unittest
from tqdm import tqdm
from hw_asr.collate_fn.collate import collate_fn
from hw_asr.datasets import LibrispeechDataset
from hw_asr.tests.utils import clear_log_folder_after_use
from hw_asr.utils.object_loading import get_dataloaders
from hw_asr.utils.parse_config import ConfigParser
class TestDataloader(unittest.TestCase):
def test_collate_fn(self):
config_parser = ConfigParser.get_test_configs()
with clear_log_folder_after_use(config_parser):
ds = LibrispeechDataset(
"dev-clean", text_encoder=config_parser.get_text_encoder(),
config_parser=config_parser
)
batch_size = 3
batch = collate_fn([ds[i] for i in range(batch_size)])
self.assertIn("spectrogram", batch) # torch.tensor
batch_size_dim, feature_length_dim, time_dim = batch["spectrogram"].shape
self.assertEqual(batch_size_dim, batch_size)
self.assertEqual(feature_length_dim, 128)
self.assertIn("text_encoded", batch) # [int] torch.tensor
# joined and padded indexes representation of transcriptions
batch_size_dim, text_length_dim = batch["text_encoded"].shape
self.assertEqual(batch_size_dim, batch_size)
self.assertIn("text_encoded_length", batch) # [int] torch.tensor
# contains lengths of each text entry
self.assertEqual(len(batch["text_encoded_length"].shape), 1)
batch_size_dim = batch["text_encoded_length"].shape[0]
self.assertEqual(batch_size_dim, batch_size)
self.assertIn("text", batch) # List[str]
# simple list of initial normalized texts
batch_size_dim = len(batch["text"])
self.assertEqual(batch_size_dim, batch_size)
def test_dataloaders(self):
_TOTAL_ITERATIONS = 10
config_parser = ConfigParser.get_test_configs()
with clear_log_folder_after_use(config_parser):
dataloaders = get_dataloaders(config_parser, config_parser.get_text_encoder())
for part in ["train", "val"]:
dl = dataloaders[part]
for i, batch in tqdm(enumerate(iter(dl)), total=_TOTAL_ITERATIONS,
desc=f"Iterating over {part}"):
if i >= _TOTAL_ITERATIONS:
break
|