File size: 1,920 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import unittest

import torch

from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.config import VocoderModelConfig
from models.helpers.tools import get_mask_from_lengths
from models.vocoder.univnet import Generator, TracedGenerator


class TestTracedUnivNet(unittest.TestCase):
    def setUp(self):
        self.batch_size = 3
        self.in_length = 100
        self.mel_channels = 80

        self.model_config = VocoderModelConfig()
        self.preprocess_config = PreprocessingConfig("english_only")

        self.generator = Generator(self.model_config, self.preprocess_config)

        self.example_inputs = (
            torch.randn(
                self.batch_size,
                self.preprocess_config.stft.n_mel_channels,
                self.in_length,
            ),
        )

        self.traced_generator = TracedGenerator(
            self.generator
        )  # , self.example_inputs)

        self.c = torch.randn(
            self.batch_size,
            self.preprocess_config.stft.n_mel_channels,
            self.in_length,
        )

        self.mel_lens = torch.tensor([self.in_length] * self.batch_size)

    def test_forward(self):
        output = self.traced_generator(self.c, self.mel_lens)

        # Assert the shape
        expected_shape = (self.batch_size, 1, self.in_length * 256)
        self.assertEqual(output.shape, expected_shape)

    def test_forward_with_masked_c(self):
        mel_lens = torch.tensor([self.in_length] * self.batch_size)

        # Mask the input mel-spectrogram tensor
        mel_mask = get_mask_from_lengths(mel_lens).unsqueeze(1)
        c = self.c.masked_fill(mel_mask, self.traced_generator.mel_mask_value)

        output = self.traced_generator(c, mel_lens)

        # Assert the shape
        expected_shape = (self.batch_size, 1, self.in_length * 256)
        self.assertEqual(output.shape, expected_shape)