File size: 4,405 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import math
import unittest

import torch

from models.config import VocoderModelConfig
from models.vocoder.univnet import DiscriminatorP


class TestDiscriminatorP(unittest.TestCase):
    def setUp(self):
        self.batch_size = 2
        self.channels = 1
        self.time_steps = 100
        self.period = 10
        self.model_config = VocoderModelConfig()

        self.x = torch.randn(self.batch_size, self.channels, self.time_steps)
        self.model = DiscriminatorP(self.period, self.model_config)

    def test_forward(self):
        fmap, output = self.model(self.x)

        self.assertEqual(len(fmap), len(self.model.convs) + 1)

        # Assert the shape of the feature maps explicitly
        fmap_dims = [
            torch.Size([self.batch_size, 64, 4, self.period]),
            torch.Size([self.batch_size, 128, 2, self.period]),
            torch.Size([self.batch_size, 256, 1, self.period]),
            torch.Size([self.batch_size, 512, 1, self.period]),
            torch.Size([self.batch_size, 1024, 1, self.period]),
            torch.Size([self.batch_size, 1, 1, self.period]),
        ]

        for i in range(len(fmap)):
            self.assertEqual(fmap[i].shape, fmap_dims[i])

        # Assert the shape of the feature maps
        dim_2nd = 4
        for i in range(len(self.model_config.mpd.periods)):
            self.assertEqual(fmap[i].shape[0], self.batch_size)
            self.assertEqual(fmap[i].shape[1], 2 ** (i + 6))
            self.assertEqual(fmap[i].shape[2], dim_2nd)

            dim_2nd = math.ceil(dim_2nd / self.model_config.mpd.stride)

            self.assertEqual(fmap[i].shape[3], self.period)

        self.assertEqual(output.shape, (self.batch_size, self.period))

    def test_forward_with_padding(self):
        fmap, output = self.model(self.x)

        self.assertEqual(len(fmap), len(self.model.convs) + 1)

        # Assert the shape of the feature maps explicitly
        fmap_dims = [
            torch.Size([self.batch_size, 64, 4, self.period]),
            torch.Size([self.batch_size, 128, 2, self.period]),
            torch.Size([self.batch_size, 256, 1, self.period]),
            torch.Size([self.batch_size, 512, 1, self.period]),
            torch.Size([self.batch_size, 1024, 1, self.period]),
            torch.Size([self.batch_size, 1, 1, self.period]),
        ]

        for i in range(len(fmap)):
            self.assertEqual(fmap[i].shape, fmap_dims[i])

        # Assert the shape of the feature maps
        dim_2nd = 4
        for i in range(len(self.model_config.mpd.periods)):
            self.assertEqual(fmap[i].shape[0], self.batch_size)
            self.assertEqual(fmap[i].shape[1], 2 ** (i + 6))
            self.assertEqual(fmap[i].shape[2], dim_2nd)

            dim_2nd = math.ceil(dim_2nd / self.model_config.mpd.stride)

            self.assertEqual(fmap[i].shape[3], self.period)

        self.assertEqual(output.shape, (self.batch_size, self.period))

    def test_forward_with_different_period(self):
        model = DiscriminatorP(self.period, self.model_config)
        x = torch.randn(self.batch_size, self.channels, self.time_steps - 1)

        model.period = 5
        fmap, output = model(x)

        self.assertEqual(len(fmap), len(model.convs) + 1)

        # Assert the shape of the feature maps explicitly
        fmap_dims = [
            torch.Size([self.batch_size, 64, 7, model.period]),
            torch.Size([self.batch_size, 128, 3, model.period]),
            torch.Size([self.batch_size, 256, 1, model.period]),
            torch.Size([self.batch_size, 512, 1, model.period]),
            torch.Size([self.batch_size, 1024, 1, model.period]),
            torch.Size([self.batch_size, 1, 1, model.period]),
        ]

        for i in range(len(fmap)):
            self.assertEqual(fmap[i].shape, fmap_dims[i])

        # Assert the shape of the feature maps
        dim_2nd = 7
        for i in range(len(self.model_config.mpd.periods)):
            self.assertEqual(fmap[i].shape[0], self.batch_size)
            self.assertEqual(fmap[i].shape[1], 2 ** (i + 6))
            self.assertEqual(fmap[i].shape[2], dim_2nd)

            dim_2nd = math.ceil(dim_2nd / self.model_config.mpd.stride)

            self.assertEqual(fmap[i].shape[3], model.period)

        self.assertEqual(output.shape, (self.batch_size, model.period))


if __name__ == "__main__":
    unittest.main()