File size: 955 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import unittest

import torch

from training.loss import ForwardSumLoss


class TestForwardSumLoss(unittest.TestCase):
    def setUp(self):
        self.forward_sum_loss = ForwardSumLoss()

    def test_forward(self):
        # Reproducible results
        torch.random.manual_seed(0)

        T = 2  # Input sequence length
        C = 2  # Number of classes (including blank)
        N = 1  # Batch size
        S = 1  # Target sequence length of longest target in batch (padding length)

        attn_logprob = torch.randn(T, N, C, C).log_softmax(2).detach().requires_grad_()

        in_lens = torch.full(size=(T,), fill_value=T, dtype=torch.long)
        out_lens = torch.randint(low=S, high=T, size=(T,), dtype=torch.long)

        loss = self.forward_sum_loss(attn_logprob, in_lens, out_lens)
        expected_loss = torch.tensor([0.0])

        self.assertTrue(torch.allclose(loss, expected_loss))


if __name__ == "__main__":
    unittest.main()