PeechTTSv22050 / training /loss /tests /test_bin_loss.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from training.loss import BinLoss
class TestBinLoss(unittest.TestCase):
def setUp(self):
self.bin_loss = BinLoss()
def test_forward_hard_attention(self):
# Test with hard attention
hard_attention = torch.tensor([1, 0, 1, 0])
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2])
loss = self.bin_loss(hard_attention, soft_attention)
expected_loss = -(torch.log(torch.tensor([0.9, 0.8]))).sum() / 2
self.assertAlmostEqual(loss.item(), expected_loss.item())
def test_forward_soft_attention(self):
# Test with soft attention
hard_attention = torch.tensor([1, 0, 1, 0])
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2], requires_grad=True)
loss = self.bin_loss(hard_attention, soft_attention)
expected_loss = (
-(torch.log(torch.tensor([0.9, 0.8], requires_grad=True))).sum() / 2
)
expected_loss.backward()
loss.backward()
self.assertAlmostEqual(loss.item(), expected_loss.item())
self.assertIsNotNone(soft_attention.grad)
if soft_attention.grad is not None:
self.assertTrue(
torch.allclose(
soft_attention.grad,
torch.tensor([-0.5556, 0.0000, -0.6250, 0.0000]),
atol=1e-4,
),
)
if __name__ == "__main__":
unittest.main()