Spaces:
Runtime error
Runtime error
Upload vocoder/train.py with huggingface_hub
Browse files- vocoder/train.py +127 -0
vocoder/train.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vocoder.models.fatchord_version import WaveRNN
|
2 |
+
from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
|
3 |
+
from vocoder.distribution import discretized_mix_logistic_loss
|
4 |
+
from vocoder.display import stream, simple_table
|
5 |
+
from vocoder.gen_wavernn import gen_testset
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from pathlib import Path
|
8 |
+
from torch import optim
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import vocoder.hparams as hp
|
11 |
+
import numpy as np
|
12 |
+
import time
|
13 |
+
import torch
|
14 |
+
import platform
|
15 |
+
|
16 |
+
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool,
|
17 |
+
save_every: int, backup_every: int, force_restart: bool):
|
18 |
+
# Check to make sure the hop length is correctly factorised
|
19 |
+
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
|
20 |
+
|
21 |
+
# Instantiate the model
|
22 |
+
print("Initializing the model...")
|
23 |
+
model = WaveRNN(
|
24 |
+
rnn_dims=hp.voc_rnn_dims,
|
25 |
+
fc_dims=hp.voc_fc_dims,
|
26 |
+
bits=hp.bits,
|
27 |
+
pad=hp.voc_pad,
|
28 |
+
upsample_factors=hp.voc_upsample_factors,
|
29 |
+
feat_dims=hp.num_mels,
|
30 |
+
compute_dims=hp.voc_compute_dims,
|
31 |
+
res_out_dims=hp.voc_res_out_dims,
|
32 |
+
res_blocks=hp.voc_res_blocks,
|
33 |
+
hop_length=hp.hop_length,
|
34 |
+
sample_rate=hp.sample_rate,
|
35 |
+
mode=hp.voc_mode
|
36 |
+
)
|
37 |
+
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
model = model.cuda()
|
40 |
+
device = torch.device('cuda')
|
41 |
+
else:
|
42 |
+
device = torch.device('cpu')
|
43 |
+
|
44 |
+
# Initialize the optimizer
|
45 |
+
optimizer = optim.Adam(model.parameters())
|
46 |
+
for p in optimizer.param_groups:
|
47 |
+
p["lr"] = hp.voc_lr
|
48 |
+
loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
|
49 |
+
|
50 |
+
# Load the weights
|
51 |
+
model_dir = models_dir.joinpath(run_id)
|
52 |
+
model_dir.mkdir(exist_ok=True)
|
53 |
+
weights_fpath = model_dir.joinpath(run_id + ".pt")
|
54 |
+
if force_restart or not weights_fpath.exists():
|
55 |
+
print("\nStarting the training of WaveRNN from scratch\n")
|
56 |
+
model.save(weights_fpath, optimizer)
|
57 |
+
else:
|
58 |
+
print("\nLoading weights at %s" % weights_fpath)
|
59 |
+
model.load(weights_fpath, optimizer)
|
60 |
+
print("WaveRNN weights loaded from step %d" % model.step)
|
61 |
+
|
62 |
+
# Initialize the dataset
|
63 |
+
metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
|
64 |
+
voc_dir.joinpath("synthesized.txt")
|
65 |
+
mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
|
66 |
+
wav_dir = syn_dir.joinpath("audio")
|
67 |
+
dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
|
68 |
+
test_loader = DataLoader(dataset,
|
69 |
+
batch_size=1,
|
70 |
+
shuffle=True,
|
71 |
+
pin_memory=True)
|
72 |
+
|
73 |
+
# Begin the training
|
74 |
+
simple_table([('Batch size', hp.voc_batch_size),
|
75 |
+
('LR', hp.voc_lr),
|
76 |
+
('Sequence Len', hp.voc_seq_len)])
|
77 |
+
|
78 |
+
for epoch in range(1, 350):
|
79 |
+
data_loader = DataLoader(dataset,
|
80 |
+
collate_fn=collate_vocoder,
|
81 |
+
batch_size=hp.voc_batch_size,
|
82 |
+
num_workers=2 if platform.system() != "Windows" else 0,
|
83 |
+
shuffle=True,
|
84 |
+
pin_memory=True)
|
85 |
+
start = time.time()
|
86 |
+
running_loss = 0.
|
87 |
+
|
88 |
+
for i, (x, y, m) in enumerate(data_loader, 1):
|
89 |
+
if torch.cuda.is_available():
|
90 |
+
x, m, y = x.cuda(), m.cuda(), y.cuda()
|
91 |
+
|
92 |
+
# Forward pass
|
93 |
+
y_hat = model(x, m)
|
94 |
+
if model.mode == 'RAW':
|
95 |
+
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
96 |
+
elif model.mode == 'MOL':
|
97 |
+
y = y.float()
|
98 |
+
y = y.unsqueeze(-1)
|
99 |
+
|
100 |
+
# Backward pass
|
101 |
+
loss = loss_func(y_hat, y)
|
102 |
+
optimizer.zero_grad()
|
103 |
+
loss.backward()
|
104 |
+
optimizer.step()
|
105 |
+
|
106 |
+
running_loss += loss.item()
|
107 |
+
speed = i / (time.time() - start)
|
108 |
+
avg_loss = running_loss / i
|
109 |
+
|
110 |
+
step = model.get_step()
|
111 |
+
k = step // 1000
|
112 |
+
|
113 |
+
if backup_every != 0 and step % backup_every == 0 :
|
114 |
+
model.checkpoint(model_dir, optimizer)
|
115 |
+
|
116 |
+
if save_every != 0 and step % save_every == 0 :
|
117 |
+
model.save(weights_fpath, optimizer)
|
118 |
+
|
119 |
+
msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
|
120 |
+
f"Loss: {avg_loss:.4f} | {speed:.1f} " \
|
121 |
+
f"steps/s | Step: {k}k | "
|
122 |
+
stream(msg)
|
123 |
+
|
124 |
+
|
125 |
+
gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
|
126 |
+
hp.voc_target, hp.voc_overlap, model_dir)
|
127 |
+
print("")
|