Spaces:
Sleeping
Sleeping
File size: 910 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import unittest
from unittest.mock import Mock
import torch
from models.tts.delightful_tts.reference_encoder.STL import STL
class TestSTL(unittest.TestCase):
def setUp(self):
self.model_config = Mock()
self.model_config.encoder.n_hidden = 512
self.model_config.reference_encoder.token_num = 32
self.stl = STL(
self.model_config,
)
self.batch_size = 10
self.n_hidden = self.model_config.encoder.n_hidden
self.x = torch.rand(
self.batch_size,
self.n_hidden // 2,
)
def test_forward(self):
output = self.stl(self.x)
self.assertTrue(torch.is_tensor(output))
# Validate the output size
expected_shape = (self.batch_size, 1, self.stl.attention.num_units)
self.assertEqual(output.shape, expected_shape)
if __name__ == "__main__":
unittest.main()
|