File size: 3,732 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
import argparse
from itertools import product as cartesian_product

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from TTS.config import load_config
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.models import setup_model

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
    parser.add_argument("--config_path", type=str, help="Path to model config file.")
    parser.add_argument("--data_path", type=str, help="Path to data directory.")
    parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
    parser.add_argument(
        "--num_iter",
        type=int,
        help="Number of model inference iterations that you like to optimize noise schedule for.",
    )
    parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
    parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
    parser.add_argument(
        "--search_depth",
        type=int,
        default=3,
        help="Search granularity. Increasing this increases the run-time exponentially.",
    )

    # load config
    args = parser.parse_args()
    config = load_config(args.config_path)

    # setup audio processor
    ap = AudioProcessor(**config.audio)

    # load dataset
    _, train_data = load_wav_data(args.data_path, 0)
    train_data = train_data[: args.num_samples]
    dataset = WaveGradDataset(
        ap=ap,
        items=train_data,
        seq_len=-1,
        hop_len=ap.hop_length,
        pad_short=config.pad_short,
        conv_pad=config.conv_pad,
        is_training=True,
        return_segments=False,
        use_noise_augment=False,
        use_cache=False,
        verbose=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=dataset.collate_full_clips,
        drop_last=False,
        num_workers=config.num_loader_workers,
        pin_memory=False,
    )

    # setup the model
    model = setup_model(config)
    if args.use_cuda:
        model.cuda()

    # setup optimization parameters
    base_values = sorted(10 * np.random.uniform(size=args.search_depth))
    print(f" > base values: {base_values}")
    exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
    best_error = float("inf")
    best_schedule = None  # pylint: disable=C0103
    total_search_iter = len(base_values) ** args.num_iter
    for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
        beta = exponents * base
        model.compute_noise_level(beta)
        for data in loader:
            mel, audio = data
            y_hat = model.inference(mel.cuda() if args.use_cuda else mel)

            if args.use_cuda:
                y_hat = y_hat.cpu()
            y_hat = y_hat.numpy()

            mel_hat = []
            for i in range(y_hat.shape[0]):
                m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
                mel_hat.append(torch.from_numpy(m))

            mel_hat = torch.stack(mel_hat)
            mse = torch.sum((mel - mel_hat) ** 2).mean()
            if mse.item() < best_error:
                best_error = mse.item()
                best_schedule = {"beta": beta}
                print(f" > Found a better schedule. - MSE: {mse.item()}")
                np.save(args.output_path, best_schedule)