PeechTTSv22050 / models /vocoder /univnet /tests /test_multi_resolution_discriminator.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.config import VocoderModelConfig
from models.vocoder.univnet import MultiResolutionDiscriminator
class TestMultiResolutionDiscriminator(unittest.TestCase):
def setUp(self):
self.resolution = [(1024, 256, 1024), (2048, 512, 2048)]
self.model_config = VocoderModelConfig()
self.model = MultiResolutionDiscriminator(self.model_config)
self.x = torch.randn(1, 1024)
def test_forward(self):
# Test the forward pass of the MultiResolutionDiscriminator class
output = self.model(self.x)
self.assertEqual(len(output), 3)
# fmap_dims = [
# [
# torch.Size([32, 1, 513]),
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 65]),
# torch.Size([32, 1, 65]),
# torch.Size([1, 65]),
# ],
# [
# torch.Size([32, 1, 1025]),
# torch.Size([32, 1, 513]),
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 129]),
# torch.Size([1, 129]),
# ],
# [
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 65]),
# torch.Size([32, 1, 33]),
# torch.Size([32, 1, 33]),
# torch.Size([1, 33]),
# ],
# ]
# init_powers_max_min = [(9, 6), (10, 7), (8, 5)]
# for key in range(len(output)):
# fmap = output[key][0]
# first_dim, second_dim = 32, 1
# p_max, p_min = init_powers_max_min[key]
# def dim_3rd(p: int):
# return max(2**p + 1, 2**p_min + 1)
# fmap_dim = fmap_dims[key]
# # Assert the shape of the feature maps
# for i, fmap_ in enumerate(fmap[:-1]):
# # Assert the feature map shape explicitly
# self.assertEqual(fmap_.shape, fmap_dim[i])
# self.assertEqual(
# fmap_.shape, torch.Size([first_dim, second_dim, dim_3rd(p_max - i)]),
# )
# self.assertEqual(fmap[-1].shape, torch.Size([second_dim, second_dim, 2**p_min + 1]))
if __name__ == "__main__":
unittest.main()