tgritsaev's picture
Upload 198 files
affcd23 verified
import unittest
import torch
from hw_asr.datasets import LibrispeechDataset, CustomDirAudioDataset, CustomAudioDataset
from hw_asr.tests.utils import clear_log_folder_after_use
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder
from hw_asr.utils import ROOT_PATH
from hw_asr.utils.parse_config import ConfigParser
class TestDataset(unittest.TestCase):
def test_librispeech(self):
config_parser = ConfigParser.get_test_configs()
with clear_log_folder_after_use(config_parser):
ds = LibrispeechDataset(
"dev-clean",
text_encoder=config_parser.get_text_encoder(),
config_parser=config_parser,
max_text_length=140,
max_audio_length=13,
limit=10,
)
self._assert_training_example_is_good(ds[0])
def test_custom_dir_dataset(self):
config_parser = ConfigParser.get_test_configs()
with clear_log_folder_after_use(config_parser):
audio_dir = str(ROOT_PATH / "test_data" / "audio")
transc_dir = str(ROOT_PATH / "test_data" / "transcriptions")
ds = CustomDirAudioDataset(
audio_dir,
transc_dir,
text_encoder=config_parser.get_text_encoder(),
config_parser=config_parser,
limit=10,
max_audio_length=8,
max_text_length=130,
)
self._assert_training_example_is_good(ds[0])
def test_custom_dataset(self):
config_parser = ConfigParser.get_test_configs()
with clear_log_folder_after_use(config_parser):
audio_path = ROOT_PATH / "test_data" / "audio"
transc_path = ROOT_PATH / "test_data" / "transcriptions"
with (transc_path / "84-121550-0000.txt").open() as f:
transcription = f.read().strip()
data = [
{
"path": str(audio_path / "84-121550-0001.flac"),
},
{
"path": str(audio_path / "84-121550-0000.flac"),
"text": transcription
}
]
ds = CustomAudioDataset(
data=data,
text_encoder=config_parser.get_text_encoder(),
config_parser=config_parser,
)
self._assert_training_example_is_good(ds[0], contains_text=False)
self._assert_training_example_is_good(ds[1])
def _assert_training_example_is_good(self, training_example: dict, contains_text=True):
for field, expected_type in [
("audio", torch.Tensor),
("spectrogram", torch.Tensor),
("duration", float),
("audio_path", str),
("text", str),
("text_encoded", torch.Tensor)
]:
self.assertIn(field, training_example, f"Error during checking field {field}")
self.assertIsInstance(training_example[field], expected_type,
f"Error during checking field {field}")
# check waveform dimensions
batch_dim, audio_dim, = training_example["audio"].size()
self.assertEqual(batch_dim, 1)
self.assertGreater(audio_dim, 1)
# check spectrogram dimensions
batch_dim, freq_dim, time_dim = training_example["spectrogram"].size()
self.assertEqual(batch_dim, 1)
self.assertEqual(freq_dim, 128)
self.assertGreater(time_dim, 1)
# check text tensor dimensions
batch_dim, length_dim, = training_example["text_encoded"].size()
self.assertEqual(batch_dim, 1)
if contains_text:
self.assertGreater(length_dim, 1)
else:
self.assertEqual(length_dim, 0)
self.assertEqual(training_example["text"], "")