File size: 3,890 Bytes
affcd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import unittest

import torch

from hw_asr.datasets import LibrispeechDataset, CustomDirAudioDataset, CustomAudioDataset
from hw_asr.tests.utils import clear_log_folder_after_use
from hw_asr.text_encoder.ctc_char_text_encoder import CTCCharTextEncoder
from hw_asr.utils import ROOT_PATH
from hw_asr.utils.parse_config import ConfigParser


class TestDataset(unittest.TestCase):
    def test_librispeech(self):
        config_parser = ConfigParser.get_test_configs()
        with clear_log_folder_after_use(config_parser):
            ds = LibrispeechDataset(
                "dev-clean",
                text_encoder=config_parser.get_text_encoder(),
                config_parser=config_parser,
                max_text_length=140,
                max_audio_length=13,
                limit=10,
            )
            self._assert_training_example_is_good(ds[0])

    def test_custom_dir_dataset(self):
        config_parser = ConfigParser.get_test_configs()
        with clear_log_folder_after_use(config_parser):
            audio_dir = str(ROOT_PATH / "test_data" / "audio")
            transc_dir = str(ROOT_PATH / "test_data" / "transcriptions")

            ds = CustomDirAudioDataset(
                audio_dir,
                transc_dir,
                text_encoder=config_parser.get_text_encoder(),
                config_parser=config_parser,
                limit=10,
                max_audio_length=8,
                max_text_length=130,
            )
            self._assert_training_example_is_good(ds[0])

    def test_custom_dataset(self):
        config_parser = ConfigParser.get_test_configs()
        with clear_log_folder_after_use(config_parser):
            audio_path = ROOT_PATH / "test_data" / "audio"
            transc_path = ROOT_PATH / "test_data" / "transcriptions"
            with (transc_path / "84-121550-0000.txt").open() as f:
                transcription = f.read().strip()
            data = [
                {
                    "path": str(audio_path / "84-121550-0001.flac"),
                },
                {
                    "path": str(audio_path / "84-121550-0000.flac"),
                    "text": transcription
                }
            ]

            ds = CustomAudioDataset(
                data=data,
                text_encoder=config_parser.get_text_encoder(),
                config_parser=config_parser,
            )
            self._assert_training_example_is_good(ds[0], contains_text=False)
            self._assert_training_example_is_good(ds[1])

    def _assert_training_example_is_good(self, training_example: dict, contains_text=True):

        for field, expected_type in [
            ("audio", torch.Tensor),
            ("spectrogram", torch.Tensor),
            ("duration", float),
            ("audio_path", str),
            ("text", str),
            ("text_encoded", torch.Tensor)
        ]:
            self.assertIn(field, training_example, f"Error during checking field {field}")
            self.assertIsInstance(training_example[field], expected_type,
                                  f"Error during checking field {field}")

        # check waveform dimensions
        batch_dim, audio_dim, = training_example["audio"].size()
        self.assertEqual(batch_dim, 1)
        self.assertGreater(audio_dim, 1)

        # check spectrogram dimensions
        batch_dim, freq_dim, time_dim = training_example["spectrogram"].size()
        self.assertEqual(batch_dim, 1)
        self.assertEqual(freq_dim, 128)
        self.assertGreater(time_dim, 1)

        # check text tensor dimensions
        batch_dim, length_dim, = training_example["text_encoded"].size()
        self.assertEqual(batch_dim, 1)
        if contains_text:
            self.assertGreater(length_dim, 1)
        else:
            self.assertEqual(length_dim, 0)
            self.assertEqual(training_example["text"], "")