Spaces:
Sleeping
Sleeping
import unittest | |
import torch | |
from training.loss.fast_speech_2_loss_gen import FastSpeech2LossGen | |
class TestFastSpeech2LossGen(unittest.TestCase): | |
def setUp(self): | |
self.loss_gen = FastSpeech2LossGen() | |
def test_forward(self): | |
# Reproducible results | |
torch.random.manual_seed(0) | |
# Test with all inputs of shape (1, 11) | |
src_masks = torch.zeros((1, 11), dtype=torch.bool) | |
mel_masks = torch.zeros((1, 11), dtype=torch.bool) | |
mel_targets = torch.randn((1, 11, 11)) | |
# postnet = torch.randn((1, 11, 11)) | |
mel_predictions = torch.randn((1, 11, 11)) | |
log_duration_predictions = torch.randn((1, 11)) | |
u_prosody_ref = torch.randn((1, 11)) | |
u_prosody_pred = torch.randn((1, 11)) | |
p_prosody_ref = torch.randn((1, 11, 11)) | |
p_prosody_pred = torch.randn((1, 11, 11)) | |
durations = torch.randn((1, 11)) | |
pitch_predictions = torch.randn((1, 11)) | |
p_targets = torch.randn((1, 11)) | |
attn_logprob = torch.randn((1, 1, 11, 11)) | |
attn_soft = torch.randn((1, 11, 11)) | |
attn_hard = torch.randn((1, 11, 11)) | |
step = 20000 | |
src_lens = torch.ones((1,), dtype=torch.long) | |
mel_lens = torch.ones((1,), dtype=torch.long) | |
energy_pred = torch.randn((1, 11)) | |
energy_target = torch.randn((1, 11)) | |
( | |
total_loss, | |
mel_loss, | |
# mel_loss_postnet, | |
ssim_loss, | |
# ssim_loss_postnet, | |
duration_loss, | |
u_prosody_loss, | |
p_prosody_loss, | |
pitch_loss, | |
ctc_loss, | |
bin_loss, | |
energy_loss, | |
) = self.loss_gen.forward( | |
src_masks, | |
mel_masks, | |
mel_targets, | |
mel_predictions, | |
# postnet, | |
log_duration_predictions, | |
u_prosody_ref, | |
u_prosody_pred, | |
p_prosody_ref, | |
p_prosody_pred, | |
durations, | |
pitch_predictions, | |
p_targets, | |
attn_logprob, | |
attn_soft, | |
attn_hard, | |
step, | |
src_lens, | |
mel_lens, | |
energy_pred, | |
energy_target, | |
) | |
self.assertIsInstance(total_loss, torch.Tensor) | |
self.assertIsInstance(mel_loss, torch.Tensor) | |
# self.assertIsInstance(mel_loss_postnet, torch.Tensor) | |
self.assertIsInstance(ssim_loss, torch.Tensor) | |
# self.assertIsInstance(ssim_loss_postnet, torch.Tensor) | |
self.assertIsInstance(duration_loss, torch.Tensor) | |
self.assertIsInstance(u_prosody_loss, torch.Tensor) | |
self.assertIsInstance(p_prosody_loss, torch.Tensor) | |
self.assertIsInstance(pitch_loss, torch.Tensor) | |
self.assertIsInstance(ctc_loss, torch.Tensor) | |
self.assertIsInstance(bin_loss, torch.Tensor) | |
self.assertIsInstance(energy_loss, torch.Tensor) | |
# Assert the value of losses | |
self.assertTrue( | |
torch.all( | |
torch.tensor( | |
[ | |
total_loss, | |
mel_loss, | |
# mel_loss_postnet, | |
ssim_loss, | |
# ssim_loss_postnet, | |
duration_loss, | |
u_prosody_loss, | |
p_prosody_loss, | |
pitch_loss, | |
ctc_loss, | |
bin_loss, | |
energy_loss, | |
], | |
) | |
>= 0, | |
), | |
) | |
if __name__ == "__main__": | |
unittest.main() | |