Spaces:
Running
Running
import unittest | |
import torch | |
from training.loss import BinLoss | |
class TestBinLoss(unittest.TestCase): | |
def setUp(self): | |
self.bin_loss = BinLoss() | |
def test_forward_hard_attention(self): | |
# Test with hard attention | |
hard_attention = torch.tensor([1, 0, 1, 0]) | |
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2]) | |
loss = self.bin_loss(hard_attention, soft_attention) | |
expected_loss = -(torch.log(torch.tensor([0.9, 0.8]))).sum() / 2 | |
self.assertAlmostEqual(loss.item(), expected_loss.item()) | |
def test_forward_soft_attention(self): | |
# Test with soft attention | |
hard_attention = torch.tensor([1, 0, 1, 0]) | |
soft_attention = torch.tensor([0.9, 0.1, 0.8, 0.2], requires_grad=True) | |
loss = self.bin_loss(hard_attention, soft_attention) | |
expected_loss = ( | |
-(torch.log(torch.tensor([0.9, 0.8], requires_grad=True))).sum() / 2 | |
) | |
expected_loss.backward() | |
loss.backward() | |
self.assertAlmostEqual(loss.item(), expected_loss.item()) | |
self.assertIsNotNone(soft_attention.grad) | |
if soft_attention.grad is not None: | |
self.assertTrue( | |
torch.allclose( | |
soft_attention.grad, | |
torch.tensor([-0.5556, 0.0000, -0.6250, 0.0000]), | |
atol=1e-4, | |
), | |
) | |
if __name__ == "__main__": | |
unittest.main() | |