Spaces:
Running
Running
File size: 3,681 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import unittest
import torch
from training.loss.fast_speech_2_loss_gen import FastSpeech2LossGen
class TestFastSpeech2LossGen(unittest.TestCase):
def setUp(self):
self.loss_gen = FastSpeech2LossGen()
def test_forward(self):
# Reproducible results
torch.random.manual_seed(0)
# Test with all inputs of shape (1, 11)
src_masks = torch.zeros((1, 11), dtype=torch.bool)
mel_masks = torch.zeros((1, 11), dtype=torch.bool)
mel_targets = torch.randn((1, 11, 11))
# postnet = torch.randn((1, 11, 11))
mel_predictions = torch.randn((1, 11, 11))
log_duration_predictions = torch.randn((1, 11))
u_prosody_ref = torch.randn((1, 11))
u_prosody_pred = torch.randn((1, 11))
p_prosody_ref = torch.randn((1, 11, 11))
p_prosody_pred = torch.randn((1, 11, 11))
durations = torch.randn((1, 11))
pitch_predictions = torch.randn((1, 11))
p_targets = torch.randn((1, 11))
attn_logprob = torch.randn((1, 1, 11, 11))
attn_soft = torch.randn((1, 11, 11))
attn_hard = torch.randn((1, 11, 11))
step = 20000
src_lens = torch.ones((1,), dtype=torch.long)
mel_lens = torch.ones((1,), dtype=torch.long)
energy_pred = torch.randn((1, 11))
energy_target = torch.randn((1, 11))
(
total_loss,
mel_loss,
# mel_loss_postnet,
ssim_loss,
# ssim_loss_postnet,
duration_loss,
u_prosody_loss,
p_prosody_loss,
pitch_loss,
ctc_loss,
bin_loss,
energy_loss,
) = self.loss_gen.forward(
src_masks,
mel_masks,
mel_targets,
mel_predictions,
# postnet,
log_duration_predictions,
u_prosody_ref,
u_prosody_pred,
p_prosody_ref,
p_prosody_pred,
durations,
pitch_predictions,
p_targets,
attn_logprob,
attn_soft,
attn_hard,
step,
src_lens,
mel_lens,
energy_pred,
energy_target,
)
self.assertIsInstance(total_loss, torch.Tensor)
self.assertIsInstance(mel_loss, torch.Tensor)
# self.assertIsInstance(mel_loss_postnet, torch.Tensor)
self.assertIsInstance(ssim_loss, torch.Tensor)
# self.assertIsInstance(ssim_loss_postnet, torch.Tensor)
self.assertIsInstance(duration_loss, torch.Tensor)
self.assertIsInstance(u_prosody_loss, torch.Tensor)
self.assertIsInstance(p_prosody_loss, torch.Tensor)
self.assertIsInstance(pitch_loss, torch.Tensor)
self.assertIsInstance(ctc_loss, torch.Tensor)
self.assertIsInstance(bin_loss, torch.Tensor)
self.assertIsInstance(energy_loss, torch.Tensor)
# Assert the value of losses
self.assertTrue(
torch.all(
torch.tensor(
[
total_loss,
mel_loss,
# mel_loss_postnet,
ssim_loss,
# ssim_loss_postnet,
duration_loss,
u_prosody_loss,
p_prosody_loss,
pitch_loss,
ctc_loss,
bin_loss,
energy_loss,
],
)
>= 0,
),
)
if __name__ == "__main__":
unittest.main()
|