File size: 1,501 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import unittest

import numpy as np
import torch

from models.helpers.tools import initialize_embeddings


class TestInitializeEmbeddings(unittest.TestCase):
    def test_initialize_embeddings(self):
        # Test with correct input shape
        shape = (5, 10)
        result = initialize_embeddings(shape)
        # Assert output is torch.Tensor
        self.assertIsInstance(result, torch.Tensor)
        # Assert output shape
        self.assertEqual(result.shape, shape)
        # Assert type of elements
        self.assertEqual(result.dtype, torch.float32)

        # Assert standard deviation is close to expected (within some tolerance)
        expected_stddev = np.sqrt(2 / shape[1])
        tolerance = 0.1
        self.assertLessEqual(abs(result.std().item() - expected_stddev), tolerance)

        # Test with incorrect number of dimensions in shape
        incorrect_shape = (5, 10, 15)
        with self.assertRaises(AssertionError) as context:
            initialize_embeddings(incorrect_shape)
        self.assertEqual(
            str(context.exception), "Can only initialize 2-D embedding matrices ...",
        )

        # Test with zero dimensions in shape
        zero_dim_shape = ()
        with self.assertRaises(AssertionError) as context:
            initialize_embeddings(zero_dim_shape)
        self.assertEqual(
            str(context.exception), "Can only initialize 2-D embedding matrices ...",
        )


# Run tests
if __name__ == "__main__":
    unittest.main()