Spaces:
Sleeping
Sleeping
import unittest | |
import torch | |
from torch import nn | |
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig | |
from models.config import VocoderModelConfig | |
from models.vocoder.univnet.generator import Generator | |
class TestUnivNet(unittest.TestCase): | |
def setUp(self): | |
self.batch_size = 3 | |
self.in_length = 100 | |
self.model_config = VocoderModelConfig() | |
self.preprocess_config = PreprocessingConfig("english_only") | |
self.generator = Generator(self.model_config, self.preprocess_config) | |
self.c = torch.randn( | |
self.batch_size, | |
self.preprocess_config.stft.n_mel_channels, | |
self.in_length, | |
) | |
def test_forward(self): | |
output = self.generator(self.c) | |
# Assert the shape | |
expected_shape = (self.batch_size, 1, self.in_length * 256) | |
self.assertEqual(output.shape, expected_shape) | |
def test_generator_inference_output_shape(self): | |
mel_lens = torch.tensor([self.in_length] * self.batch_size) | |
output = self.generator.infer(self.c, mel_lens) | |
# Assert the shape | |
expected_shape = ( | |
self.batch_size, | |
1, | |
self.in_length * self.preprocess_config.stft.hop_length, | |
) | |
self.assertEqual(output.shape, expected_shape) | |
def test_eval(self): | |
generator = Generator( | |
self.model_config, | |
self.preprocess_config, | |
) | |
generator.eval(inference=True) | |
for module in generator.modules(): | |
if isinstance(module, nn.Conv1d): | |
self.assertFalse(hasattr(module, "weight_g")) | |
self.assertFalse(hasattr(module, "weight_v")) | |
def test_remove_weight_norm(self): | |
generator = Generator( | |
self.model_config, | |
self.preprocess_config, | |
) | |
generator.remove_weight_norm() | |
for module in generator.modules(): | |
if isinstance(module, nn.Conv1d): | |
self.assertFalse(hasattr(module, "weight_g")) | |
self.assertFalse(hasattr(module, "weight_v")) | |
if __name__ == "__main__": | |
unittest.main() | |