WinWut commited on
Commit
df4fb1d
·
1 Parent(s): cca25f4

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +667 -0
model.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Imports
2
+
3
+ from __future__ import print_function, division
4
+ import tensorflow as tf
5
+ from glob import glob
6
+ import scipy
7
+ import soundfile as sf
8
+ import matplotlib.pyplot as plt
9
+ from IPython.display import clear_output
10
+ from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Concatenate, Conv2D, Conv2DTranspose, GlobalAveragePooling2D, UpSampling2D, LeakyReLU, ReLU, Add, Multiply, Lambda, Dot, BatchNormalization, Activation, ZeroPadding2D, Cropping2D, Cropping1D
11
+ from tensorflow.keras.models import Sequential, Model, load_model
12
+ from tensorflow.keras.optimizers import Adam
13
+ from tensorflow.keras.initializers import TruncatedNormal, he_normal
14
+ import tensorflow.keras.backend as K
15
+ import datetime
16
+ import numpy as np
17
+ import random
18
+ import matplotlib.pyplot as plt
19
+ import collections
20
+ from PIL import Image
21
+ from skimage.transform import resize
22
+ import imageio
23
+ import librosa
24
+ import librosa.display
25
+ from librosa.feature import melspectrogram
26
+ import os
27
+ import time
28
+ import IPython
29
+
30
+ #Hyperparameters
31
+
32
+ hop=192 #hop size (window size = 6*hop)
33
+ sr=16000 #sampling rate
34
+ min_level_db=-100 #reference values to normalize data
35
+ ref_level_db=20
36
+
37
+ shape=24 #length of time axis of split specrograms to feed to generator
38
+ vec_len=128 #length of vector generated by siamese vector
39
+ bs = 16 #batch size
40
+ delta = 2. #constant for siamese loss
41
+
42
+ #There seems to be a problem with Tensorflow STFT, so we'll be using pytorch to handle offline mel-spectrogram generation and waveform reconstruction
43
+ #For waveform reconstruction, a gradient-based method is used:
44
+
45
+ ''' Decorsière, Rémi, Peter L. Søndergaard, Ewen N. MacDonald, and Torsten Dau.
46
+ "Inversion of auditory spectrograms, traditional spectrograms, and other envelope representations."
47
+ IEEE/ACM Transactions on Audio, Speech, and Language Processing 23, no. 1 (2014): 46-56.'''
48
+
49
+ #ORIGINAL CODE FROM https://github.com/yoyololicon/spectrogram-inversion
50
+
51
+ import torch
52
+ import torch.nn as nn
53
+ import torch.nn.functional as F
54
+ from tqdm import tqdm
55
+ from functools import partial
56
+ import math
57
+ import heapq
58
+ from torchaudio.transforms import MelScale, Spectrogram
59
+
60
+
61
+ specobj = Spectrogram(n_fft=6*hop, win_length=6*hop, hop_length=hop, pad=0, power=2, normalized=True)
62
+ specfunc = specobj.forward
63
+ melobj = MelScale(n_mels=hop, sample_rate=sr, f_min=0.,n_stft=577)
64
+ melfunc = melobj.forward
65
+
66
+ def melspecfunc(waveform):
67
+ specgram = specfunc(waveform)
68
+ mel_specgram = melfunc(specgram)
69
+ return mel_specgram
70
+
71
+ def spectral_convergence(input, target):
72
+ return 20 * ((input - target).norm().log10() - target.norm().log10())
73
+
74
+ def GRAD(spec, transform_fn, samples=None, init_x0=None, maxiter=1000, tol=1e-6, verbose=1, evaiter=10, lr=0.003):
75
+
76
+ spec = torch.Tensor(spec)
77
+ samples = (spec.shape[-1]*hop)-hop
78
+
79
+ if init_x0 is None:
80
+ init_x0 = spec.new_empty((1,samples)).normal_(std=1e-6)
81
+ x = nn.Parameter(init_x0)
82
+ T = spec
83
+
84
+ criterion = nn.L1Loss()
85
+ optimizer = torch.optim.Adam([x], lr=lr)
86
+
87
+ bar_dict = {}
88
+ metric_func = spectral_convergence
89
+ bar_dict['spectral_convergence'] = 0
90
+ metric = 'spectral_convergence'
91
+
92
+ init_loss = None
93
+ with tqdm(total=maxiter, disable=not verbose) as pbar:
94
+ for i in range(maxiter):
95
+ optimizer.zero_grad()
96
+ V = transform_fn(x)
97
+ loss = criterion(V, T)
98
+ loss.backward()
99
+ optimizer.step()
100
+ lr = lr*0.9999
101
+ for param_group in optimizer.param_groups:
102
+ param_group['lr'] = lr
103
+
104
+ if i % evaiter == evaiter - 1:
105
+ with torch.no_grad():
106
+ V = transform_fn(x)
107
+ bar_dict[metric] = metric_func(V, spec).item()
108
+ l2_loss = criterion(V, spec).item()
109
+ pbar.set_postfix(**bar_dict, loss=l2_loss)
110
+ pbar.update(evaiter)
111
+
112
+ return x.detach().view(-1).cpu()
113
+
114
+ def normalize(S):
115
+ return np.clip((((S - min_level_db) / -min_level_db)*2.)-1., -1, 1)
116
+
117
+ def denormalize(S):
118
+ return (((np.clip(S, -1, 1)+1.)/2.) * -min_level_db) + min_level_db
119
+
120
+ def prep(wv,hop=192):
121
+ S = np.array(torch.squeeze(melspecfunc(torch.Tensor(wv).view(1,-1))).detach().cpu())
122
+ S = librosa.power_to_db(S)-ref_level_db
123
+ return normalize(S)
124
+
125
+ def deprep(S):
126
+ S = denormalize(S)+ref_level_db
127
+ S = librosa.db_to_power(S)
128
+ wv = GRAD(np.expand_dims(S,0), melspecfunc, maxiter=2000, evaiter=10, tol=1e-8)
129
+ return np.array(np.squeeze(wv))
130
+
131
+ #Helper functions
132
+
133
+ #Generate spectrograms from waveform array
134
+ def tospec(data):
135
+ specs=np.empty(data.shape[0], dtype=object)
136
+ for i in range(data.shape[0]):
137
+ x = data[i]
138
+ S=prep(x)
139
+ S = np.array(S, dtype=np.float32)
140
+ specs[i]=np.expand_dims(S, -1)
141
+ print(specs.shape)
142
+ return specs
143
+
144
+ #Generate multiple spectrograms with a determined length from single wav file
145
+ def tospeclong(path, length=4*16000):
146
+ x, sr = librosa.load(path,sr=16000)
147
+ x,_ = librosa.effects.trim(x)
148
+ loudls = librosa.effects.split(x, top_db=50)
149
+ xls = np.array([])
150
+ for interv in loudls:
151
+ xls = np.concatenate((xls,x[interv[0]:interv[1]]))
152
+ x = xls
153
+ num = x.shape[0]//length
154
+ specs=np.empty(num, dtype=object)
155
+ for i in range(num-1):
156
+ a = x[i*length:(i+1)*length]
157
+ S = prep(a)
158
+ S = np.array(S, dtype=np.float32)
159
+ try:
160
+ sh = S.shape
161
+ specs[i]=S
162
+ except AttributeError:
163
+ print('spectrogram failed')
164
+ print(specs.shape)
165
+ return specs
166
+
167
+ #Waveform array from path of folder containing wav files
168
+ def audio_array(path):
169
+ ls = glob(f'{path}/*.wav')
170
+ adata = []
171
+ for i in range(len(ls)):
172
+ try:
173
+ x, sr = tf.audio.decode_wav(tf.io.read_file(ls[i]), 1)
174
+ except:
175
+ print(ls[i],"is broken")
176
+ continue
177
+ x = np.array(x, dtype=np.float32)
178
+ adata.append(x)
179
+ return np.array(adata)
180
+
181
+ #Concatenate spectrograms in array along the time axis
182
+ def testass(a):
183
+ but=False
184
+ con = np.array([])
185
+ nim = a.shape[0]
186
+ for i in range(nim):
187
+ im = a[i]
188
+ im = np.squeeze(im)
189
+ if not but:
190
+ con=im
191
+ but=True
192
+ else:
193
+ con = np.concatenate((con,im), axis=1)
194
+ return np.squeeze(con)
195
+
196
+ #Split spectrograms in chunks with equal size
197
+ def splitcut(data):
198
+ ls = []
199
+ mini = 0
200
+ minifinal = 10*shape #max spectrogram length
201
+ for i in range(data.shape[0]-1):
202
+ if data[i].shape[1]<=data[i+1].shape[1]:
203
+ mini = data[i].shape[1]
204
+ else:
205
+ mini = data[i+1].shape[1]
206
+ if mini>=3*shape and mini<minifinal:
207
+ minifinal = mini
208
+ for i in range(data.shape[0]):
209
+ x = data[i]
210
+ if x.shape[1]>=3*shape:
211
+ for n in range(x.shape[1]//minifinal):
212
+ ls.append(x[:,n*minifinal:n*minifinal+minifinal,:])
213
+ ls.append(x[:,-minifinal:,:])
214
+ return np.array(ls)
215
+
216
+ #Adding Spectral Normalization to convolutional layers
217
+
218
+ from tensorflow.python.keras.utils import conv_utils
219
+ from tensorflow.python.ops import array_ops
220
+ from tensorflow.python.ops import math_ops
221
+ from tensorflow.python.ops import sparse_ops
222
+ from tensorflow.python.ops import gen_math_ops
223
+ from tensorflow.python.ops import standard_ops
224
+ from tensorflow.python.eager import context
225
+ from tensorflow.python.framework import tensor_shape
226
+
227
+ def l2normalize(v, eps=1e-12):
228
+ return v / (tf.norm(v) + eps)
229
+
230
+
231
+ class ConvSN2D(tf.keras.layers.Conv2D):
232
+
233
+ def __init__(self, filters, kernel_size, power_iterations=1, **kwargs):
234
+ super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
235
+ self.power_iterations = power_iterations
236
+
237
+
238
+ def build(self, input_shape):
239
+ super(ConvSN2D, self).build(input_shape)
240
+
241
+ if self.data_format == 'channels_first':
242
+ channel_axis = 1
243
+ else:
244
+ channel_axis = -1
245
+
246
+ self.u = self.add_weight(self.name + '_u',
247
+ shape=tuple([1, self.kernel.shape.as_list()[-1]]),
248
+ initializer=tf.initializers.RandomNormal(0, 1),
249
+ trainable=False
250
+ )
251
+
252
+ def compute_spectral_norm(self, W, new_u, W_shape):
253
+ for _ in range(self.power_iterations):
254
+
255
+ new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
256
+ new_u = l2normalize(tf.matmul(new_v, W))
257
+
258
+ sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
259
+ W_bar = W/sigma
260
+
261
+ with tf.control_dependencies([self.u.assign(new_u)]):
262
+ W_bar = tf.reshape(W_bar, W_shape)
263
+
264
+ return W_bar
265
+
266
+ def convolution_op(self, inputs, kernel):
267
+ if self.padding == "causal":
268
+ tf_padding = "VALID" # Causal padding handled in `call`.
269
+ elif isinstance(self.padding, str):
270
+ tf_padding = self.padding.upper()
271
+ else:
272
+ tf_padding = self.padding
273
+
274
+ return tf.nn.convolution(
275
+ inputs,
276
+ kernel,
277
+ strides=list(self.strides),
278
+ padding=tf_padding,
279
+ dilations=list(self.dilation_rate),
280
+ )
281
+ def call(self, inputs):
282
+ W_shape = self.kernel.shape.as_list()
283
+ W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
284
+ new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
285
+ outputs = self.convolution_op(inputs, new_kernel)
286
+
287
+ if self.use_bias:
288
+ if self.data_format == 'channels_first':
289
+ outputs = tf.nn.bias_add(outputs, self.bias, data_format='NCHW')
290
+ else:
291
+ outputs = tf.nn.bias_add(outputs, self.bias, data_format='NHWC')
292
+ if self.activation is not None:
293
+ return self.activation(outputs)
294
+
295
+ return outputs
296
+
297
+
298
+ class ConvSN2DTranspose(tf.keras.layers.Conv2DTranspose):
299
+
300
+ def __init__(self, filters, kernel_size, power_iterations=1, **kwargs):
301
+ super(ConvSN2DTranspose, self).__init__(filters, kernel_size, **kwargs)
302
+ self.power_iterations = power_iterations
303
+
304
+
305
+ def build(self, input_shape):
306
+ super(ConvSN2DTranspose, self).build(input_shape)
307
+
308
+ if self.data_format == 'channels_first':
309
+ channel_axis = 1
310
+ else:
311
+ channel_axis = -1
312
+
313
+ self.u = self.add_weight(self.name + '_u',
314
+ shape=tuple([1, self.kernel.shape.as_list()[-1]]),
315
+ initializer=tf.initializers.RandomNormal(0, 1),
316
+ trainable=False
317
+ )
318
+
319
+ def compute_spectral_norm(self, W, new_u, W_shape):
320
+ for _ in range(self.power_iterations):
321
+
322
+ new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
323
+ new_u = l2normalize(tf.matmul(new_v, W))
324
+
325
+ sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
326
+ W_bar = W/sigma
327
+
328
+ with tf.control_dependencies([self.u.assign(new_u)]):
329
+ W_bar = tf.reshape(W_bar, W_shape)
330
+
331
+ return W_bar
332
+
333
+ def call(self, inputs):
334
+ W_shape = self.kernel.shape.as_list()
335
+ W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
336
+ new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
337
+
338
+ inputs_shape = array_ops.shape(inputs)
339
+ batch_size = inputs_shape[0]
340
+ if self.data_format == 'channels_first':
341
+ h_axis, w_axis = 2, 3
342
+ else:
343
+ h_axis, w_axis = 1, 2
344
+
345
+ height, width = inputs_shape[h_axis], inputs_shape[w_axis]
346
+ kernel_h, kernel_w = self.kernel_size
347
+ stride_h, stride_w = self.strides
348
+
349
+ if self.output_padding is None:
350
+ out_pad_h = out_pad_w = None
351
+ else:
352
+ out_pad_h, out_pad_w = self.output_padding
353
+
354
+ out_height = conv_utils.deconv_output_length(height,
355
+ kernel_h,
356
+ padding=self.padding,
357
+ output_padding=out_pad_h,
358
+ stride=stride_h,
359
+ dilation=self.dilation_rate[0])
360
+ out_width = conv_utils.deconv_output_length(width,
361
+ kernel_w,
362
+ padding=self.padding,
363
+ output_padding=out_pad_w,
364
+ stride=stride_w,
365
+ dilation=self.dilation_rate[1])
366
+ if self.data_format == 'channels_first':
367
+ output_shape = (batch_size, self.filters, out_height, out_width)
368
+ else:
369
+ output_shape = (batch_size, out_height, out_width, self.filters)
370
+
371
+ output_shape_tensor = array_ops.stack(output_shape)
372
+ outputs = K.conv2d_transpose(
373
+ inputs,
374
+ new_kernel,
375
+ output_shape_tensor,
376
+ strides=self.strides,
377
+ padding=self.padding,
378
+ data_format=self.data_format,
379
+ dilation_rate=self.dilation_rate)
380
+
381
+ if not context.executing_eagerly():
382
+ out_shape = self.compute_output_shape(inputs.shape)
383
+ outputs.set_shape(out_shape)
384
+
385
+ if self.use_bias:
386
+ outputs = tf.nn.bias_add(
387
+ outputs,
388
+ self.bias,
389
+ data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
390
+
391
+ if self.activation is not None:
392
+ return self.activation(outputs)
393
+ return outputs
394
+
395
+
396
+ class DenseSN(Dense):
397
+
398
+ def build(self, input_shape):
399
+ super(DenseSN, self).build(input_shape)
400
+
401
+ self.u = self.add_weight(self.name + '_u',
402
+ shape=tuple([1, self.kernel.shape.as_list()[-1]]),
403
+ initializer=tf.initializers.RandomNormal(0, 1),
404
+ trainable=False)
405
+
406
+ def compute_spectral_norm(self, W, new_u, W_shape):
407
+ new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
408
+ new_u = l2normalize(tf.matmul(new_v, W))
409
+ sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
410
+ W_bar = W/sigma
411
+ with tf.control_dependencies([self.u.assign(new_u)]):
412
+ W_bar = tf.reshape(W_bar, W_shape)
413
+ return W_bar
414
+
415
+ def call(self, inputs):
416
+ W_shape = self.kernel.shape.as_list()
417
+ W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
418
+ new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
419
+ rank = len(inputs.shape)
420
+ if rank > 2:
421
+ outputs = standard_ops.tensordot(inputs, new_kernel, [[rank - 1], [0]])
422
+ if not context.executing_eagerly():
423
+ shape = inputs.shape.as_list()
424
+ output_shape = shape[:-1] + [self.units]
425
+ outputs.set_shape(output_shape)
426
+ else:
427
+ inputs = math_ops.cast(inputs, self._compute_dtype)
428
+ if K.is_sparse(inputs):
429
+ outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, new_kernel)
430
+ else:
431
+ outputs = gen_math_ops.mat_mul(inputs, new_kernel)
432
+ if self.use_bias:
433
+ outputs = tf.nn.bias_add(outputs, self.bias)
434
+ if self.activation is not None:
435
+ return self.activation(outputs)
436
+ return outputs
437
+
438
+ #Networks Architecture
439
+
440
+ init = tf.keras.initializers.he_uniform()
441
+
442
+ def conv2d(layer_input, filters, kernel_size=4, strides=2, padding='same', leaky=True, bnorm=True, sn=True):
443
+ if leaky:
444
+ Activ = LeakyReLU(alpha=0.2)
445
+ else:
446
+ Activ = ReLU()
447
+ if sn:
448
+ d = ConvSN2D(filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=init, use_bias=False)(layer_input)
449
+ else:
450
+ d = Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=init, use_bias=False)(layer_input)
451
+ if bnorm:
452
+ d = BatchNormalization()(d)
453
+ d = Activ(d)
454
+ return d
455
+
456
+ def deconv2d(layer_input, layer_res, filters, kernel_size=4, conc=True, scalev=False, bnorm=True, up=True, padding='same', strides=2):
457
+ if up:
458
+ u = UpSampling2D((1,2))(layer_input)
459
+ u = ConvSN2D(filters, kernel_size, strides=(1,1), kernel_initializer=init, use_bias=False, padding=padding)(u)
460
+ else:
461
+ u = ConvSN2DTranspose(filters, kernel_size, strides=strides, kernel_initializer=init, use_bias=False, padding=padding)(layer_input)
462
+ if bnorm:
463
+ u = BatchNormalization()(u)
464
+ u = LeakyReLU(alpha=0.2)(u)
465
+ if conc:
466
+ u = Concatenate()([u,layer_res])
467
+ return u
468
+
469
+ #Extract function: splitting spectrograms
470
+ def extract_image(im):
471
+ im1 = Cropping2D(((0,0), (0, 2*(im.shape[2]//3))))(im)
472
+ im2 = Cropping2D(((0,0), (im.shape[2]//3,im.shape[2]//3)))(im)
473
+ im3 = Cropping2D(((0,0), (2*(im.shape[2]//3), 0)))(im)
474
+ return im1,im2,im3
475
+
476
+ #Assemble function: concatenating spectrograms
477
+ def assemble_image(lsim):
478
+ im1,im2,im3 = lsim
479
+ imh = Concatenate(2)([im1,im2,im3])
480
+ return imh
481
+
482
+ #U-NET style architecture
483
+ def build_generator(input_shape):
484
+ h,w,c = input_shape
485
+ inp = Input(shape=input_shape)
486
+ #downscaling
487
+ g0 = tf.keras.layers.ZeroPadding2D((0,1))(inp)
488
+ g1 = conv2d(g0, 256, kernel_size=(h,3), strides=1, padding='valid')
489
+ g2 = conv2d(g1, 256, kernel_size=(1,9), strides=(1,2))
490
+ g3 = conv2d(g2, 256, kernel_size=(1,7), strides=(1,2))
491
+ #upscaling
492
+ g4 = deconv2d(g3,g2, 256, kernel_size=(1,7), strides=(1,2))
493
+ g5 = deconv2d(g4,g1, 256, kernel_size=(1,9), strides=(1,2), bnorm=False)
494
+ g6 = ConvSN2DTranspose(1, kernel_size=(h,1), strides=(1,1), kernel_initializer=init, padding='valid', activation='tanh')(g5)
495
+ return Model(inp,g6, name='G')
496
+
497
+ #Siamese Network
498
+ def build_siamese(input_shape):
499
+ h,w,c = input_shape
500
+ inp = Input(shape=input_shape)
501
+ g1 = conv2d(inp, 256, kernel_size=(h,3), strides=1, padding='valid', sn=False)
502
+ g2 = conv2d(g1, 256, kernel_size=(1,9), strides=(1,2), sn=False)
503
+ g3 = conv2d(g2, 256, kernel_size=(1,7), strides=(1,2), sn=False)
504
+ g4 = Flatten()(g3)
505
+ g5 = Dense(vec_len)(g4)
506
+ return Model(inp, g5, name='S')
507
+
508
+ #Discriminator (Critic) Network
509
+ def build_critic(input_shape):
510
+ h,w,c = input_shape
511
+ inp = Input(shape=input_shape)
512
+ g1 = conv2d(inp, 512, kernel_size=(h,3), strides=1, padding='valid', bnorm=False)
513
+ g2 = conv2d(g1, 512, kernel_size=(1,9), strides=(1,2), bnorm=False)
514
+ g3 = conv2d(g2, 512, kernel_size=(1,7), strides=(1,2), bnorm=False)
515
+ g4 = Flatten()(g3)
516
+ g4 = DenseSN(1, kernel_initializer=init)(g4)
517
+ return Model(inp, g4, name='C')
518
+
519
+ #Load past models from path to resume training or test
520
+ save_model_path = '/content/drive/MyDrive/weights' #@param {type:"string"}
521
+ def load(path):
522
+ gen = build_generator((hop,shape,1))
523
+ siam = build_siamese((hop,shape,1))
524
+ critic = build_critic((hop,3*shape,1))
525
+ gen.load_weights(path+'/gen.h5')
526
+ critic.load_weights(path+'/critic.h5')
527
+ siam.load_weights(path+'/siam.h5')
528
+ return gen,critic,siam
529
+
530
+ #Build models
531
+ def build():
532
+ gen = build_generator((hop,shape,1))
533
+ siam = build_siamese((hop,shape,1))
534
+ critic = build_critic((hop,3*shape,1)) #the discriminator accepts as input spectrograms of triple the width of those generated by the generator
535
+ return gen,critic,siam
536
+
537
+ #Show results mid-training
538
+ def save_test_image_full(path):
539
+ a = testgena()
540
+ print(a.shape)
541
+ ab = gen(a, training=False)
542
+ ab = testass(ab)
543
+ a = testass(a)
544
+ abwv = deprep(ab)
545
+ awv = deprep(a)
546
+ sf.write(path+'/new_file.wav', abwv, sr)
547
+ IPython.display.display(IPython.display.Audio(np.squeeze(abwv), rate=sr))
548
+ IPython.display.display(IPython.display.Audio(np.squeeze(awv), rate=sr))
549
+ fig, axs = plt.subplots(ncols=2)
550
+ axs[0].imshow(np.flip(a, -2), cmap=None)
551
+ axs[0].axis('off')
552
+ axs[0].set_title('Source')
553
+ axs[1].imshow(np.flip(ab, -2), cmap=None)
554
+ axs[1].axis('off')
555
+ axs[1].set_title('Generated')
556
+ plt.show()
557
+
558
+ #Save in training loop
559
+ def save_end(epoch,gloss,closs,mloss,n_save=3,save_path=save_model_path): #use custom save_path (i.e. Drive '../content/drive/My Drive/')
560
+ if epoch % n_save == 0:
561
+ print('Saving...')
562
+ path = f'{save_path}/MELGANVC-{str(gloss)[:9]}-{str(closs)[:9]}-{str(mloss)[:9]}'
563
+ os.mkdir(path)
564
+ gen.save_weights(path+'/gen.h5')
565
+ critic.save_weights(path+'/critic.h5')
566
+ siam.save_weights(path+'/siam.h5')
567
+ save_test_image_full(path)
568
+
569
+ #Get models and optimizers
570
+ def get_networks(shape, load_model=False, path=None):
571
+ if not load_model:
572
+ gen,critic,siam = build()
573
+ else:
574
+ gen,critic,siam = load(path)
575
+ print('Built networks')
576
+
577
+ opt_gen = Adam(0.0001, 0.5)
578
+ opt_disc = Adam(0.0001, 0.5)
579
+
580
+ return gen,critic,siam, [opt_gen,opt_disc]
581
+
582
+ #Set learning rate
583
+ def update_lr(lr):
584
+ opt_gen.learning_rate = lr
585
+ opt_disc.learning_rate = lr
586
+
587
+ #Build models and initialize optimizers
588
+ load_model_path='MELGANVC-0.4886211-0.5750153-0-20230612T163214Z-001\MELGANVC-0.4886211-0.5750153-0' #@param {type:"string"}
589
+ #If load_model=True, specify the path where the models are saved
590
+
591
+ gen,critic,siam, [opt_gen,opt_disc] = get_networks(shape, load_model=True,path=load_model_path)
592
+
593
+ #After Training, use these functions to convert data with the generator and save the results
594
+
595
+ #Assembling generated Spectrogram chunks into final Spectrogram
596
+ def specass(a,spec):
597
+ but=False
598
+ con = np.array([])
599
+ nim = a.shape[0]
600
+ for i in range(nim-1):
601
+ im = a[i]
602
+ im = np.squeeze(im)
603
+ if not but:
604
+ con=im
605
+ but=True
606
+ else:
607
+ con = np.concatenate((con,im), axis=1)
608
+ diff = spec.shape[1]-(nim*shape)
609
+ a = np.squeeze(a)
610
+ con = np.concatenate((con,a[-1,:,-diff:]), axis=1)
611
+ return np.squeeze(con)
612
+
613
+ #Splitting input spectrogram into different chunks to feed to the generator
614
+ def chopspec(spec):
615
+ dsa=[]
616
+ for i in range(spec.shape[1]//shape):
617
+ im = spec[:,i*shape:i*shape+shape]
618
+ im = np.reshape(im, (im.shape[0],im.shape[1],1))
619
+ dsa.append(im)
620
+ imlast = spec[:,-shape:]
621
+ imlast = np.reshape(imlast, (imlast.shape[0],imlast.shape[1],1))
622
+ dsa.append(imlast)
623
+ return np.array(dsa, dtype=np.float32)
624
+
625
+ #Converting from source Spectrogram to target Spectrogram
626
+ def towave(spec, name, path='../content/', show=False):
627
+ specarr = chopspec(spec)
628
+ print(specarr.shape)
629
+ a = specarr
630
+ print('Generating...')
631
+ ab = gen(a, training=False)
632
+ print('Assembling and Converting...')
633
+ a = specass(a,spec)
634
+ ab = specass(ab,spec)
635
+ awv = deprep(a)
636
+ abwv = deprep(ab)
637
+ print('Saving...')
638
+ pathfin = f'{path}/{name}'
639
+ os.mkdir(pathfin)
640
+ sf.write(pathfin+'/AB.wav', abwv, sr)
641
+ sf.write(pathfin+'/A.wav', awv, sr)
642
+ print('Saved WAV!')
643
+ IPython.display.display(IPython.display.Audio(np.squeeze(abwv), rate=sr))
644
+ IPython.display.display(IPython.display.Audio(np.squeeze(awv), rate=sr))
645
+ if show:
646
+ fig, axs = plt.subplots(ncols=2)
647
+ axs[0].imshow(np.flip(a, -2), cmap=None)
648
+ axs[0].axis('off')
649
+ axs[0].set_title('Source')
650
+ axs[1].imshow(np.flip(ab, -2), cmap=None)
651
+ axs[1].axis('off')
652
+ axs[1].set_title('Generated')
653
+ plt.show()
654
+ return abwv
655
+
656
+ #Wav to wav conversion
657
+
658
+ wv, sr = librosa.load("sltsp.wav", sr=16000) #Load waveform
659
+ print(wv.shape)
660
+ speca = prep(wv) #Waveform to Spectrogram
661
+
662
+ plt.figure(figsize=(50,1)) #Show Spectrogram
663
+ plt.imshow(np.flip(speca, axis=0), cmap=None)
664
+ plt.axis('off')
665
+ plt.show()
666
+
667
+ abwv = towave(speca, name='FILENAME2', path='songs_gen') #Convert and save wav