keithhon commited on
Commit
4dd64e8
·
1 Parent(s): 6e8c3d6

Upload vocoder/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocoder/train.py +127 -0
vocoder/train.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vocoder.models.fatchord_version import WaveRNN
2
+ from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
3
+ from vocoder.distribution import discretized_mix_logistic_loss
4
+ from vocoder.display import stream, simple_table
5
+ from vocoder.gen_wavernn import gen_testset
6
+ from torch.utils.data import DataLoader
7
+ from pathlib import Path
8
+ from torch import optim
9
+ import torch.nn.functional as F
10
+ import vocoder.hparams as hp
11
+ import numpy as np
12
+ import time
13
+ import torch
14
+ import platform
15
+
16
+ def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool,
17
+ save_every: int, backup_every: int, force_restart: bool):
18
+ # Check to make sure the hop length is correctly factorised
19
+ assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
20
+
21
+ # Instantiate the model
22
+ print("Initializing the model...")
23
+ model = WaveRNN(
24
+ rnn_dims=hp.voc_rnn_dims,
25
+ fc_dims=hp.voc_fc_dims,
26
+ bits=hp.bits,
27
+ pad=hp.voc_pad,
28
+ upsample_factors=hp.voc_upsample_factors,
29
+ feat_dims=hp.num_mels,
30
+ compute_dims=hp.voc_compute_dims,
31
+ res_out_dims=hp.voc_res_out_dims,
32
+ res_blocks=hp.voc_res_blocks,
33
+ hop_length=hp.hop_length,
34
+ sample_rate=hp.sample_rate,
35
+ mode=hp.voc_mode
36
+ )
37
+
38
+ if torch.cuda.is_available():
39
+ model = model.cuda()
40
+ device = torch.device('cuda')
41
+ else:
42
+ device = torch.device('cpu')
43
+
44
+ # Initialize the optimizer
45
+ optimizer = optim.Adam(model.parameters())
46
+ for p in optimizer.param_groups:
47
+ p["lr"] = hp.voc_lr
48
+ loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
49
+
50
+ # Load the weights
51
+ model_dir = models_dir.joinpath(run_id)
52
+ model_dir.mkdir(exist_ok=True)
53
+ weights_fpath = model_dir.joinpath(run_id + ".pt")
54
+ if force_restart or not weights_fpath.exists():
55
+ print("\nStarting the training of WaveRNN from scratch\n")
56
+ model.save(weights_fpath, optimizer)
57
+ else:
58
+ print("\nLoading weights at %s" % weights_fpath)
59
+ model.load(weights_fpath, optimizer)
60
+ print("WaveRNN weights loaded from step %d" % model.step)
61
+
62
+ # Initialize the dataset
63
+ metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
64
+ voc_dir.joinpath("synthesized.txt")
65
+ mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
66
+ wav_dir = syn_dir.joinpath("audio")
67
+ dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
68
+ test_loader = DataLoader(dataset,
69
+ batch_size=1,
70
+ shuffle=True,
71
+ pin_memory=True)
72
+
73
+ # Begin the training
74
+ simple_table([('Batch size', hp.voc_batch_size),
75
+ ('LR', hp.voc_lr),
76
+ ('Sequence Len', hp.voc_seq_len)])
77
+
78
+ for epoch in range(1, 350):
79
+ data_loader = DataLoader(dataset,
80
+ collate_fn=collate_vocoder,
81
+ batch_size=hp.voc_batch_size,
82
+ num_workers=2 if platform.system() != "Windows" else 0,
83
+ shuffle=True,
84
+ pin_memory=True)
85
+ start = time.time()
86
+ running_loss = 0.
87
+
88
+ for i, (x, y, m) in enumerate(data_loader, 1):
89
+ if torch.cuda.is_available():
90
+ x, m, y = x.cuda(), m.cuda(), y.cuda()
91
+
92
+ # Forward pass
93
+ y_hat = model(x, m)
94
+ if model.mode == 'RAW':
95
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
96
+ elif model.mode == 'MOL':
97
+ y = y.float()
98
+ y = y.unsqueeze(-1)
99
+
100
+ # Backward pass
101
+ loss = loss_func(y_hat, y)
102
+ optimizer.zero_grad()
103
+ loss.backward()
104
+ optimizer.step()
105
+
106
+ running_loss += loss.item()
107
+ speed = i / (time.time() - start)
108
+ avg_loss = running_loss / i
109
+
110
+ step = model.get_step()
111
+ k = step // 1000
112
+
113
+ if backup_every != 0 and step % backup_every == 0 :
114
+ model.checkpoint(model_dir, optimizer)
115
+
116
+ if save_every != 0 and step % save_every == 0 :
117
+ model.save(weights_fpath, optimizer)
118
+
119
+ msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
120
+ f"Loss: {avg_loss:.4f} | {speed:.1f} " \
121
+ f"steps/s | Step: {k}k | "
122
+ stream(msg)
123
+
124
+
125
+ gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
126
+ hp.voc_target, hp.voc_overlap, model_dir)
127
+ print("")