File size: 2,301 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import unittest

from numba import prange
import numpy as np

from models.tts.delightful_tts.acoustic_model.mas import b_mas, mas_width1


class TestMasWidth1(unittest.TestCase):
    def setUp(self):
        # Example attn_map with random number between 0 and 1
        self.attn_map = np.random.rand(5, 5)

    def test_mas_width1(self):
        # Test output of mas_width1 function
        opt = mas_width1(self.attn_map)

        # Assert opt returned is a numpy ndarray
        self.assertIsInstance(opt, np.ndarray)

        # Assert the shapes of input attn_map and output opt are same
        self.assertEqual(opt.shape, self.attn_map.shape)

        # Assert opt only contains 0s and 1s (as per function description)
        self.assertTrue(np.array_equal(opt, opt.astype(bool)))

        # Assert that at least one entry in opt is 1.0 (since at least one optimal position should exist)
        self.assertIn(1.0, opt)


class TestBMas(unittest.TestCase):
    def setUp(self):
        # Create a sample batched attention map for testing
        # This sets up a batch of 2 attention maps, each of size 5 by 5
        self.b_attn_map = np.random.rand(2, 1, 5, 5)
        # Lengths of sequences in the input batch
        self.in_lens = np.array([3, 4])
        # Lengths of sequences in the output batch
        self.out_lens = np.array([4, 5])

    def test_b_mas(self):
        # Run the b_mas function taking width = 1
        attn_out = b_mas(self.b_attn_map, self.in_lens, self.out_lens, width=1)

        # Check the output type
        self.assertIsInstance(attn_out, np.ndarray)

        # Check if output and input have same shape
        self.assertEqual(attn_out.shape, self.b_attn_map.shape)

        # Assert attn_out only contains 0s and 1s.
        self.assertTrue(np.array_equal(attn_out, attn_out.astype(bool)))

        # Verify that the first dimension size equals batch size (2)
        self.assertEqual(attn_out.shape[0], 2)

        # Verify that the third and fourth dimensions size matches the given out_lens and in_len
        for b in prange(attn_out.shape[0]):
            self.assertEqual(
                np.sum(attn_out[b, 0, : self.out_lens[b], : self.in_lens[b]]),
                self.out_lens[b],
            )


if __name__ == "__main__":
    unittest.main()