PeechTTSv22050 / training /loss /tests /test_univnet_loss.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from training.loss import UnivnetLoss
class TestUnivnetLoss(unittest.TestCase):
def setUp(self):
torch.random.manual_seed(42)
self.loss_module = UnivnetLoss()
def test_forward(self):
# Create some fake input data
audio = torch.randn(1, 1, 22050)
fake_audio = torch.randn(1, 1, 22050)
res_fake = [(torch.randn(1, 1, 22050), torch.randn(1))]
period_fake = [(torch.randn(1, 1, 22050), torch.randn(1))]
res_real = [(torch.randn(1, 1, 22050), torch.randn(1))]
period_real = [(torch.randn(1, 1, 22050), torch.randn(1))]
# Call the forward method
output = self.loss_module.forward(
audio,
fake_audio,
res_fake,
period_fake,
res_real,
period_real,
)
# Check that the output is a tuple with the expected lens
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 6)
(
total_loss_gen,
total_loss_disc,
stft_loss,
score_loss,
esr_loss,
snr_loss,
) = output
self.assertIsInstance(total_loss_gen, torch.Tensor)
self.assertIsInstance(total_loss_disc, torch.Tensor)
self.assertIsInstance(stft_loss, torch.Tensor)
self.assertIsInstance(score_loss, torch.Tensor)
self.assertIsInstance(esr_loss, torch.Tensor)
self.assertIsInstance(snr_loss, torch.Tensor)
# Assert the value of losses
self.assertTrue(
torch.all(
torch.tensor(
[
total_loss_gen,
total_loss_disc,
stft_loss,
score_loss,
esr_loss,
snr_loss,
],
)
>= 0,
),
)
if __name__ == "__main__":
unittest.main()