File size: 4,666 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import unittest

import torch

from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.config import VocoderModelConfig
from models.vocoder.univnet import Discriminator, Generator


# One of the most important test case for univnet
# Integration test
class TestDiscriminator(unittest.TestCase):
    def setUp(self):
        self.model_config = VocoderModelConfig()
        self.preprocess_config = PreprocessingConfig("english_only")

        self.generator = Generator(self.model_config, self.preprocess_config)

        self.model = Discriminator(self.model_config)

        self.batch_size = 1
        self.in_length = 100

        self.c = torch.randn(
            self.batch_size,
            self.preprocess_config.stft.n_mel_channels,
            self.in_length,
        )

    def test_forward(self):
        # Test the forward pass of the Discriminator class
        x = self.generator(self.c)

        output = self.model(x)

        self.assertEqual(len(output), 2)

        # Assert MRD length
        self.assertEqual(len(output[0]), 3)

        # Assert MPD length
        self.assertEqual(len(output[1]), 5)

        # Test MRD output
        # output_mrd = output[0]

        # fmap_mrd_dims = [
        #     [
        #         torch.Size([32, 1, 513]),
        #         torch.Size([32, 1, 257]),
        #         torch.Size([32, 1, 129]),
        #         torch.Size([32, 1, 65]),
        #         torch.Size([32, 1, 65]),
        #         torch.Size([1, 1, 65]),
        #     ],
        #     [
        #         torch.Size([32, 1, 1025]),
        #         torch.Size([32, 1, 513]),
        #         torch.Size([32, 1, 257]),
        #         torch.Size([32, 1, 129]),
        #         torch.Size([32, 1, 129]),
        #         torch.Size([1, 1, 129]),
        #     ],
        #     [
        #         torch.Size([32, 1, 257]),
        #         torch.Size([32, 1, 129]),
        #         torch.Size([32, 1, 65]),
        #         torch.Size([32, 1, 33]),
        #         torch.Size([32, 1, 33]),
        #         torch.Size([1, 1, 33]),
        #     ],
        # ]

        # for key in range(len(output[0])):
        #     fmap = output_mrd[key][0]
        #     x = output_mrd[key][1]

        #     fmap_dims = fmap_mrd_dims[key]

        #     # Assert the shape of the feature maps
        #     for i, fmap_ in enumerate(fmap):
        #         # Assert the feature map shape explicitly
        #         self.assertEqual(fmap_.shape, fmap_dims[i])

        # # Test MPD output
        # output_mpd = output[1]

        # fmap_mpd_dims = [
        #     [
        #         torch.Size([1, 64, 4267, 2]),
        #         torch.Size([1, 128, 1423, 2]),
        #         torch.Size([1, 256, 475, 2]),
        #         torch.Size([1, 512, 159, 2]),
        #         torch.Size([1, 1024, 159, 2]),
        #         torch.Size([1, 1, 159, 2]),
        #     ],
        #     [
        #         torch.Size([1, 64, 2845, 3]),
        #         torch.Size([1, 128, 949, 3]),
        #         torch.Size([1, 256, 317, 3]),
        #         torch.Size([1, 512, 106, 3]),
        #         torch.Size([1, 1024, 106, 3]),
        #         torch.Size([1, 1, 106, 3]),
        #     ],
        #     [
        #         torch.Size([1, 64, 1707, 5]),
        #         torch.Size([1, 128, 569, 5]),
        #         torch.Size([1, 256, 190, 5]),
        #         torch.Size([1, 512, 64, 5]),
        #         torch.Size([1, 1024, 64, 5]),
        #         torch.Size([1, 1, 64, 5]),
        #     ],
        #     [
        #         torch.Size([1, 64, 1220, 7]),
        #         torch.Size([1, 128, 407, 7]),
        #         torch.Size([1, 256, 136, 7]),
        #         torch.Size([1, 512, 46, 7]),
        #         torch.Size([1, 1024, 46, 7]),
        #         torch.Size([1, 1, 46, 7]),
        #     ],
        #     [
        #         torch.Size([1, 64, 776, 11]),
        #         torch.Size([1, 128, 259, 11]),
        #         torch.Size([1, 256, 87, 11]),
        #         torch.Size([1, 512, 29, 11]),
        #         torch.Size([1, 1024, 29, 11]),
        #         torch.Size([1, 1, 29, 11]),
        #     ],
        # ]

        # for key in range(len(output[1])):
        #     fmap = output_mpd[key][0]
        #     x = output_mpd[key][1]

        #     fmap_dims = fmap_mpd_dims[key]

        #     # Assert the shape of the feature maps
        #     for i, fmap in enumerate(fmap):
        #         # Assert the feature map shape explicitly
        #         self.assertEqual(fmap.shape, fmap_dims[i])


if __name__ == "__main__":
    unittest.main()