File size: 2,167 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import unittest

import torch
from torch import nn

from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.config import VocoderModelConfig
from models.vocoder.univnet.generator import Generator


class TestUnivNet(unittest.TestCase):
    def setUp(self):
        self.batch_size = 3
        self.in_length = 100

        self.model_config = VocoderModelConfig()
        self.preprocess_config = PreprocessingConfig("english_only")

        self.generator = Generator(self.model_config, self.preprocess_config)

        self.c = torch.randn(
            self.batch_size,
            self.preprocess_config.stft.n_mel_channels,
            self.in_length,
        )

    def test_forward(self):
        output = self.generator(self.c)

        # Assert the shape
        expected_shape = (self.batch_size, 1, self.in_length * 256)
        self.assertEqual(output.shape, expected_shape)

    def test_generator_inference_output_shape(self):
        mel_lens = torch.tensor([self.in_length] * self.batch_size)

        output = self.generator.infer(self.c, mel_lens)

        # Assert the shape
        expected_shape = (
            self.batch_size,
            1,
            self.in_length * self.preprocess_config.stft.hop_length,
        )
        self.assertEqual(output.shape, expected_shape)

    def test_eval(self):
        generator = Generator(
            self.model_config,
            self.preprocess_config,
        )

        generator.eval(inference=True)
        for module in generator.modules():
            if isinstance(module, nn.Conv1d):
                self.assertFalse(hasattr(module, "weight_g"))
                self.assertFalse(hasattr(module, "weight_v"))

    def test_remove_weight_norm(self):
        generator = Generator(
            self.model_config,
            self.preprocess_config,
        )

        generator.remove_weight_norm()
        for module in generator.modules():
            if isinstance(module, nn.Conv1d):
                self.assertFalse(hasattr(module, "weight_g"))
                self.assertFalse(hasattr(module, "weight_v"))


if __name__ == "__main__":
    unittest.main()