nickovchinnikov's picture
Init
9d61c9b
import os
import unittest
from lightning.pytorch import Trainer
import torch
from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig
from models.config import PreprocessingConfigUnivNet
from models.tts.delightful_tts.delightful_tts import DelightfulTTS
from models.vocoder.univnet import UnivNet
checkpoint = "checkpoints/logs_44100_tts_80_logs_new3_lightning_logs_version_7_checkpoints_epoch=2450-step=183470.ckpt"
# NOTE: this is needed to avoid CUDA_LAUNCH_BLOCKING error
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
class TestDelightfulTTS(unittest.TestCase):
def setUp(self):
# Create a dummy Trainer instance
self.trainer = Trainer()
self.preprocessing_config = PreprocessingConfig("multilingual")
def test_optim_finetuning(self):
module = DelightfulTTS(preprocess_config=self.preprocessing_config)
module.trainer = self.trainer
optimizer_config = module.configure_optimizers()
optimizer = optimizer_config["optimizer"]
lr_scheduler = optimizer_config["lr_scheduler"]
# Test the optimizer
self.assertIsInstance(optimizer, torch.optim.AdamW)
self.assertIsInstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)
def test_train_steps(self):
default_root_dir = "checkpoints/acoustic"
# checkpoint = "checkpoints/logs_new_training_libri-360-swa_multilingual_conf_epoch=146-step=33516.ckpt"
trainer = Trainer(
# Save checkpoints to the `default_root_dir` directory
default_root_dir=default_root_dir,
fast_dev_run=1,
limit_train_batches=1,
max_epochs=1,
accelerator="gpu",
)
module = DelightfulTTS(
preprocess_config=self.preprocessing_config,
)
train_dataloader = module.train_dataloader(cache=False)
# automatically restores model, epoch, step, LR schedulers, etc...
# trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
result = trainer.fit(model=module, train_dataloaders=train_dataloader)
self.assertIsNone(result)
def test_load_from_new_checkpoint(self):
try:
DelightfulTTS.load_from_checkpoint(
checkpoint,
strict=False,
preprocess_config=self.preprocessing_config,
)
except Exception as e:
self.fail(f"Loading from checkpoint raised an exception: {e}")
def test_forward(self):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
module = DelightfulTTS.load_from_checkpoint(
checkpoint,
strict=False,
map_location=device,
preprocess_config=self.preprocessing_config,
)
univnet = UnivNet()
univnet = univnet.to(device)
text = """As the snake shook its head, a deafening shout behind Harry made both of them jump.
'DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON'T BELIEVE WHAT IT'S DOING!'
Dudley came waddling towards them as fast as he could.
‘Out of the way, you,’ he said, punching Harry in the ribs. Caught by surprise, Harry fell hard on the concrete floor. What came next happened so fast no one saw how it happened – one second, Piers and Dudley were leaning right up close to the glass, the next, they had leapt back with howls of horror."""
speaker = torch.tensor([2071], device=device)
mel_spec = module.forward(
text,
speaker,
)
self.assertIsInstance(mel_spec, torch.Tensor)