Spaces:
Running
Running
File size: 4,290 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import math
import unittest
import torch
from models.config import VocoderModelConfig
from models.vocoder.univnet import MultiPeriodDiscriminator
class TestMultiPeriodDiscriminator(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.channels = 1
self.time_steps = 100
self.model_config = VocoderModelConfig()
self.model = MultiPeriodDiscriminator(self.model_config)
self.x = torch.randn(self.batch_size, self.channels, self.time_steps)
def test_forward(self):
output = self.model(self.x)
self.assertEqual(len(output), len(self.model_config.mpd.periods))
fmaps_dims = [
[
torch.Size([self.batch_size, 64, 17, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 128, 6, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 256, 2, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[0]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[0]],
),
],
[
torch.Size([self.batch_size, 64, 12, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 128, 4, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 256, 2, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[1]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[1]],
),
],
[
torch.Size([self.batch_size, 64, 7, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 128, 3, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[2]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[2]],
),
],
[
torch.Size([self.batch_size, 64, 5, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 128, 2, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[3]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[3]],
),
],
[
torch.Size([self.batch_size, 64, 4, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 128, 2, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[4]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[4]],
),
],
]
init_2nd_dims = [17, 12, 7, 5, 4]
for mpd_k in range(len(self.model_config.mpd.periods)):
fmap = output[mpd_k][0]
x = output[mpd_k][1]
self.assertEqual(len(x), self.batch_size)
# Assert the shape of the feature maps
dim_2nd = init_2nd_dims[mpd_k]
period = self.model_config.mpd.periods[mpd_k]
dims_expl = fmaps_dims[mpd_k]
for i in range(len(self.model_config.mpd.periods)):
# Assert the shape of the feature maps explicitly
self.assertEqual(fmap[i].shape, dims_expl[i])
# Assert the shape of the feature maps
self.assertEqual(
fmap[i].shape,
torch.Size([self.batch_size, 2 ** (i + 6), dim_2nd, period]),
)
dim_2nd = math.ceil(dim_2nd / self.model_config.mpd.stride)
self.assertEqual(len(output), len(self.model_config.mpd.periods))
if __name__ == "__main__":
unittest.main()
|