import unittest import torch from models.config import PreprocessingConfigUnivNet as PreprocessingConfig from models.config import VocoderModelConfig from models.vocoder.univnet import Discriminator, Generator # One of the most important test case for univnet # Integration test class TestDiscriminator(unittest.TestCase): def setUp(self): self.model_config = VocoderModelConfig() self.preprocess_config = PreprocessingConfig("english_only") self.generator = Generator(self.model_config, self.preprocess_config) self.model = Discriminator(self.model_config) self.batch_size = 1 self.in_length = 100 self.c = torch.randn( self.batch_size, self.preprocess_config.stft.n_mel_channels, self.in_length, ) def test_forward(self): # Test the forward pass of the Discriminator class x = self.generator(self.c) output = self.model(x) self.assertEqual(len(output), 2) # Assert MRD length self.assertEqual(len(output[0]), 3) # Assert MPD length self.assertEqual(len(output[1]), 5) # Test MRD output # output_mrd = output[0] # fmap_mrd_dims = [ # [ # torch.Size([32, 1, 513]), # torch.Size([32, 1, 257]), # torch.Size([32, 1, 129]), # torch.Size([32, 1, 65]), # torch.Size([32, 1, 65]), # torch.Size([1, 1, 65]), # ], # [ # torch.Size([32, 1, 1025]), # torch.Size([32, 1, 513]), # torch.Size([32, 1, 257]), # torch.Size([32, 1, 129]), # torch.Size([32, 1, 129]), # torch.Size([1, 1, 129]), # ], # [ # torch.Size([32, 1, 257]), # torch.Size([32, 1, 129]), # torch.Size([32, 1, 65]), # torch.Size([32, 1, 33]), # torch.Size([32, 1, 33]), # torch.Size([1, 1, 33]), # ], # ] # for key in range(len(output[0])): # fmap = output_mrd[key][0] # x = output_mrd[key][1] # fmap_dims = fmap_mrd_dims[key] # # Assert the shape of the feature maps # for i, fmap_ in enumerate(fmap): # # Assert the feature map shape explicitly # self.assertEqual(fmap_.shape, fmap_dims[i]) # # Test MPD output # output_mpd = output[1] # fmap_mpd_dims = [ # [ # torch.Size([1, 64, 4267, 2]), # torch.Size([1, 128, 1423, 2]), # torch.Size([1, 256, 475, 2]), # torch.Size([1, 512, 159, 2]), # torch.Size([1, 1024, 159, 2]), # torch.Size([1, 1, 159, 2]), # ], # [ # torch.Size([1, 64, 2845, 3]), # torch.Size([1, 128, 949, 3]), # torch.Size([1, 256, 317, 3]), # torch.Size([1, 512, 106, 3]), # torch.Size([1, 1024, 106, 3]), # torch.Size([1, 1, 106, 3]), # ], # [ # torch.Size([1, 64, 1707, 5]), # torch.Size([1, 128, 569, 5]), # torch.Size([1, 256, 190, 5]), # torch.Size([1, 512, 64, 5]), # torch.Size([1, 1024, 64, 5]), # torch.Size([1, 1, 64, 5]), # ], # [ # torch.Size([1, 64, 1220, 7]), # torch.Size([1, 128, 407, 7]), # torch.Size([1, 256, 136, 7]), # torch.Size([1, 512, 46, 7]), # torch.Size([1, 1024, 46, 7]), # torch.Size([1, 1, 46, 7]), # ], # [ # torch.Size([1, 64, 776, 11]), # torch.Size([1, 128, 259, 11]), # torch.Size([1, 256, 87, 11]), # torch.Size([1, 512, 29, 11]), # torch.Size([1, 1024, 29, 11]), # torch.Size([1, 1, 29, 11]), # ], # ] # for key in range(len(output[1])): # fmap = output_mpd[key][0] # x = output_mpd[key][1] # fmap_dims = fmap_mpd_dims[key] # # Assert the shape of the feature maps # for i, fmap in enumerate(fmap): # # Assert the feature map shape explicitly # self.assertEqual(fmap.shape, fmap_dims[i]) if __name__ == "__main__": unittest.main()