Kremon96 commited on
Commit
102d284
·
verified ·
1 Parent(s): df438fc

Delete vocoder

Browse files
vocoder/LICENSE.txt DELETED
@@ -1,22 +0,0 @@
1
- MIT License
2
-
3
- Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
4
- Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
5
-
6
- Permission is hereby granted, free of charge, to any person obtaining a copy
7
- of this software and associated documentation files (the "Software"), to deal
8
- in the Software without restriction, including without limitation the rights
9
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
- copies of the Software, and to permit persons to whom the Software is
11
- furnished to do so, subject to the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be included in all
14
- copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/audio.py DELETED
@@ -1,108 +0,0 @@
1
- import math
2
- import numpy as np
3
- import librosa
4
- import vocoder.hparams as hp
5
- from scipy.signal import lfilter
6
- import soundfile as sf
7
-
8
-
9
- def label_2_float(x, bits) :
10
- return 2 * x / (2**bits - 1.) - 1.
11
-
12
-
13
- def float_2_label(x, bits) :
14
- assert abs(x).max() <= 1.0
15
- x = (x + 1.) * (2**bits - 1) / 2
16
- return x.clip(0, 2**bits - 1)
17
-
18
-
19
- def load_wav(path) :
20
- return librosa.load(str(path), sr=hp.sample_rate)[0]
21
-
22
-
23
- def save_wav(x, path) :
24
- sf.write(path, x.astype(np.float32), hp.sample_rate)
25
-
26
-
27
- def split_signal(x) :
28
- unsigned = x + 2**15
29
- coarse = unsigned // 256
30
- fine = unsigned % 256
31
- return coarse, fine
32
-
33
-
34
- def combine_signal(coarse, fine) :
35
- return coarse * 256 + fine - 2**15
36
-
37
-
38
- def encode_16bits(x) :
39
- return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
40
-
41
-
42
- mel_basis = None
43
-
44
-
45
- def linear_to_mel(spectrogram):
46
- global mel_basis
47
- if mel_basis is None:
48
- mel_basis = build_mel_basis()
49
- return np.dot(mel_basis, spectrogram)
50
-
51
-
52
- def build_mel_basis():
53
- return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)
54
-
55
-
56
- def normalize(S):
57
- return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
58
-
59
-
60
- def denormalize(S):
61
- return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
62
-
63
-
64
- def amp_to_db(x):
65
- return 20 * np.log10(np.maximum(1e-5, x))
66
-
67
-
68
- def db_to_amp(x):
69
- return np.power(10.0, x * 0.05)
70
-
71
-
72
- def spectrogram(y):
73
- D = stft(y)
74
- S = amp_to_db(np.abs(D)) - hp.ref_level_db
75
- return normalize(S)
76
-
77
-
78
- def melspectrogram(y):
79
- D = stft(y)
80
- S = amp_to_db(linear_to_mel(np.abs(D)))
81
- return normalize(S)
82
-
83
-
84
- def stft(y):
85
- return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
86
-
87
-
88
- def pre_emphasis(x):
89
- return lfilter([1, -hp.preemphasis], [1], x)
90
-
91
-
92
- def de_emphasis(x):
93
- return lfilter([1], [1, -hp.preemphasis], x)
94
-
95
-
96
- def encode_mu_law(x, mu) :
97
- mu = mu - 1
98
- fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
99
- return np.floor((fx + 1) / 2 * mu + 0.5)
100
-
101
-
102
- def decode_mu_law(y, mu, from_labels=True) :
103
- if from_labels:
104
- y = label_2_float(y, math.log2(mu))
105
- mu = mu - 1
106
- x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
107
- return x
108
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/display.py DELETED
@@ -1,127 +0,0 @@
1
- import time
2
- import numpy as np
3
- import sys
4
-
5
-
6
- def progbar(i, n, size=16):
7
- done = (i * size) // n
8
- bar = ''
9
- for i in range(size):
10
- bar += '█' if i <= done else '░'
11
- return bar
12
-
13
-
14
- def stream(message) :
15
- try:
16
- sys.stdout.write("\r{%s}" % message)
17
- except:
18
- #Remove non-ASCII characters from message
19
- message = ''.join(i for i in message if ord(i)<128)
20
- sys.stdout.write("\r{%s}" % message)
21
-
22
-
23
- def simple_table(item_tuples) :
24
-
25
- border_pattern = '+---------------------------------------'
26
- whitespace = ' '
27
-
28
- headings, cells, = [], []
29
-
30
- for item in item_tuples :
31
-
32
- heading, cell = str(item[0]), str(item[1])
33
-
34
- pad_head = True if len(heading) < len(cell) else False
35
-
36
- pad = abs(len(heading) - len(cell))
37
- pad = whitespace[:pad]
38
-
39
- pad_left = pad[:len(pad)//2]
40
- pad_right = pad[len(pad)//2:]
41
-
42
- if pad_head :
43
- heading = pad_left + heading + pad_right
44
- else :
45
- cell = pad_left + cell + pad_right
46
-
47
- headings += [heading]
48
- cells += [cell]
49
-
50
- border, head, body = '', '', ''
51
-
52
- for i in range(len(item_tuples)) :
53
-
54
- temp_head = f'| {headings[i]} '
55
- temp_body = f'| {cells[i]} '
56
-
57
- border += border_pattern[:len(temp_head)]
58
- head += temp_head
59
- body += temp_body
60
-
61
- if i == len(item_tuples) - 1 :
62
- head += '|'
63
- body += '|'
64
- border += '+'
65
-
66
- print(border)
67
- print(head)
68
- print(border)
69
- print(body)
70
- print(border)
71
- print(' ')
72
-
73
-
74
- def time_since(started) :
75
- elapsed = time.time() - started
76
- m = int(elapsed // 60)
77
- s = int(elapsed % 60)
78
- if m >= 60 :
79
- h = int(m // 60)
80
- m = m % 60
81
- return f'{h}h {m}m {s}s'
82
- else :
83
- return f'{m}m {s}s'
84
-
85
-
86
- def save_attention(attn, path):
87
- import matplotlib.pyplot as plt
88
-
89
- fig = plt.figure(figsize=(12, 6))
90
- plt.imshow(attn.T, interpolation='nearest', aspect='auto')
91
- fig.savefig(f'{path}.png', bbox_inches='tight')
92
- plt.close(fig)
93
-
94
-
95
- def save_spectrogram(M, path, length=None):
96
- import matplotlib.pyplot as plt
97
-
98
- M = np.flip(M, axis=0)
99
- if length : M = M[:, :length]
100
- fig = plt.figure(figsize=(12, 6))
101
- plt.imshow(M, interpolation='nearest', aspect='auto')
102
- fig.savefig(f'{path}.png', bbox_inches='tight')
103
- plt.close(fig)
104
-
105
-
106
- def plot(array):
107
- import matplotlib.pyplot as plt
108
-
109
- fig = plt.figure(figsize=(30, 5))
110
- ax = fig.add_subplot(111)
111
- ax.xaxis.label.set_color('grey')
112
- ax.yaxis.label.set_color('grey')
113
- ax.xaxis.label.set_fontsize(23)
114
- ax.yaxis.label.set_fontsize(23)
115
- ax.tick_params(axis='x', colors='grey', labelsize=23)
116
- ax.tick_params(axis='y', colors='grey', labelsize=23)
117
- plt.plot(array)
118
-
119
-
120
- def plot_spec(M):
121
- import matplotlib.pyplot as plt
122
-
123
- M = np.flip(M, axis=0)
124
- plt.figure(figsize=(18,4))
125
- plt.imshow(M, interpolation='nearest', aspect='auto')
126
- plt.show()
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/distribution.py DELETED
@@ -1,132 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
-
5
-
6
- def log_sum_exp(x):
7
- """ numerically stable log_sum_exp implementation that prevents overflow """
8
- # TF ordering
9
- axis = len(x.size()) - 1
10
- m, _ = torch.max(x, dim=axis)
11
- m2, _ = torch.max(x, dim=axis, keepdim=True)
12
- return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13
-
14
-
15
- # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16
- def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
17
- log_scale_min=None, reduce=True):
18
- if log_scale_min is None:
19
- log_scale_min = float(np.log(1e-14))
20
- y_hat = y_hat.permute(0,2,1)
21
- assert y_hat.dim() == 3
22
- assert y_hat.size(1) % 3 == 0
23
- nr_mix = y_hat.size(1) // 3
24
-
25
- # (B x T x C)
26
- y_hat = y_hat.transpose(1, 2)
27
-
28
- # unpack parameters. (B, T, num_mixtures) x 3
29
- logit_probs = y_hat[:, :, :nr_mix]
30
- means = y_hat[:, :, nr_mix:2 * nr_mix]
31
- log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
32
-
33
- # B x T x 1 -> B x T x num_mixtures
34
- y = y.expand_as(means)
35
-
36
- centered_y = y - means
37
- inv_stdv = torch.exp(-log_scales)
38
- plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
39
- cdf_plus = torch.sigmoid(plus_in)
40
- min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
41
- cdf_min = torch.sigmoid(min_in)
42
-
43
- # log probability for edge case of 0 (before scaling)
44
- # equivalent: torch.log(F.sigmoid(plus_in))
45
- log_cdf_plus = plus_in - F.softplus(plus_in)
46
-
47
- # log probability for edge case of 255 (before scaling)
48
- # equivalent: (1 - F.sigmoid(min_in)).log()
49
- log_one_minus_cdf_min = -F.softplus(min_in)
50
-
51
- # probability for all other cases
52
- cdf_delta = cdf_plus - cdf_min
53
-
54
- mid_in = inv_stdv * centered_y
55
- # log probability in the center of the bin, to be used in extreme cases
56
- # (not actually used in our code)
57
- log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
58
-
59
- # tf equivalent
60
- """
61
- log_probs = tf.where(x < -0.999, log_cdf_plus,
62
- tf.where(x > 0.999, log_one_minus_cdf_min,
63
- tf.where(cdf_delta > 1e-5,
64
- tf.log(tf.maximum(cdf_delta, 1e-12)),
65
- log_pdf_mid - np.log(127.5))))
66
- """
67
- # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
68
- # for num_classes=65536 case? 1e-7? not sure..
69
- inner_inner_cond = (cdf_delta > 1e-5).float()
70
-
71
- inner_inner_out = inner_inner_cond * \
72
- torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
73
- (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
74
- inner_cond = (y > 0.999).float()
75
- inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
76
- cond = (y < -0.999).float()
77
- log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
78
-
79
- log_probs = log_probs + F.log_softmax(logit_probs, -1)
80
-
81
- if reduce:
82
- return -torch.mean(log_sum_exp(log_probs))
83
- else:
84
- return -log_sum_exp(log_probs).unsqueeze(-1)
85
-
86
-
87
- def sample_from_discretized_mix_logistic(y, log_scale_min=None):
88
- """
89
- Sample from discretized mixture of logistic distributions
90
- Args:
91
- y (Tensor): B x C x T
92
- log_scale_min (float): Log scale minimum value
93
- Returns:
94
- Tensor: sample in range of [-1, 1].
95
- """
96
- if log_scale_min is None:
97
- log_scale_min = float(np.log(1e-14))
98
- assert y.size(1) % 3 == 0
99
- nr_mix = y.size(1) // 3
100
-
101
- # B x T x C
102
- y = y.transpose(1, 2)
103
- logit_probs = y[:, :, :nr_mix]
104
-
105
- # sample mixture indicator from softmax
106
- temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
107
- temp = logit_probs.data - torch.log(- torch.log(temp))
108
- _, argmax = temp.max(dim=-1)
109
-
110
- # (B, T) -> (B, T, nr_mix)
111
- one_hot = to_one_hot(argmax, nr_mix)
112
- # select logistic parameters
113
- means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
114
- log_scales = torch.clamp(torch.sum(
115
- y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
116
- # sample from logistic & clip to interval
117
- # we don't actually round to the nearest 8bit value when sampling
118
- u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
119
- x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
120
-
121
- x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
122
-
123
- return x
124
-
125
-
126
- def to_one_hot(tensor, n, fill_with=1.):
127
- # we perform one hot encore with respect to the last axis
128
- one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
129
- if tensor.is_cuda:
130
- one_hot = one_hot.cuda()
131
- one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
132
- return one_hot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/gen_wavernn.py DELETED
@@ -1,31 +0,0 @@
1
- from vocoder.models.fatchord_version import WaveRNN
2
- from vocoder.audio import *
3
-
4
-
5
- def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path):
6
- k = model.get_step() // 1000
7
-
8
- for i, (m, x) in enumerate(test_set, 1):
9
- if i > samples:
10
- break
11
-
12
- print('\n| Generating: %i/%i' % (i, samples))
13
-
14
- x = x[0].numpy()
15
-
16
- bits = 16 if hp.voc_mode == 'MOL' else hp.bits
17
-
18
- if hp.mu_law and hp.voc_mode != 'MOL' :
19
- x = decode_mu_law(x, 2**bits, from_labels=True)
20
- else :
21
- x = label_2_float(x, bits)
22
-
23
- save_wav(x, save_path.joinpath("%dk_steps_%d_target.wav" % (k, i)))
24
-
25
- batch_str = "gen_batched_target%d_overlap%d" % (target, overlap) if batched else \
26
- "gen_not_batched"
27
- save_str = save_path.joinpath("%dk_steps_%d_%s.wav" % (k, i, batch_str))
28
-
29
- wav = model.generate(m, batched, target, overlap, hp.mu_law)
30
- save_wav(wav, save_str)
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/hparams.py DELETED
@@ -1,44 +0,0 @@
1
- from synthesizer.hparams import hparams as _syn_hp
2
-
3
-
4
- # Audio settings------------------------------------------------------------------------
5
- # Match the values of the synthesizer
6
- sample_rate = _syn_hp.sample_rate
7
- n_fft = _syn_hp.n_fft
8
- num_mels = _syn_hp.num_mels
9
- hop_length = _syn_hp.hop_size
10
- win_length = _syn_hp.win_size
11
- fmin = _syn_hp.fmin
12
- min_level_db = _syn_hp.min_level_db
13
- ref_level_db = _syn_hp.ref_level_db
14
- mel_max_abs_value = _syn_hp.max_abs_value
15
- preemphasis = _syn_hp.preemphasis
16
- apply_preemphasis = _syn_hp.preemphasize
17
-
18
- bits = 9 # bit depth of signal
19
- mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode
20
- # below
21
-
22
-
23
- # WAVERNN / VOCODER --------------------------------------------------------------------------------
24
- voc_mode = 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from
25
- # mixture of logistics)
26
- voc_upsample_factors = (5, 5, 8) # NB - this needs to correctly factorise hop_length
27
- voc_rnn_dims = 512
28
- voc_fc_dims = 512
29
- voc_compute_dims = 128
30
- voc_res_out_dims = 128
31
- voc_res_blocks = 10
32
-
33
- # Training
34
- voc_batch_size = 100
35
- voc_lr = 1e-4
36
- voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
37
- voc_pad = 2 # this will pad the input so that the resnet can 'see' wider
38
- # than input length
39
- voc_seq_len = hop_length * 5 # must be a multiple of hop_length
40
-
41
- # Generating / Synthesizing
42
- voc_gen_batched = True # very fast (realtime+) single utterance batched generation
43
- voc_target = 8000 # target number of samples to be generated in each batch entry
44
- voc_overlap = 400 # number of samples for crossfading between batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/inference.py DELETED
@@ -1,64 +0,0 @@
1
- from vocoder.models.fatchord_version import WaveRNN
2
- from vocoder import hparams as hp
3
- import torch
4
-
5
-
6
- _model = None # type: WaveRNN
7
-
8
- def load_model(weights_fpath, verbose=True):
9
- global _model, _device
10
-
11
- if verbose:
12
- print("Building Wave-RNN")
13
- _model = WaveRNN(
14
- rnn_dims=hp.voc_rnn_dims,
15
- fc_dims=hp.voc_fc_dims,
16
- bits=hp.bits,
17
- pad=hp.voc_pad,
18
- upsample_factors=hp.voc_upsample_factors,
19
- feat_dims=hp.num_mels,
20
- compute_dims=hp.voc_compute_dims,
21
- res_out_dims=hp.voc_res_out_dims,
22
- res_blocks=hp.voc_res_blocks,
23
- hop_length=hp.hop_length,
24
- sample_rate=hp.sample_rate,
25
- mode=hp.voc_mode
26
- )
27
-
28
- if torch.cuda.is_available():
29
- _model = _model.cuda()
30
- _device = torch.device('cuda')
31
- else:
32
- _device = torch.device('cpu')
33
-
34
- if verbose:
35
- print("Loading model weights at %s" % weights_fpath)
36
- checkpoint = torch.load(weights_fpath, _device)
37
- _model.load_state_dict(checkpoint['model_state'])
38
- _model.eval()
39
-
40
-
41
- def is_loaded():
42
- return _model is not None
43
-
44
-
45
- def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
46
- progress_callback=None):
47
- """
48
- Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
49
- that of the synthesizer!)
50
-
51
- :param normalize:
52
- :param batched:
53
- :param target:
54
- :param overlap:
55
- :return:
56
- """
57
- if _model is None:
58
- raise Exception("Please load Wave-RNN in memory before using it")
59
-
60
- if normalize:
61
- mel = mel / hp.mel_max_abs_value
62
- mel = torch.from_numpy(mel[None, ...])
63
- wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
64
- return wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/models/deepmind_version.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from utils.display import *
5
- from utils.dsp import *
6
-
7
-
8
- class WaveRNN(nn.Module) :
9
- def __init__(self, hidden_size=896, quantisation=256) :
10
- super(WaveRNN, self).__init__()
11
-
12
- self.hidden_size = hidden_size
13
- self.split_size = hidden_size // 2
14
-
15
- # The main matmul
16
- self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
17
-
18
- # Output fc layers
19
- self.O1 = nn.Linear(self.split_size, self.split_size)
20
- self.O2 = nn.Linear(self.split_size, quantisation)
21
- self.O3 = nn.Linear(self.split_size, self.split_size)
22
- self.O4 = nn.Linear(self.split_size, quantisation)
23
-
24
- # Input fc layers
25
- self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
26
- self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
27
-
28
- # biases for the gates
29
- self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
30
- self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
31
- self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
32
-
33
- # display num params
34
- self.num_params()
35
-
36
-
37
- def forward(self, prev_y, prev_hidden, current_coarse) :
38
-
39
- # Main matmul - the projection is split 3 ways
40
- R_hidden = self.R(prev_hidden)
41
- R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
42
-
43
- # Project the prev input
44
- coarse_input_proj = self.I_coarse(prev_y)
45
- I_coarse_u, I_coarse_r, I_coarse_e = \
46
- torch.split(coarse_input_proj, self.split_size, dim=1)
47
-
48
- # Project the prev input and current coarse sample
49
- fine_input = torch.cat([prev_y, current_coarse], dim=1)
50
- fine_input_proj = self.I_fine(fine_input)
51
- I_fine_u, I_fine_r, I_fine_e = \
52
- torch.split(fine_input_proj, self.split_size, dim=1)
53
-
54
- # concatenate for the gates
55
- I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
56
- I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
57
- I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
58
-
59
- # Compute all gates for coarse and fine
60
- u = F.sigmoid(R_u + I_u + self.bias_u)
61
- r = F.sigmoid(R_r + I_r + self.bias_r)
62
- e = F.tanh(r * R_e + I_e + self.bias_e)
63
- hidden = u * prev_hidden + (1. - u) * e
64
-
65
- # Split the hidden state
66
- hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
67
-
68
- # Compute outputs
69
- out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
70
- out_fine = self.O4(F.relu(self.O3(hidden_fine)))
71
-
72
- return out_coarse, out_fine, hidden
73
-
74
-
75
- def generate(self, seq_len):
76
- with torch.no_grad():
77
- # First split up the biases for the gates
78
- b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
79
- b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
80
- b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
81
-
82
- # Lists for the two output seqs
83
- c_outputs, f_outputs = [], []
84
-
85
- # Some initial inputs
86
- out_coarse = torch.LongTensor([0]).cuda()
87
- out_fine = torch.LongTensor([0]).cuda()
88
-
89
- # We'll meed a hidden state
90
- hidden = self.init_hidden()
91
-
92
- # Need a clock for display
93
- start = time.time()
94
-
95
- # Loop for generation
96
- for i in range(seq_len) :
97
-
98
- # Split into two hidden states
99
- hidden_coarse, hidden_fine = \
100
- torch.split(hidden, self.split_size, dim=1)
101
-
102
- # Scale and concat previous predictions
103
- out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
104
- out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
105
- prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
106
-
107
- # Project input
108
- coarse_input_proj = self.I_coarse(prev_outputs)
109
- I_coarse_u, I_coarse_r, I_coarse_e = \
110
- torch.split(coarse_input_proj, self.split_size, dim=1)
111
-
112
- # Project hidden state and split 6 ways
113
- R_hidden = self.R(hidden)
114
- R_coarse_u , R_fine_u, \
115
- R_coarse_r, R_fine_r, \
116
- R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
117
-
118
- # Compute the coarse gates
119
- u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
120
- r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
121
- e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
122
- hidden_coarse = u * hidden_coarse + (1. - u) * e
123
-
124
- # Compute the coarse output
125
- out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
126
- posterior = F.softmax(out_coarse, dim=1)
127
- distrib = torch.distributions.Categorical(posterior)
128
- out_coarse = distrib.sample()
129
- c_outputs.append(out_coarse)
130
-
131
- # Project the [prev outputs and predicted coarse sample]
132
- coarse_pred = out_coarse.float() / 127.5 - 1.
133
- fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
134
- fine_input_proj = self.I_fine(fine_input)
135
- I_fine_u, I_fine_r, I_fine_e = \
136
- torch.split(fine_input_proj, self.split_size, dim=1)
137
-
138
- # Compute the fine gates
139
- u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
140
- r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
141
- e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
142
- hidden_fine = u * hidden_fine + (1. - u) * e
143
-
144
- # Compute the fine output
145
- out_fine = self.O4(F.relu(self.O3(hidden_fine)))
146
- posterior = F.softmax(out_fine, dim=1)
147
- distrib = torch.distributions.Categorical(posterior)
148
- out_fine = distrib.sample()
149
- f_outputs.append(out_fine)
150
-
151
- # Put the hidden state back together
152
- hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
153
-
154
- # Display progress
155
- speed = (i + 1) / (time.time() - start)
156
- stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
157
-
158
- coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
159
- fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
160
- output = combine_signal(coarse, fine)
161
-
162
- return output, coarse, fine
163
-
164
- def init_hidden(self, batch_size=1) :
165
- return torch.zeros(batch_size, self.hidden_size).cuda()
166
-
167
- def num_params(self) :
168
- parameters = filter(lambda p: p.requires_grad, self.parameters())
169
- parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
170
- print('Trainable Parameters: %.3f million' % parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/models/fatchord_version.py DELETED
@@ -1,434 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from vocoder.distribution import sample_from_discretized_mix_logistic
5
- from vocoder.display import *
6
- from vocoder.audio import *
7
-
8
-
9
- class ResBlock(nn.Module):
10
- def __init__(self, dims):
11
- super().__init__()
12
- self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
13
- self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
14
- self.batch_norm1 = nn.BatchNorm1d(dims)
15
- self.batch_norm2 = nn.BatchNorm1d(dims)
16
-
17
- def forward(self, x):
18
- residual = x
19
- x = self.conv1(x)
20
- x = self.batch_norm1(x)
21
- x = F.relu(x)
22
- x = self.conv2(x)
23
- x = self.batch_norm2(x)
24
- return x + residual
25
-
26
-
27
- class MelResNet(nn.Module):
28
- def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
29
- super().__init__()
30
- k_size = pad * 2 + 1
31
- self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
32
- self.batch_norm = nn.BatchNorm1d(compute_dims)
33
- self.layers = nn.ModuleList()
34
- for i in range(res_blocks):
35
- self.layers.append(ResBlock(compute_dims))
36
- self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
37
-
38
- def forward(self, x):
39
- x = self.conv_in(x)
40
- x = self.batch_norm(x)
41
- x = F.relu(x)
42
- for f in self.layers: x = f(x)
43
- x = self.conv_out(x)
44
- return x
45
-
46
-
47
- class Stretch2d(nn.Module):
48
- def __init__(self, x_scale, y_scale):
49
- super().__init__()
50
- self.x_scale = x_scale
51
- self.y_scale = y_scale
52
-
53
- def forward(self, x):
54
- b, c, h, w = x.size()
55
- x = x.unsqueeze(-1).unsqueeze(3)
56
- x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
57
- return x.view(b, c, h * self.y_scale, w * self.x_scale)
58
-
59
-
60
- class UpsampleNetwork(nn.Module):
61
- def __init__(self, feat_dims, upsample_scales, compute_dims,
62
- res_blocks, res_out_dims, pad):
63
- super().__init__()
64
- total_scale = np.cumproduct(upsample_scales)[-1]
65
- self.indent = pad * total_scale
66
- self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
67
- self.resnet_stretch = Stretch2d(total_scale, 1)
68
- self.up_layers = nn.ModuleList()
69
- for scale in upsample_scales:
70
- k_size = (1, scale * 2 + 1)
71
- padding = (0, scale)
72
- stretch = Stretch2d(scale, 1)
73
- conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
74
- conv.weight.data.fill_(1. / k_size[1])
75
- self.up_layers.append(stretch)
76
- self.up_layers.append(conv)
77
-
78
- def forward(self, m):
79
- aux = self.resnet(m).unsqueeze(1)
80
- aux = self.resnet_stretch(aux)
81
- aux = aux.squeeze(1)
82
- m = m.unsqueeze(1)
83
- for f in self.up_layers: m = f(m)
84
- m = m.squeeze(1)[:, :, self.indent:-self.indent]
85
- return m.transpose(1, 2), aux.transpose(1, 2)
86
-
87
-
88
- class WaveRNN(nn.Module):
89
- def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
90
- feat_dims, compute_dims, res_out_dims, res_blocks,
91
- hop_length, sample_rate, mode='RAW'):
92
- super().__init__()
93
- self.mode = mode
94
- self.pad = pad
95
- if self.mode == 'RAW' :
96
- self.n_classes = 2 ** bits
97
- elif self.mode == 'MOL' :
98
- self.n_classes = 30
99
- else :
100
- RuntimeError("Unknown model mode value - ", self.mode)
101
-
102
- self.rnn_dims = rnn_dims
103
- self.aux_dims = res_out_dims // 4
104
- self.hop_length = hop_length
105
- self.sample_rate = sample_rate
106
-
107
- self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
108
- self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
109
- self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
110
- self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
111
- self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
112
- self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
113
- self.fc3 = nn.Linear(fc_dims, self.n_classes)
114
-
115
- self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
116
- self.num_params()
117
-
118
- def forward(self, x, mels):
119
- self.step += 1
120
- bsize = x.size(0)
121
- if torch.cuda.is_available():
122
- h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
123
- h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
124
- else:
125
- h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
126
- h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
127
- mels, aux = self.upsample(mels)
128
-
129
- aux_idx = [self.aux_dims * i for i in range(5)]
130
- a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
131
- a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
132
- a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
133
- a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
134
-
135
- x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
136
- x = self.I(x)
137
- res = x
138
- x, _ = self.rnn1(x, h1)
139
-
140
- x = x + res
141
- res = x
142
- x = torch.cat([x, a2], dim=2)
143
- x, _ = self.rnn2(x, h2)
144
-
145
- x = x + res
146
- x = torch.cat([x, a3], dim=2)
147
- x = F.relu(self.fc1(x))
148
-
149
- x = torch.cat([x, a4], dim=2)
150
- x = F.relu(self.fc2(x))
151
- return self.fc3(x)
152
-
153
- def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None):
154
- mu_law = mu_law if self.mode == 'RAW' else False
155
- progress_callback = progress_callback or self.gen_display
156
-
157
- self.eval()
158
- output = []
159
- start = time.time()
160
- rnn1 = self.get_gru_cell(self.rnn1)
161
- rnn2 = self.get_gru_cell(self.rnn2)
162
-
163
- with torch.no_grad():
164
- if torch.cuda.is_available():
165
- mels = mels.cuda()
166
- else:
167
- mels = mels.cpu()
168
- wave_len = (mels.size(-1) - 1) * self.hop_length
169
- mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
170
- mels, aux = self.upsample(mels.transpose(1, 2))
171
-
172
- if batched:
173
- mels = self.fold_with_overlap(mels, target, overlap)
174
- aux = self.fold_with_overlap(aux, target, overlap)
175
-
176
- b_size, seq_len, _ = mels.size()
177
-
178
- if torch.cuda.is_available():
179
- h1 = torch.zeros(b_size, self.rnn_dims).cuda()
180
- h2 = torch.zeros(b_size, self.rnn_dims).cuda()
181
- x = torch.zeros(b_size, 1).cuda()
182
- else:
183
- h1 = torch.zeros(b_size, self.rnn_dims).cpu()
184
- h2 = torch.zeros(b_size, self.rnn_dims).cpu()
185
- x = torch.zeros(b_size, 1).cpu()
186
-
187
- d = self.aux_dims
188
- aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
189
-
190
- for i in range(seq_len):
191
-
192
- m_t = mels[:, i, :]
193
-
194
- a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
195
-
196
- x = torch.cat([x, m_t, a1_t], dim=1)
197
- x = self.I(x)
198
- h1 = rnn1(x, h1)
199
-
200
- x = x + h1
201
- inp = torch.cat([x, a2_t], dim=1)
202
- h2 = rnn2(inp, h2)
203
-
204
- x = x + h2
205
- x = torch.cat([x, a3_t], dim=1)
206
- x = F.relu(self.fc1(x))
207
-
208
- x = torch.cat([x, a4_t], dim=1)
209
- x = F.relu(self.fc2(x))
210
-
211
- logits = self.fc3(x)
212
-
213
- if self.mode == 'MOL':
214
- sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
215
- output.append(sample.view(-1))
216
- if torch.cuda.is_available():
217
- # x = torch.FloatTensor([[sample]]).cuda()
218
- x = sample.transpose(0, 1).cuda()
219
- else:
220
- x = sample.transpose(0, 1)
221
-
222
- elif self.mode == 'RAW' :
223
- posterior = F.softmax(logits, dim=1)
224
- distrib = torch.distributions.Categorical(posterior)
225
-
226
- sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
227
- output.append(sample)
228
- x = sample.unsqueeze(-1)
229
- else:
230
- raise RuntimeError("Unknown model mode value - ", self.mode)
231
-
232
- if i % 100 == 0:
233
- gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
234
- progress_callback(i, seq_len, b_size, gen_rate)
235
-
236
- output = torch.stack(output).transpose(0, 1)
237
- output = output.cpu().numpy()
238
- output = output.astype(np.float64)
239
-
240
- if batched:
241
- output = self.xfade_and_unfold(output, target, overlap)
242
- else:
243
- output = output[0]
244
-
245
- if mu_law:
246
- output = decode_mu_law(output, self.n_classes, False)
247
- if hp.apply_preemphasis:
248
- output = de_emphasis(output)
249
-
250
- # Fade-out at the end to avoid signal cutting out suddenly
251
- fade_out = np.linspace(1, 0, 20 * self.hop_length)
252
- output = output[:wave_len]
253
- output[-20 * self.hop_length:] *= fade_out
254
-
255
- self.train()
256
-
257
- return output
258
-
259
-
260
- def gen_display(self, i, seq_len, b_size, gen_rate):
261
- pbar = progbar(i, seq_len)
262
- msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
263
- stream(msg)
264
-
265
- def get_gru_cell(self, gru):
266
- gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
267
- gru_cell.weight_hh.data = gru.weight_hh_l0.data
268
- gru_cell.weight_ih.data = gru.weight_ih_l0.data
269
- gru_cell.bias_hh.data = gru.bias_hh_l0.data
270
- gru_cell.bias_ih.data = gru.bias_ih_l0.data
271
- return gru_cell
272
-
273
- def pad_tensor(self, x, pad, side='both'):
274
- # NB - this is just a quick method i need right now
275
- # i.e., it won't generalise to other shapes/dims
276
- b, t, c = x.size()
277
- total = t + 2 * pad if side == 'both' else t + pad
278
- if torch.cuda.is_available():
279
- padded = torch.zeros(b, total, c).cuda()
280
- else:
281
- padded = torch.zeros(b, total, c).cpu()
282
- if side == 'before' or side == 'both':
283
- padded[:, pad:pad + t, :] = x
284
- elif side == 'after':
285
- padded[:, :t, :] = x
286
- return padded
287
-
288
- def fold_with_overlap(self, x, target, overlap):
289
-
290
- ''' Fold the tensor with overlap for quick batched inference.
291
- Overlap will be used for crossfading in xfade_and_unfold()
292
-
293
- Args:
294
- x (tensor) : Upsampled conditioning features.
295
- shape=(1, timesteps, features)
296
- target (int) : Target timesteps for each index of batch
297
- overlap (int) : Timesteps for both xfade and rnn warmup
298
-
299
- Return:
300
- (tensor) : shape=(num_folds, target + 2 * overlap, features)
301
-
302
- Details:
303
- x = [[h1, h2, ... hn]]
304
-
305
- Where each h is a vector of conditioning features
306
-
307
- Eg: target=2, overlap=1 with x.size(1)=10
308
-
309
- folded = [[h1, h2, h3, h4],
310
- [h4, h5, h6, h7],
311
- [h7, h8, h9, h10]]
312
- '''
313
-
314
- _, total_len, features = x.size()
315
-
316
- # Calculate variables needed
317
- num_folds = (total_len - overlap) // (target + overlap)
318
- extended_len = num_folds * (overlap + target) + overlap
319
- remaining = total_len - extended_len
320
-
321
- # Pad if some time steps poking out
322
- if remaining != 0:
323
- num_folds += 1
324
- padding = target + 2 * overlap - remaining
325
- x = self.pad_tensor(x, padding, side='after')
326
-
327
- if torch.cuda.is_available():
328
- folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
329
- else:
330
- folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu()
331
-
332
- # Get the values for the folded tensor
333
- for i in range(num_folds):
334
- start = i * (target + overlap)
335
- end = start + target + 2 * overlap
336
- folded[i] = x[:, start:end, :]
337
-
338
- return folded
339
-
340
- def xfade_and_unfold(self, y, target, overlap):
341
-
342
- ''' Applies a crossfade and unfolds into a 1d array.
343
-
344
- Args:
345
- y (ndarry) : Batched sequences of audio samples
346
- shape=(num_folds, target + 2 * overlap)
347
- dtype=np.float64
348
- overlap (int) : Timesteps for both xfade and rnn warmup
349
-
350
- Return:
351
- (ndarry) : audio samples in a 1d array
352
- shape=(total_len)
353
- dtype=np.float64
354
-
355
- Details:
356
- y = [[seq1],
357
- [seq2],
358
- [seq3]]
359
-
360
- Apply a gain envelope at both ends of the sequences
361
-
362
- y = [[seq1_in, seq1_target, seq1_out],
363
- [seq2_in, seq2_target, seq2_out],
364
- [seq3_in, seq3_target, seq3_out]]
365
-
366
- Stagger and add up the groups of samples:
367
-
368
- [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
369
-
370
- '''
371
-
372
- num_folds, length = y.shape
373
- target = length - 2 * overlap
374
- total_len = num_folds * (target + overlap) + overlap
375
-
376
- # Need some silence for the rnn warmup
377
- silence_len = overlap // 2
378
- fade_len = overlap - silence_len
379
- silence = np.zeros((silence_len), dtype=np.float64)
380
-
381
- # Equal power crossfade
382
- t = np.linspace(-1, 1, fade_len, dtype=np.float64)
383
- fade_in = np.sqrt(0.5 * (1 + t))
384
- fade_out = np.sqrt(0.5 * (1 - t))
385
-
386
- # Concat the silence to the fades
387
- fade_in = np.concatenate([silence, fade_in])
388
- fade_out = np.concatenate([fade_out, silence])
389
-
390
- # Apply the gain to the overlap samples
391
- y[:, :overlap] *= fade_in
392
- y[:, -overlap:] *= fade_out
393
-
394
- unfolded = np.zeros((total_len), dtype=np.float64)
395
-
396
- # Loop to add up all the samples
397
- for i in range(num_folds):
398
- start = i * (target + overlap)
399
- end = start + target + 2 * overlap
400
- unfolded[start:end] += y[i]
401
-
402
- return unfolded
403
-
404
- def get_step(self) :
405
- return self.step.data.item()
406
-
407
- def checkpoint(self, model_dir, optimizer) :
408
- k_steps = self.get_step() // 1000
409
- self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer)
410
-
411
- def log(self, path, msg) :
412
- with open(path, 'a') as f:
413
- print(msg, file=f)
414
-
415
- def load(self, path, optimizer) :
416
- checkpoint = torch.load(path)
417
- if "optimizer_state" in checkpoint:
418
- self.load_state_dict(checkpoint["model_state"])
419
- optimizer.load_state_dict(checkpoint["optimizer_state"])
420
- else:
421
- # Backwards compatibility
422
- self.load_state_dict(checkpoint)
423
-
424
- def save(self, path, optimizer) :
425
- torch.save({
426
- "model_state": self.state_dict(),
427
- "optimizer_state": optimizer.state_dict(),
428
- }, path)
429
-
430
- def num_params(self, print_out=True):
431
- parameters = filter(lambda p: p.requires_grad, self.parameters())
432
- parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
433
- if print_out :
434
- print('Trainable Parameters: %.3fM' % parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/train.py DELETED
@@ -1,118 +0,0 @@
1
- import time
2
- from pathlib import Path
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import optim
8
- from torch.utils.data import DataLoader
9
-
10
- import vocoder.hparams as hp
11
- from vocoder.display import stream, simple_table
12
- from vocoder.distribution import discretized_mix_logistic_loss
13
- from vocoder.gen_wavernn import gen_testset
14
- from vocoder.models.fatchord_version import WaveRNN
15
- from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
16
-
17
-
18
- def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int,
19
- backup_every: int, force_restart: bool):
20
- # Check to make sure the hop length is correctly factorised
21
- assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
22
-
23
- # Instantiate the model
24
- print("Initializing the model...")
25
- model = WaveRNN(
26
- rnn_dims=hp.voc_rnn_dims,
27
- fc_dims=hp.voc_fc_dims,
28
- bits=hp.bits,
29
- pad=hp.voc_pad,
30
- upsample_factors=hp.voc_upsample_factors,
31
- feat_dims=hp.num_mels,
32
- compute_dims=hp.voc_compute_dims,
33
- res_out_dims=hp.voc_res_out_dims,
34
- res_blocks=hp.voc_res_blocks,
35
- hop_length=hp.hop_length,
36
- sample_rate=hp.sample_rate,
37
- mode=hp.voc_mode
38
- )
39
-
40
- if torch.cuda.is_available():
41
- model = model.cuda()
42
-
43
- # Initialize the optimizer
44
- optimizer = optim.Adam(model.parameters())
45
- for p in optimizer.param_groups:
46
- p["lr"] = hp.voc_lr
47
- loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
48
-
49
- # Load the weights
50
- model_dir = models_dir / run_id
51
- model_dir.mkdir(exist_ok=True)
52
- weights_fpath = model_dir / "vocoder.pt"
53
- if force_restart or not weights_fpath.exists():
54
- print("\nStarting the training of WaveRNN from scratch\n")
55
- model.save(weights_fpath, optimizer)
56
- else:
57
- print("\nLoading weights at %s" % weights_fpath)
58
- model.load(weights_fpath, optimizer)
59
- print("WaveRNN weights loaded from step %d" % model.step)
60
-
61
- # Initialize the dataset
62
- metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
63
- voc_dir.joinpath("synthesized.txt")
64
- mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
65
- wav_dir = syn_dir.joinpath("audio")
66
- dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
67
- test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
68
-
69
- # Begin the training
70
- simple_table([('Batch size', hp.voc_batch_size),
71
- ('LR', hp.voc_lr),
72
- ('Sequence Len', hp.voc_seq_len)])
73
-
74
- for epoch in range(1, 350):
75
- data_loader = DataLoader(dataset, hp.voc_batch_size, shuffle=True, num_workers=2, collate_fn=collate_vocoder)
76
- start = time.time()
77
- running_loss = 0.
78
-
79
- for i, (x, y, m) in enumerate(data_loader, 1):
80
- if torch.cuda.is_available():
81
- x, m, y = x.cuda(), m.cuda(), y.cuda()
82
-
83
- # Forward pass
84
- y_hat = model(x, m)
85
- if model.mode == 'RAW':
86
- y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
87
- elif model.mode == 'MOL':
88
- y = y.float()
89
- y = y.unsqueeze(-1)
90
-
91
- # Backward pass
92
- loss = loss_func(y_hat, y)
93
- optimizer.zero_grad()
94
- loss.backward()
95
- optimizer.step()
96
-
97
- running_loss += loss.item()
98
- speed = i / (time.time() - start)
99
- avg_loss = running_loss / i
100
-
101
- step = model.get_step()
102
- k = step // 1000
103
-
104
- if backup_every != 0 and step % backup_every == 0 :
105
- model.checkpoint(model_dir, optimizer)
106
-
107
- if save_every != 0 and step % save_every == 0 :
108
- model.save(weights_fpath, optimizer)
109
-
110
- msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
111
- f"Loss: {avg_loss:.4f} | {speed:.1f} " \
112
- f"steps/s | Step: {k}k | "
113
- stream(msg)
114
-
115
-
116
- gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
117
- hp.voc_target, hp.voc_overlap, model_dir)
118
- print("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vocoder/vocoder_dataset.py DELETED
@@ -1,84 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from pathlib import Path
3
- from vocoder import audio
4
- import vocoder.hparams as hp
5
- import numpy as np
6
- import torch
7
-
8
-
9
- class VocoderDataset(Dataset):
10
- def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path):
11
- print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir))
12
-
13
- with metadata_fpath.open("r") as metadata_file:
14
- metadata = [line.split("|") for line in metadata_file]
15
-
16
- gta_fnames = [x[1] for x in metadata if int(x[4])]
17
- gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames]
18
- wav_fnames = [x[0] for x in metadata if int(x[4])]
19
- wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames]
20
- self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths))
21
-
22
- print("Found %d samples" % len(self.samples_fpaths))
23
-
24
- def __getitem__(self, index):
25
- mel_path, wav_path = self.samples_fpaths[index]
26
-
27
- # Load the mel spectrogram and adjust its range to [-1, 1]
28
- mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value
29
-
30
- # Load the wav
31
- wav = np.load(wav_path)
32
- if hp.apply_preemphasis:
33
- wav = audio.pre_emphasis(wav)
34
- wav = np.clip(wav, -1, 1)
35
-
36
- # Fix for missing padding # TODO: settle on whether this is any useful
37
- r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav)
38
- wav = np.pad(wav, (0, r_pad), mode='constant')
39
- assert len(wav) >= mel.shape[1] * hp.hop_length
40
- wav = wav[:mel.shape[1] * hp.hop_length]
41
- assert len(wav) % hp.hop_length == 0
42
-
43
- # Quantize the wav
44
- if hp.voc_mode == 'RAW':
45
- if hp.mu_law:
46
- quant = audio.encode_mu_law(wav, mu=2 ** hp.bits)
47
- else:
48
- quant = audio.float_2_label(wav, bits=hp.bits)
49
- elif hp.voc_mode == 'MOL':
50
- quant = audio.float_2_label(wav, bits=16)
51
-
52
- return mel.astype(np.float32), quant.astype(np.int64)
53
-
54
- def __len__(self):
55
- return len(self.samples_fpaths)
56
-
57
-
58
- def collate_vocoder(batch):
59
- mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
60
- max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
61
- mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
62
- sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
63
-
64
- mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
65
-
66
- labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
67
-
68
- mels = np.stack(mels).astype(np.float32)
69
- labels = np.stack(labels).astype(np.int64)
70
-
71
- mels = torch.tensor(mels)
72
- labels = torch.tensor(labels).long()
73
-
74
- x = labels[:, :hp.voc_seq_len]
75
- y = labels[:, 1:]
76
-
77
- bits = 16 if hp.voc_mode == 'MOL' else hp.bits
78
-
79
- x = audio.label_2_float(x.float(), bits)
80
-
81
- if hp.voc_mode == 'MOL' :
82
- y = audio.label_2_float(y.float(), bits)
83
-
84
- return x, y, mels