File size: 3,117 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import unittest

import torch

from training.preprocess.tacotron_stft import TacotronSTFT


class TestTacotronSTFT(unittest.TestCase):
    def setUp(self):
        torch.random.manual_seed(0)

        self.batch_size = 2
        self.seq_len = 100
        self.filter_length = 1024
        self.hop_length = 256
        self.win_length = 1024
        self.n_mel_channels = self.filter_length // 2 + 1
        self.sampling_rate = 22050
        self.mel_fmin = 0
        self.mel_fmax = 8000
        self.center = True

        self.model = TacotronSTFT(
            filter_length=self.filter_length,
            hop_length=self.hop_length,
            win_length=self.win_length,
            n_mel_channels=self.n_mel_channels,
            sampling_rate=self.sampling_rate,
            mel_fmin=self.mel_fmin,
            mel_fmax=self.mel_fmax,
            center=self.center,
        )

    def test_spectrogram(self):
        # Test the _spectrogram method
        y = torch.randn(self.batch_size, self.filter_length // 2)
        y = y / torch.max(torch.abs(y))

        spec = self.model._spectrogram(y)

        self.assertEqual(
            spec.shape,
            (self.batch_size, self.filter_length // 2 + 1, 6, self.batch_size),
        )

    def test_linear_spectrogram(self):
        # Test the linear_spectrogram method
        y = torch.randn(self.batch_size, self.filter_length // 2)
        y = y / torch.max(torch.abs(y))

        spec = self.model.linear_spectrogram(y)

        self.assertEqual(spec.shape, (self.batch_size, self.filter_length // 2 + 1, 6))

    def test_forward(self):
        # Test the forward method
        y = torch.randn(self.n_mel_channels, self.filter_length // 2)
        y = y / torch.max(torch.abs(y))

        spec, mel = self.model(y)

        self.assertEqual(
            spec.shape,
            (self.filter_length // 2 + 1, self.filter_length // 2 + 1, 6),
        )
        self.assertEqual(
            mel.shape,
            (self.filter_length // 2 + 1, self.filter_length // 2 + 1, 6),
        )

    def test_spectral_normalize_torch(self):
        # Test the spectral_normalize_torch method
        magnitudes = torch.randn(self.batch_size, self.n_mel_channels, self.seq_len)
        output = self.model.spectral_normalize_torch(magnitudes)
        self.assertEqual(
            output.shape, (self.batch_size, self.n_mel_channels, self.seq_len),
        )

    def test_dynamic_range_compression_torch(self):
        # Test the dynamic_range_compression_torch method
        x = torch.randn(self.batch_size, self.n_mel_channels, self.seq_len)
        output = self.model.dynamic_range_compression_torch(x)
        self.assertEqual(
            output.shape, (self.batch_size, self.n_mel_channels, self.seq_len),
        )

    def test_get_mel_from_wav(self):
        # Test the get_mel_from_wav method
        audio = torch.randn(44100)
        audio /= torch.max(torch.abs(audio))
        melspec = self.model.get_mel_from_wav(audio)
        self.assertEqual(melspec.shape, (self.n_mel_channels, 176))


if __name__ == "__main__":
    unittest.main()