File size: 2,214 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import unittest

from lightning.pytorch import Trainer
import torch

from models.config import VocoderFinetuningConfig, VocoderPretrainingConfig
from models.vocoder.univnet import UnivNet

# NOTE: this is needed to avoid CUDA_LAUNCH_BLOCKING error
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

class TestUnivNet(unittest.TestCase):
    def test_optim_finetuning(self):
        module = UnivNet()

        self.assertIsInstance(module.train_config, VocoderPretrainingConfig)

        optimizer_configs = module.configure_optimizers()

        for optimizer_config in optimizer_configs:
            optimizer = optimizer_config["optimizer"]
            lr_scheduler = optimizer_config["lr_scheduler"]

            # Test the optimizer
            self.assertIsInstance(optimizer, torch.optim.AdamW)
            self.assertIsInstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)

    def test_finetuning(self):
        module = UnivNet(fine_tuning=True)

        self.assertIsInstance(module.train_config, VocoderFinetuningConfig)

    def test_train_step(self):
        trainer = Trainer(
            # Save checkpoints to the `default_root_dir` directory
            default_root_dir="checkpoints/vocoder",
            limit_train_batches=1,
            max_epochs=1,
            accelerator="cpu",
        )

        # Load the pretrained weights
        # NOTE: this is the path to the checkpoint in the repo
        # It works only for version 0.1.0 checkpoint
        # This code will be removed in the future!
        # checkpoint_path = "models/checkpoints/assets/v0.1.0/vocoder_pretrained.pt"

        module = UnivNet(batch_size=1, acc_grad_steps=1)

        train_dataloader = module.train_dataloader(2, cache=False, mem_cache=False)

        result = trainer.fit(model=module, train_dataloaders=train_dataloader)

        self.assertIsNone(result)

    # def test_load_from_checkpoint(self):
    #     try:
    #         UnivNet.load_from_checkpoint(
    #             "./checkpoints/vocoder.ckpt",
    #         )
    #     except Exception as e:
    #         self.fail(f"Loading from checkpoint raised an exception: {e}")