Spaces:
Running
Running
File size: 4,666 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import unittest
import torch
from models.config import PreprocessingConfigUnivNet as PreprocessingConfig
from models.config import VocoderModelConfig
from models.vocoder.univnet import Discriminator, Generator
# One of the most important test case for univnet
# Integration test
class TestDiscriminator(unittest.TestCase):
def setUp(self):
self.model_config = VocoderModelConfig()
self.preprocess_config = PreprocessingConfig("english_only")
self.generator = Generator(self.model_config, self.preprocess_config)
self.model = Discriminator(self.model_config)
self.batch_size = 1
self.in_length = 100
self.c = torch.randn(
self.batch_size,
self.preprocess_config.stft.n_mel_channels,
self.in_length,
)
def test_forward(self):
# Test the forward pass of the Discriminator class
x = self.generator(self.c)
output = self.model(x)
self.assertEqual(len(output), 2)
# Assert MRD length
self.assertEqual(len(output[0]), 3)
# Assert MPD length
self.assertEqual(len(output[1]), 5)
# Test MRD output
# output_mrd = output[0]
# fmap_mrd_dims = [
# [
# torch.Size([32, 1, 513]),
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 65]),
# torch.Size([32, 1, 65]),
# torch.Size([1, 1, 65]),
# ],
# [
# torch.Size([32, 1, 1025]),
# torch.Size([32, 1, 513]),
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 129]),
# torch.Size([1, 1, 129]),
# ],
# [
# torch.Size([32, 1, 257]),
# torch.Size([32, 1, 129]),
# torch.Size([32, 1, 65]),
# torch.Size([32, 1, 33]),
# torch.Size([32, 1, 33]),
# torch.Size([1, 1, 33]),
# ],
# ]
# for key in range(len(output[0])):
# fmap = output_mrd[key][0]
# x = output_mrd[key][1]
# fmap_dims = fmap_mrd_dims[key]
# # Assert the shape of the feature maps
# for i, fmap_ in enumerate(fmap):
# # Assert the feature map shape explicitly
# self.assertEqual(fmap_.shape, fmap_dims[i])
# # Test MPD output
# output_mpd = output[1]
# fmap_mpd_dims = [
# [
# torch.Size([1, 64, 4267, 2]),
# torch.Size([1, 128, 1423, 2]),
# torch.Size([1, 256, 475, 2]),
# torch.Size([1, 512, 159, 2]),
# torch.Size([1, 1024, 159, 2]),
# torch.Size([1, 1, 159, 2]),
# ],
# [
# torch.Size([1, 64, 2845, 3]),
# torch.Size([1, 128, 949, 3]),
# torch.Size([1, 256, 317, 3]),
# torch.Size([1, 512, 106, 3]),
# torch.Size([1, 1024, 106, 3]),
# torch.Size([1, 1, 106, 3]),
# ],
# [
# torch.Size([1, 64, 1707, 5]),
# torch.Size([1, 128, 569, 5]),
# torch.Size([1, 256, 190, 5]),
# torch.Size([1, 512, 64, 5]),
# torch.Size([1, 1024, 64, 5]),
# torch.Size([1, 1, 64, 5]),
# ],
# [
# torch.Size([1, 64, 1220, 7]),
# torch.Size([1, 128, 407, 7]),
# torch.Size([1, 256, 136, 7]),
# torch.Size([1, 512, 46, 7]),
# torch.Size([1, 1024, 46, 7]),
# torch.Size([1, 1, 46, 7]),
# ],
# [
# torch.Size([1, 64, 776, 11]),
# torch.Size([1, 128, 259, 11]),
# torch.Size([1, 256, 87, 11]),
# torch.Size([1, 512, 29, 11]),
# torch.Size([1, 1024, 29, 11]),
# torch.Size([1, 1, 29, 11]),
# ],
# ]
# for key in range(len(output[1])):
# fmap = output_mpd[key][0]
# x = output_mpd[key][1]
# fmap_dims = fmap_mpd_dims[key]
# # Assert the shape of the feature maps
# for i, fmap in enumerate(fmap):
# # Assert the feature map shape explicitly
# self.assertEqual(fmap.shape, fmap_dims[i])
if __name__ == "__main__":
unittest.main()
|