PeechTTSv22050 / training /preprocess /tests /test_preprocess_libritts.py
nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.config import PreprocessingConfigUnivNet, get_lang_map
from training.preprocess import PreprocessLibriTTS
from training.preprocess.preprocess_libritts import PreprocessForAcousticResult
class TestPreprocessLibriTTS(unittest.TestCase):
def setUp(self):
torch.random.manual_seed(42)
lang_map = get_lang_map("en")
processing_lang_type = lang_map.processing_lang_type
self.preprocess_libritts = PreprocessLibriTTS(
PreprocessingConfigUnivNet(processing_lang_type),
)
def test_acoustic(self):
# Set the sampling rate and duration of the audio signal
sr_actual = 44100
duration = 1.0
# Set the frequency of the pitch (in Hz)
pitch_freq = 440.0
# Generate a time vector for the audio signal
t = torch.linspace(0, duration, int(sr_actual * duration))
# Generate a sinusoidal waveform with the specified pitch frequency
audio = torch.sin(2 * torch.pi * pitch_freq * t)
audio = audio.unsqueeze(0)
raw_text = "Hello, world!"
output = self.preprocess_libritts.acoustic(
(audio, sr_actual, raw_text, raw_text, 0, 0, "0"),
)
self.assertIsNotNone(output)
if output is not None:
self.assertIsInstance(output, PreprocessForAcousticResult)
self.assertIsInstance(output.wav, torch.Tensor)
self.assertIsInstance(output.mel, torch.Tensor)
self.assertIsInstance(output.pitch, torch.Tensor)
self.assertIsInstance(output.phones, torch.Tensor)
self.assertIsInstance(output.raw_text, str)
self.assertIsInstance(output.pitch_is_normalized, bool)
self.assertEqual(output.wav.shape, torch.Size([22050]))
self.assertEqual(output.mel.shape, torch.Size([100, 86]))
self.assertEqual(output.pitch.shape, torch.Size([86]))
torch.testing.assert_close(
output.phones,
torch.tensor(
[
2.0,
10.0,
37.0,
14.0,
50.0,
17.0,
45.0,
62.0,
71.0,
24.0,
50.0,
118.0,
52.0,
14.0,
6.0,
60.0,
71.0,
3.0,
],
),
)
self.assertEqual(output.raw_text, "Hello, world!")
self.assertFalse(output.pitch_is_normalized)
def test_acoustic_with_short_audio(self):
audio = torch.randn(1, 22050)
sr_actual = 22050
raw_text = "Hello, world!"
output = self.preprocess_libritts.acoustic(
(audio, sr_actual, raw_text, raw_text, 0, 0, "0"),
)
self.assertIsNone(output)
def test_acoustic_with_complicated_text(self):
# Set the sampling rate and duration of the audio signal
sr_actual = 44100
duration = 10.0
# Set the frequency of the pitch (in Hz)
pitch_freq = 440.0
# Generate a time vector for the audio signal
t = torch.linspace(0, duration, int(sr_actual * duration))
# Generate a sinusoidal waveform with the specified pitch frequency
audio = torch.sin(2 * torch.pi * pitch_freq * t).unsqueeze(0)
raw_text = r"""Hello, world! Wow!!!!! This is amazing?????
It’s a beautiful day…
Mr. Smith paid $111 in U.S.A. on Dec. 17th. We paid $123 for this desk."""
output = self.preprocess_libritts.acoustic(
(audio, sr_actual, raw_text, raw_text, 0, 0, "0"),
)
self.assertIsNotNone(output)
if output is not None:
self.assertEqual(output.attn_prior.shape, torch.Size([226, 861]))
self.assertEqual(output.mel.shape, torch.Size([100, 861]))
self.assertEqual(
output.normalized_text,
"Hello, world! Wow! This is amazing?. It's a beautiful day.. mister Smith paid one hundred and eleven dollars in USA on december seventeenth. We paid one hundred and twenty three dollars for this desk.",
)
self.assertEqual(output.phones.shape, torch.Size([226]))
self.assertEqual(len(output.phones_ipa), 224)
self.assertEqual(output.wav.shape, torch.Size([220500]))
def test_acoustic_with_long_audio(self):
audio = torch.randn(1, 88200)
sr_actual = 44100
raw_text = "Hello, world!"
output = self.preprocess_libritts.acoustic(
(audio, sr_actual, raw_text, raw_text, 0, 0, "0"),
)
self.assertIsNone(output)
def test_beta_binomial_prior_distribution(self):
phoneme_count = 10
mel_count = 20
prior_dist = self.preprocess_libritts.beta_binomial_prior_distribution(
phoneme_count,
mel_count,
)
self.assertIsInstance(prior_dist, torch.Tensor)
self.assertEqual(prior_dist.shape, (mel_count, phoneme_count))
def test_preprocess_univnet(self):
# Set the sampling rate and duration of the audio signal
sr_actual = 44100
duration = 10.0
# Set the frequency of the pitch (in Hz)
pitch_freq = 440.0
# Generate a time vector for the audio signal
t = torch.linspace(0, duration, int(sr_actual * duration))
# Generate a sinusoidal waveform with the specified pitch frequency
audio = torch.sin(2 * torch.pi * pitch_freq * t).unsqueeze(0)
speaker_id = 10
output = self.preprocess_libritts.univnet(
(audio, sr_actual, "", "", speaker_id, 0, ""),
)
self.assertIsNotNone(output)
if output is not None:
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 3)
mel, audio_segment, speaker_id_output = output
self.assertIsInstance(mel, torch.Tensor)
self.assertIsInstance(audio_segment, torch.Tensor)
self.assertIsInstance(speaker_id_output, int)
self.assertEqual(mel.shape, torch.Size([100, 64]))
self.assertEqual(speaker_id_output, speaker_id)
if __name__ == "__main__":
unittest.main()