Spaces:
Sleeping
Sleeping
import os | |
import unittest | |
from lightning.pytorch import Trainer | |
import torch | |
from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig | |
from models.config import PreprocessingConfigUnivNet | |
from models.tts.delightful_tts.delightful_tts import DelightfulTTS | |
from models.vocoder.univnet import UnivNet | |
checkpoint = "checkpoints/logs_44100_tts_80_logs_new3_lightning_logs_version_7_checkpoints_epoch=2450-step=183470.ckpt" | |
# NOTE: this is needed to avoid CUDA_LAUNCH_BLOCKING error | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
class TestDelightfulTTS(unittest.TestCase): | |
def setUp(self): | |
# Create a dummy Trainer instance | |
self.trainer = Trainer() | |
self.preprocessing_config = PreprocessingConfig("multilingual") | |
def test_optim_finetuning(self): | |
module = DelightfulTTS(preprocess_config=self.preprocessing_config) | |
module.trainer = self.trainer | |
optimizer_config = module.configure_optimizers() | |
optimizer = optimizer_config["optimizer"] | |
lr_scheduler = optimizer_config["lr_scheduler"] | |
# Test the optimizer | |
self.assertIsInstance(optimizer, torch.optim.AdamW) | |
self.assertIsInstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR) | |
def test_train_steps(self): | |
default_root_dir = "checkpoints/acoustic" | |
# checkpoint = "checkpoints/logs_new_training_libri-360-swa_multilingual_conf_epoch=146-step=33516.ckpt" | |
trainer = Trainer( | |
# Save checkpoints to the `default_root_dir` directory | |
default_root_dir=default_root_dir, | |
fast_dev_run=1, | |
limit_train_batches=1, | |
max_epochs=1, | |
accelerator="gpu", | |
) | |
module = DelightfulTTS( | |
preprocess_config=self.preprocessing_config, | |
) | |
train_dataloader = module.train_dataloader(cache=False) | |
# automatically restores model, epoch, step, LR schedulers, etc... | |
# trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt") | |
result = trainer.fit(model=module, train_dataloaders=train_dataloader) | |
self.assertIsNone(result) | |
def test_load_from_new_checkpoint(self): | |
try: | |
DelightfulTTS.load_from_checkpoint( | |
checkpoint, | |
strict=False, | |
preprocess_config=self.preprocessing_config, | |
) | |
except Exception as e: | |
self.fail(f"Loading from checkpoint raised an exception: {e}") | |
def test_forward(self): | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = torch.device("cpu") | |
module = DelightfulTTS.load_from_checkpoint( | |
checkpoint, | |
strict=False, | |
map_location=device, | |
preprocess_config=self.preprocessing_config, | |
) | |
univnet = UnivNet() | |
univnet = univnet.to(device) | |
text = """As the snake shook its head, a deafening shout behind Harry made both of them jump. | |
'DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON'T BELIEVE WHAT IT'S DOING!' | |
Dudley came waddling towards them as fast as he could. | |
‘Out of the way, you,’ he said, punching Harry in the ribs. Caught by surprise, Harry fell hard on the concrete floor. What came next happened so fast no one saw how it happened – one second, Piers and Dudley were leaning right up close to the glass, the next, they had leapt back with howls of horror.""" | |
speaker = torch.tensor([2071], device=device) | |
mel_spec = module.forward( | |
text, | |
speaker, | |
) | |
self.assertIsInstance(mel_spec, torch.Tensor) | |