Spaces:
Running
Running
File size: 1,448 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import unittest
import torch
from training.loss import BinLoss
class TestBinLoss(unittest.TestCase):
def setUp(self):
self.bin_loss = BinLoss()
def test_forward_hard_attention(self):
# Test with hard attention
hard_attention = torch.tensor([1, 0, 1, 0])
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2])
loss = self.bin_loss(hard_attention, soft_attention)
expected_loss = -(torch.log(torch.tensor([0.9, 0.8]))).sum() / 2
self.assertAlmostEqual(loss.item(), expected_loss.item())
def test_forward_soft_attention(self):
# Test with soft attention
hard_attention = torch.tensor([1, 0, 1, 0])
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2], requires_grad=True)
loss = self.bin_loss(hard_attention, soft_attention)
expected_loss = (
-(torch.log(torch.tensor([0.9, 0.8], requires_grad=True))).sum() / 2
)
expected_loss.backward()
loss.backward()
self.assertAlmostEqual(loss.item(), expected_loss.item())
self.assertIsNotNone(soft_attention.grad)
if soft_attention.grad is not None:
self.assertTrue(
torch.allclose(
soft_attention.grad,
torch.tensor([-0.5556, 0.0000, -0.6250, 0.0000]),
atol=1e-4,
),
)
if __name__ == "__main__":
unittest.main()
|