PeechTTSv22050 / models /vocoder /univnet /tests /test_discriminator_r.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.config import VocoderModelConfig
from models.vocoder.univnet.discriminator_r import DiscriminatorR
class TestDiscriminatorR(unittest.TestCase):
def setUp(self):
self.resolution = (1024, 256, 1024)
self.model_config = VocoderModelConfig()
self.model = DiscriminatorR(self.resolution, self.model_config)
def test_forward(self):
x = torch.randn(1, 1024)
# Test the forward pass of the DiscriminatorR class
fmap, output = self.model(x)
self.assertEqual(len(fmap), 6)
# Assert the shape of the feature maps explicitly
# fmap_dims = [
# torch.Size([32, 1, 513]),
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 65]),
# torch.Size([32, 1, 65]),
# torch.Size([1, 1, 65]),
# ]
# for i in range(len(fmap)):
# self.assertEqual(fmap[i].shape, fmap_dims[i])
# first_dim, second_dim = 32, 1
# init_p = 9
# def dim_3rd(p: int = init_p):
# return max(2**p + 1, 2**6 + 1)
# # Assert the shape of the feature maps
# for i, fmap_ in enumerate(fmap[:-1]):
# self.assertEqual(
# fmap_.shape, torch.Size([first_dim, second_dim, dim_3rd(init_p - i)]),
# )
# self.assertEqual(fmap[-1].shape, torch.Size([second_dim, second_dim, 65]))
self.assertEqual(output.shape, (1, 513))
def test_spectrogram(self):
x = torch.randn(4, 1, 16384)
# Test the spectrogram function of the DiscriminatorR class
mag = self.model.spectrogram(x)
self.assertEqual(mag.shape, (4, 513, 64))
if __name__ == "__main__":
unittest.main()