Spaces:
Sleeping
Sleeping
import unittest | |
import torch | |
from models.config import VocoderModelConfig | |
from models.vocoder.univnet.discriminator_r import DiscriminatorR | |
class TestDiscriminatorR(unittest.TestCase): | |
def setUp(self): | |
self.resolution = (1024, 256, 1024) | |
self.model_config = VocoderModelConfig() | |
self.model = DiscriminatorR(self.resolution, self.model_config) | |
def test_forward(self): | |
x = torch.randn(1, 1024) | |
# Test the forward pass of the DiscriminatorR class | |
fmap, output = self.model(x) | |
self.assertEqual(len(fmap), 6) | |
# Assert the shape of the feature maps explicitly | |
# fmap_dims = [ | |
# torch.Size([32, 1, 513]), | |
# torch.Size([32, 1, 257]), | |
# torch.Size([32, 1, 129]), | |
# torch.Size([32, 1, 65]), | |
# torch.Size([32, 1, 65]), | |
# torch.Size([1, 1, 65]), | |
# ] | |
# for i in range(len(fmap)): | |
# self.assertEqual(fmap[i].shape, fmap_dims[i]) | |
# first_dim, second_dim = 32, 1 | |
# init_p = 9 | |
# def dim_3rd(p: int = init_p): | |
# return max(2**p + 1, 2**6 + 1) | |
# # Assert the shape of the feature maps | |
# for i, fmap_ in enumerate(fmap[:-1]): | |
# self.assertEqual( | |
# fmap_.shape, torch.Size([first_dim, second_dim, dim_3rd(init_p - i)]), | |
# ) | |
# self.assertEqual(fmap[-1].shape, torch.Size([second_dim, second_dim, 65])) | |
self.assertEqual(output.shape, (1, 513)) | |
def test_spectrogram(self): | |
x = torch.randn(4, 1, 16384) | |
# Test the spectrogram function of the DiscriminatorR class | |
mag = self.model.spectrogram(x) | |
self.assertEqual(mag.shape, (4, 513, 64)) | |
if __name__ == "__main__": | |
unittest.main() | |