|
import unittest |
|
|
|
import torch |
|
|
|
from hw_asr.datasets import LibrispeechDataset, CustomDirAudioDataset, CustomAudioDataset |
|
from hw_asr.tests.utils import clear_log_folder_after_use |
|
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder |
|
from hw_asr.utils import ROOT_PATH |
|
from hw_asr.utils.parse_config import ConfigParser |
|
|
|
|
|
class TestDataset(unittest.TestCase): |
|
def test_librispeech(self): |
|
config_parser = ConfigParser.get_test_configs() |
|
with clear_log_folder_after_use(config_parser): |
|
ds = LibrispeechDataset( |
|
"dev-clean", |
|
text_encoder=config_parser.get_text_encoder(), |
|
config_parser=config_parser, |
|
max_text_length=140, |
|
max_audio_length=13, |
|
limit=10, |
|
) |
|
self._assert_training_example_is_good(ds[0]) |
|
|
|
def test_custom_dir_dataset(self): |
|
config_parser = ConfigParser.get_test_configs() |
|
with clear_log_folder_after_use(config_parser): |
|
audio_dir = str(ROOT_PATH / "test_data" / "audio") |
|
transc_dir = str(ROOT_PATH / "test_data" / "transcriptions") |
|
|
|
ds = CustomDirAudioDataset( |
|
audio_dir, |
|
transc_dir, |
|
text_encoder=config_parser.get_text_encoder(), |
|
config_parser=config_parser, |
|
limit=10, |
|
max_audio_length=8, |
|
max_text_length=130, |
|
) |
|
self._assert_training_example_is_good(ds[0]) |
|
|
|
def test_custom_dataset(self): |
|
config_parser = ConfigParser.get_test_configs() |
|
with clear_log_folder_after_use(config_parser): |
|
audio_path = ROOT_PATH / "test_data" / "audio" |
|
transc_path = ROOT_PATH / "test_data" / "transcriptions" |
|
with (transc_path / "84-121550-0000.txt").open() as f: |
|
transcription = f.read().strip() |
|
data = [ |
|
{ |
|
"path": str(audio_path / "84-121550-0001.flac"), |
|
}, |
|
{ |
|
"path": str(audio_path / "84-121550-0000.flac"), |
|
"text": transcription |
|
} |
|
] |
|
|
|
ds = CustomAudioDataset( |
|
data=data, |
|
text_encoder=config_parser.get_text_encoder(), |
|
config_parser=config_parser, |
|
) |
|
self._assert_training_example_is_good(ds[0], contains_text=False) |
|
self._assert_training_example_is_good(ds[1]) |
|
|
|
def _assert_training_example_is_good(self, training_example: dict, contains_text=True): |
|
|
|
for field, expected_type in [ |
|
("audio", torch.Tensor), |
|
("spectrogram", torch.Tensor), |
|
("duration", float), |
|
("audio_path", str), |
|
("text", str), |
|
("text_encoded", torch.Tensor) |
|
]: |
|
self.assertIn(field, training_example, f"Error during checking field {field}") |
|
self.assertIsInstance(training_example[field], expected_type, |
|
f"Error during checking field {field}") |
|
|
|
|
|
batch_dim, audio_dim, = training_example["audio"].size() |
|
self.assertEqual(batch_dim, 1) |
|
self.assertGreater(audio_dim, 1) |
|
|
|
|
|
batch_dim, freq_dim, time_dim = training_example["spectrogram"].size() |
|
self.assertEqual(batch_dim, 1) |
|
self.assertEqual(freq_dim, 128) |
|
self.assertGreater(time_dim, 1) |
|
|
|
|
|
batch_dim, length_dim, = training_example["text_encoded"].size() |
|
self.assertEqual(batch_dim, 1) |
|
if contains_text: |
|
self.assertGreater(length_dim, 1) |
|
else: |
|
self.assertEqual(length_dim, 0) |
|
self.assertEqual(training_example["text"], "") |
|
|