PeechTTSv22050 / models /vocoder /univnet /tests /test_multi_period_discriminator.py
nickovchinnikov's picture
Init
9d61c9b
import math
import unittest
import torch
from models.config import VocoderModelConfig
from models.vocoder.univnet import MultiPeriodDiscriminator
class TestMultiPeriodDiscriminator(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.channels = 1
self.time_steps = 100
self.model_config = VocoderModelConfig()
self.model = MultiPeriodDiscriminator(self.model_config)
self.x = torch.randn(self.batch_size, self.channels, self.time_steps)
def test_forward(self):
output = self.model(self.x)
self.assertEqual(len(output), len(self.model_config.mpd.periods))
fmaps_dims = [
[
torch.Size([self.batch_size, 64, 17, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 128, 6, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 256, 2, self.model_config.mpd.periods[0]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[0]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[0]],
),
],
[
torch.Size([self.batch_size, 64, 12, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 128, 4, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 256, 2, self.model_config.mpd.periods[1]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[1]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[1]],
),
],
[
torch.Size([self.batch_size, 64, 7, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 128, 3, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[2]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[2]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[2]],
),
],
[
torch.Size([self.batch_size, 64, 5, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 128, 2, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[3]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[3]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[3]],
),
],
[
torch.Size([self.batch_size, 64, 4, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 128, 2, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 256, 1, self.model_config.mpd.periods[4]]),
torch.Size([self.batch_size, 512, 1, self.model_config.mpd.periods[4]]),
torch.Size(
[self.batch_size, 1024, 1, self.model_config.mpd.periods[4]],
),
],
]
init_2nd_dims = [17, 12, 7, 5, 4]
for mpd_k in range(len(self.model_config.mpd.periods)):
fmap = output[mpd_k][0]
x = output[mpd_k][1]
self.assertEqual(len(x), self.batch_size)
# Assert the shape of the feature maps
dim_2nd = init_2nd_dims[mpd_k]
period = self.model_config.mpd.periods[mpd_k]
dims_expl = fmaps_dims[mpd_k]
for i in range(len(self.model_config.mpd.periods)):
# Assert the shape of the feature maps explicitly
self.assertEqual(fmap[i].shape, dims_expl[i])
# Assert the shape of the feature maps
self.assertEqual(
fmap[i].shape,
torch.Size([self.batch_size, 2 ** (i + 6), dim_2nd, period]),
)
dim_2nd = math.ceil(dim_2nd / self.model_config.mpd.stride)
self.assertEqual(len(output), len(self.model_config.mpd.periods))
if __name__ == "__main__":
unittest.main()