Spaces:
Sleeping
Sleeping
File size: 1,476 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import unittest
import torch
from models.tts.delightful_tts.acoustic_model.embedding import Embedding
class TestEmbedding(unittest.TestCase):
def setUp(self):
self.embedding = Embedding(
num_embeddings=100,
embedding_dim=50,
)
def test_forward_output_shape(self):
# Generate a tensor of indices to lookup in the embedding
idx = torch.randint(
low=0,
high=100,
size=(10, 20),
) # for a sequence of length 20 and batch size 10
# Test the forward function
output = self.embedding(idx)
# Check the output's shape is as expected
self.assertEqual(output.shape, (10, 20, 50))
def test_forward_output_values(self):
idx = torch.LongTensor([[0, 50], [99, 1]]) # Indices to lookup in the embedding
output = self.embedding(idx)
# Check the values returned by forward function match the expected embeddings
self.assertTrue(torch.all(output[0, 0] == self.embedding.embeddings[0]))
self.assertTrue(torch.all(output[1, 1] == self.embedding.embeddings[1]))
def test_dtype(self):
idx = torch.randint(
low=0,
high=100,
size=(10, 20),
) # some example indices
output = self.embedding(idx)
# Check the data type of output
self.assertEqual(output.dtype, torch.float32)
if __name__ == "__main__":
unittest.main()
|