PeechTTSv22050 / models /helpers /tests /test_pitch_phoneme_averaging.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.helpers import (
pitch_phoneme_averaging,
)
class TestPitchPhonemeAveraging(unittest.TestCase):
def test_pitch_phoneme_averaging(self):
# Initialize inputs
durations = torch.tensor([[5, 1, 3, 0], [2, 4, 0, 0]], dtype=torch.float32)
num_phonemes = durations.shape[-1]
max_length = int(torch.sum(durations, dim=1).int().max().item())
pitches = torch.rand(2, max_length)
max_phoneme_len = num_phonemes
# Call the pitch_phoneme_averaging method
result = pitch_phoneme_averaging(durations, pitches, max_phoneme_len)
# Assert output type
self.assertIsInstance(result, torch.Tensor)
# Assert output shape
expected_shape = durations.shape
self.assertEqual(result.shape, expected_shape)
# Assert all pitch values are within [0,1] since input pitch values were from a uniform distribution over [0,1]
self.assertTrue(torch.all((result >= 0) & (result <= 1)))
# Run tests
if __name__ == "__main__":
unittest.main()