File size: 1,476 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import unittest

import torch

from models.tts.delightful_tts.acoustic_model.embedding import Embedding


class TestEmbedding(unittest.TestCase):
    def setUp(self):
        self.embedding = Embedding(
            num_embeddings=100,
            embedding_dim=50,
        )

    def test_forward_output_shape(self):
        # Generate a tensor of indices to lookup in the embedding
        idx = torch.randint(
            low=0,
            high=100,
            size=(10, 20),
        )  # for a sequence of length 20 and batch size 10

        # Test the forward function
        output = self.embedding(idx)

        # Check the output's shape is as expected
        self.assertEqual(output.shape, (10, 20, 50))

    def test_forward_output_values(self):
        idx = torch.LongTensor([[0, 50], [99, 1]])  # Indices to lookup in the embedding

        output = self.embedding(idx)

        # Check the values returned by forward function match the expected embeddings
        self.assertTrue(torch.all(output[0, 0] == self.embedding.embeddings[0]))
        self.assertTrue(torch.all(output[1, 1] == self.embedding.embeddings[1]))

    def test_dtype(self):
        idx = torch.randint(
            low=0,
            high=100,
            size=(10, 20),
        )  # some example indices

        output = self.embedding(idx)

        # Check the data type of output
        self.assertEqual(output.dtype, torch.float32)


if __name__ == "__main__":
    unittest.main()