nickovchinnikov's picture
Init
9d61c9b
import unittest
import torch
from models.config import (
AcousticENModelConfig,
AcousticPretrainingConfig,
)
from models.config import (
PreprocessingConfigUnivNet as PreprocessingConfig,
)
from models.helpers.initializer import (
init_acoustic_model,
init_conformer,
init_forward_trains_params,
init_mask_input_embeddings_encoding_attn_mask,
)
from models.tts.delightful_tts.attention.conformer import Conformer
# Conformer is used in the encoder of the AccousticModel, crucial for the training
# Here you can understand the input and output shapes of the Conformer
# Integration test
class TestConformer(unittest.TestCase):
def setUp(self):
self.acoustic_pretraining_config = AcousticPretrainingConfig()
self.model_config = AcousticENModelConfig()
self.preprocess_config = PreprocessingConfig("english_only")
# Based on speaker.json mock
n_speakers = 10
# # Add Conformer as encoder
self.encoder, _ = init_conformer(self.model_config)
# Add AcousticModel instance
self.acoustic_model, _ = init_acoustic_model(
self.preprocess_config,
self.model_config,
n_speakers,
)
# Generate mock data for the forward pass
self.forward_train_params = init_forward_trains_params(
self.model_config,
self.acoustic_pretraining_config,
self.preprocess_config,
n_speakers,
)
def test_initialization(self):
"""Test that a Conformer instance is correctly initialized."""
self.assertIsInstance(self.encoder, Conformer)
def test_forward(self):
"""Test that a Conformer instance can correctly perform a forward pass.
For this test case we use the code from AccousticModel.
"""
(
src_mask,
x,
embeddings,
encoding,
_,
) = init_mask_input_embeddings_encoding_attn_mask(
self.acoustic_model,
self.forward_train_params,
self.model_config,
)
# Assert the shape of x
self.assertEqual(
x.shape,
torch.Size(
[
self.model_config.speaker_embed_dim,
self.acoustic_pretraining_config.batch_size,
self.model_config.speaker_embed_dim // 2,
],
),
)
# Assert the shape of embeddings
self.assertEqual(
embeddings.shape,
torch.Size(
[
self.model_config.speaker_embed_dim,
self.acoustic_pretraining_config.batch_size,
self.model_config.speaker_embed_dim
+ self.model_config.lang_embed_dim,
],
),
)
# Run conformer encoder
# x: Tensor containing the encoded sequences. Shape: [speaker_embed_dim, batch_size, speaker_embed_dim]
x = self.encoder(x, src_mask, embeddings=embeddings, encoding=encoding)
# Assert the shape of x
self.assertEqual(
x.shape,
torch.Size(
[
self.model_config.speaker_embed_dim,
self.acoustic_pretraining_config.batch_size,
self.model_config.speaker_embed_dim // 2,
],
),
)
if __name__ == "__main__":
unittest.main()