PeechTTSv22050 / models /vocoder /univnet /tests /test_traced_generator.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.config import VocoderModelConfig
from models.helpers.tools import get_mask_from_lengths
from models.vocoder.univnet import Generator, TracedGenerator
class TestTracedUnivNet(unittest.TestCase):
def setUp(self):
self.batch_size = 3
self.in_length = 100
self.mel_channels = 80
self.model_config = VocoderModelConfig()
self.preprocess_config = PreprocessingConfig("english_only")
self.generator = Generator(self.model_config, self.preprocess_config)
self.example_inputs = (
torch.randn(
self.batch_size,
self.preprocess_config.stft.n_mel_channels,
self.in_length,
),
)
self.traced_generator = TracedGenerator(
self.generator
) # , self.example_inputs)
self.c = torch.randn(
self.batch_size,
self.preprocess_config.stft.n_mel_channels,
self.in_length,
)
self.mel_lens = torch.tensor([self.in_length] * self.batch_size)
def test_forward(self):
output = self.traced_generator(self.c, self.mel_lens)
# Assert the shape
expected_shape = (self.batch_size, 1, self.in_length * 256)
self.assertEqual(output.shape, expected_shape)
def test_forward_with_masked_c(self):
mel_lens = torch.tensor([self.in_length] * self.batch_size)
# Mask the input mel-spectrogram tensor
mel_mask = get_mask_from_lengths(mel_lens).unsqueeze(1)
c = self.c.masked_fill(mel_mask, self.traced_generator.mel_mask_value)
output = self.traced_generator(c, mel_lens)
# Assert the shape
expected_shape = (self.batch_size, 1, self.in_length * 256)
self.assertEqual(output.shape, expected_shape)