AiMimicry commited on
Commit
576ea34
·
1 Parent(s): 3aa0c37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +539 -106
app.py CHANGED
@@ -1,109 +1,542 @@
1
  import os
2
- import io
3
- import gradio as gr
 
 
 
 
 
 
 
 
 
4
  import librosa
5
  import numpy as np
6
- import utils
7
- from inference.infer_tool import Svc
8
- import logging
9
- import soundfile
10
- import asyncio
11
- import argparse
12
- import gradio.processing_utils as gr_processing_utils
13
- logging.getLogger('numba').setLevel(logging.WARNING)
14
- logging.getLogger('markdown_it').setLevel(logging.WARNING)
15
- logging.getLogger('urllib3').setLevel(logging.WARNING)
16
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
17
-
18
- limitation = os.getenv("SYSTEM") == "spaces" # limit audio length in huggingface spaces
19
-
20
- audio_postprocess_ori = gr.Audio.postprocess
21
-
22
- def audio_postprocess(self, y):
23
- data = audio_postprocess_ori(self, y)
24
- if data is None:
25
- return None
26
- return gr_processing_utils.encode_url_or_file_to_base64(data["name"])
27
-
28
-
29
- gr.Audio.postprocess = audio_postprocess
30
- def create_vc_fn(model, sid):
31
- def vc_fn(input_audio, vc_transform, auto_f0, fmp):
32
- if input_audio is None:
33
- return "You need to upload an audio", None
34
- sampling_rate, audio = input_audio
35
- duration = audio.shape[0] / sampling_rate
36
- if duration > 20 and limitation:
37
- return "Please upload an audio file that is less than 20 seconds. If you need to generate a longer audio file, please use Colab.", None
38
- audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
39
- if len(audio.shape) > 1:
40
- audio = librosa.to_mono(audio.transpose(1, 0))
41
- if sampling_rate != 16000:
42
- audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
43
- raw_path = io.BytesIO()
44
- soundfile.write(raw_path, audio, 16000, format="wav")
45
- raw_path.seek(0)
46
- out_audio, out_sr = model.infer(sid, vc_transform, raw_path,
47
- auto_predict_f0=auto_f0, F0_mean_pooling=fmp
48
- )
49
- return "Success", (44100, out_audio.cpu().numpy())
50
- return vc_fn
51
-
52
- if __name__ == '__main__':
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument('--device', type=str, default='cpu')
55
- parser.add_argument('--api', action="store_true", default=False)
56
- parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
57
- args = parser.parse_args()
58
- hubert_model = utils.get_hubert_model().to(args.device)
59
- models = []
60
- voices = []
61
- for f in os.listdir("models"):
62
- name = f
63
- model = Svc(fr"models/{f}/{f}.pth", f"models/{f}/config.json", device=args.device)
64
- cover = f"models/{f}/cover.jpg" if os.path.exists(f"models/{f}/cover.jpg") else None
65
- models.append((name, cover, create_vc_fn(model, name)))
66
- with gr.Blocks() as app:
67
- gr.Markdown(
68
- "# <center> Sovits Models\n"
69
- "## <center> The input audio should be clean and pure voice without background music.\n"
70
- "[![Original Repo](https://badgen.net/badge/icon/github?icon=github&label=Original%20Repo)](https://github.com/svc-develop-team/so-vits-svc)"
71
- )
72
-
73
- with gr.Tabs():
74
- for (name, cover, vc_fn) in models:
75
- with gr.TabItem(name):
76
- with gr.Row():
77
- gr.Markdown(
78
- '<div align="center">'
79
- f'<img style="width:auto;height:300px;" src="file/{cover}">' if cover else ""
80
- '</div>'
81
- )
82
- with gr.Row():
83
- with gr.Column():
84
- vc_input = gr.Audio(label="Input audio"+' (less than 20 seconds)' if limitation else '')
85
- vc_transform = gr.Number(label="vc_transform", value=0)
86
- auto_f0 = gr.Checkbox(label="auto_f0", value=False)
87
- fmp = gr.Checkbox(label="fmp", value=False)
88
- vc_submit = gr.Button("Generate", variant="primary")
89
-
90
- with gr.Column():
91
- vc_output1 = gr.Textbox(label="Output Message")
92
- vc_output2 = gr.Audio(label="Output Audio")
93
- vc_submit.click(vc_fn, [vc_input, vc_transform, auto_f0, fmp], [vc_output1, vc_output2])
94
-
95
- """
96
- for category, link in others.items():
97
- with gr.TabItem(category):
98
- gr.Markdown(
99
- f'''
100
- <center>
101
- <h2>Click to Go</h2>
102
- <a href="{link}">
103
- <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-xl-dark.svg"
104
- </a>
105
- </center>
106
- '''
107
- )
108
- """
109
- app.queue(concurrency_count=1, api_open=args.api).launch(share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import glob
3
+ import re
4
+ import sys
5
+ import argparse
6
+ import logging
7
+ import json
8
+ import subprocess
9
+ import warnings
10
+ import random
11
+ import functools
12
+
13
  import librosa
14
  import numpy as np
15
+ from scipy.io.wavfile import read
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from modules.commons import sequence_mask
19
+ from hubert import hubert_model
20
+
21
+ MATPLOTLIB_FLAG = False
22
+
23
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
24
+ logger = logging
25
+
26
+ f0_bin = 256
27
+ f0_max = 1100.0
28
+ f0_min = 50.0
29
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
30
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
31
+
32
+
33
+ # def normalize_f0(f0, random_scale=True):
34
+ # f0_norm = f0.clone() # create a copy of the input Tensor
35
+ # batch_size, _, frame_length = f0_norm.shape
36
+ # for i in range(batch_size):
37
+ # means = torch.mean(f0_norm[i, 0, :])
38
+ # if random_scale:
39
+ # factor = random.uniform(0.8, 1.2)
40
+ # else:
41
+ # factor = 1
42
+ # f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor
43
+ # return f0_norm
44
+ # def normalize_f0(f0, random_scale=True):
45
+ # means = torch.mean(f0[:, 0, :], dim=1, keepdim=True)
46
+ # if random_scale:
47
+ # factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device)
48
+ # else:
49
+ # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
50
+ # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
51
+ # return f0_norm
52
+
53
+ def deprecated(func):
54
+ """This is a decorator which can be used to mark functions
55
+ as deprecated. It will result in a warning being emitted
56
+ when the function is used."""
57
+ @functools.wraps(func)
58
+ def new_func(*args, **kwargs):
59
+ warnings.simplefilter('always', DeprecationWarning) # turn off filter
60
+ warnings.warn("Call to deprecated function {}.".format(func.__name__),
61
+ category=DeprecationWarning,
62
+ stacklevel=2)
63
+ warnings.simplefilter('default', DeprecationWarning) # reset filter
64
+ return func(*args, **kwargs)
65
+ return new_func
66
+
67
+ def normalize_f0(f0, x_mask, uv, random_scale=True):
68
+ # calculate means based on x_mask
69
+ uv_sum = torch.sum(uv, dim=1, keepdim=True)
70
+ uv_sum[uv_sum == 0] = 9999
71
+ means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
72
+
73
+ if random_scale:
74
+ factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
75
+ else:
76
+ factor = torch.ones(f0.shape[0], 1).to(f0.device)
77
+ # normalize f0 based on means and factor
78
+ f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
79
+ if torch.isnan(f0_norm).any():
80
+ exit(0)
81
+ return f0_norm * x_mask
82
+
83
+ def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05):
84
+ from modules.crepe import CrepePitchExtractor
85
+ x = wav_numpy
86
+ if p_len is None:
87
+ p_len = x.shape[0]//hop_length
88
+ else:
89
+ assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
90
+
91
+ f0_min = 50
92
+ f0_max = 1100
93
+ F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold)
94
+ f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len)
95
+ return f0,uv
96
+
97
+ def plot_data_to_numpy(x, y):
98
+ global MATPLOTLIB_FLAG
99
+ if not MATPLOTLIB_FLAG:
100
+ import matplotlib
101
+ matplotlib.use("Agg")
102
+ MATPLOTLIB_FLAG = True
103
+ mpl_logger = logging.getLogger('matplotlib')
104
+ mpl_logger.setLevel(logging.WARNING)
105
+ import matplotlib.pylab as plt
106
+ import numpy as np
107
+
108
+ fig, ax = plt.subplots(figsize=(10, 2))
109
+ plt.plot(x)
110
+ plt.plot(y)
111
+ plt.tight_layout()
112
+
113
+ fig.canvas.draw()
114
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
115
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
116
+ plt.close()
117
+ return data
118
+
119
+
120
+
121
+ def interpolate_f0(f0):
122
+
123
+ data = np.reshape(f0, (f0.size, 1))
124
+
125
+ vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
126
+ vuv_vector[data > 0.0] = 1.0
127
+ vuv_vector[data <= 0.0] = 0.0
128
+
129
+ ip_data = data
130
+
131
+ frame_number = data.size
132
+ last_value = 0.0
133
+ for i in range(frame_number):
134
+ if data[i] <= 0.0:
135
+ j = i + 1
136
+ for j in range(i + 1, frame_number):
137
+ if data[j] > 0.0:
138
+ break
139
+ if j < frame_number - 1:
140
+ if last_value > 0.0:
141
+ step = (data[j] - data[i - 1]) / float(j - i)
142
+ for k in range(i, j):
143
+ ip_data[k] = data[i - 1] + step * (k - i + 1)
144
+ else:
145
+ for k in range(i, j):
146
+ ip_data[k] = data[j]
147
+ else:
148
+ for k in range(i, frame_number):
149
+ ip_data[k] = last_value
150
+ else:
151
+ ip_data[i] = data[i] # this may not be necessary
152
+ last_value = data[i]
153
+
154
+ return ip_data[:,0], vuv_vector[:,0]
155
+
156
+
157
+ def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
158
+ import parselmouth
159
+ x = wav_numpy
160
+ if p_len is None:
161
+ p_len = x.shape[0]//hop_length
162
+ else:
163
+ assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
164
+ time_step = hop_length / sampling_rate * 1000
165
+ f0_min = 50
166
+ f0_max = 1100
167
+ f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
168
+ time_step=time_step / 1000, voicing_threshold=0.6,
169
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
170
+
171
+ pad_size=(p_len - len(f0) + 1) // 2
172
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
173
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
174
+ return f0
175
+
176
+ def resize_f0(x, target_len):
177
+ source = np.array(x)
178
+ source[source<0.001] = np.nan
179
+ target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
180
+ res = np.nan_to_num(target)
181
+ return res
182
+
183
+ def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
184
+ import pyworld
185
+ if p_len is None:
186
+ p_len = wav_numpy.shape[0]//hop_length
187
+ f0, t = pyworld.dio(
188
+ wav_numpy.astype(np.double),
189
+ fs=sampling_rate,
190
+ f0_ceil=800,
191
+ frame_period=1000 * hop_length / sampling_rate,
192
+ )
193
+ f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
194
+ for index, pitch in enumerate(f0):
195
+ f0[index] = round(pitch, 1)
196
+ return resize_f0(f0, p_len)
197
+
198
+ def f0_to_coarse(f0):
199
+ is_torch = isinstance(f0, torch.Tensor)
200
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
201
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
202
+
203
+ f0_mel[f0_mel <= 1] = 1
204
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
205
+ f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
206
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
207
+ return f0_coarse
208
+
209
+
210
+ def get_hubert_model():
211
+ vec_path = "hubert/checkpoint_best_legacy_500.pt"
212
+ print("load model(s) from {}".format(vec_path))
213
+ from fairseq import checkpoint_utils
214
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
215
+ [vec_path],
216
+ suffix="",
217
+ )
218
+ model = models[0]
219
+ model.eval()
220
+ return model
221
+
222
+ def get_hubert_content(hmodel, wav_16k_tensor):
223
+ feats = wav_16k_tensor
224
+ if feats.dim() == 2: # double channels
225
+ feats = feats.mean(-1)
226
+ assert feats.dim() == 1, feats.dim()
227
+ feats = feats.view(1, -1)
228
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
229
+ inputs = {
230
+ "source": feats.to(wav_16k_tensor.device),
231
+ "padding_mask": padding_mask.to(wav_16k_tensor.device),
232
+ "output_layer": 9, # layer 9
233
+ }
234
+ with torch.no_grad():
235
+ logits = hmodel.extract_features(**inputs)
236
+ feats = hmodel.final_proj(logits[0])
237
+ return feats.transpose(1, 2)
238
+
239
+
240
+ def get_content(cmodel, y):
241
+ with torch.no_grad():
242
+ c = cmodel.extract_features(y.squeeze(1))[0]
243
+ c = c.transpose(1, 2)
244
+ return c
245
+
246
+
247
+
248
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
249
+ assert os.path.isfile(checkpoint_path)
250
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
251
+ iteration = checkpoint_dict['iteration']
252
+ learning_rate = checkpoint_dict['learning_rate']
253
+ if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
254
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
255
+ saved_state_dict = checkpoint_dict['model']
256
+ if hasattr(model, 'module'):
257
+ state_dict = model.module.state_dict()
258
+ else:
259
+ state_dict = model.state_dict()
260
+ new_state_dict = {}
261
+ for k, v in state_dict.items():
262
+ try:
263
+ # assert "dec" in k or "disc" in k
264
+ # print("load", k)
265
+ new_state_dict[k] = saved_state_dict[k]
266
+ assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
267
+ except:
268
+ print("error, %s is not in the checkpoint" % k)
269
+ logger.info("%s is not in the checkpoint" % k)
270
+ new_state_dict[k] = v
271
+ if hasattr(model, 'module'):
272
+ model.module.load_state_dict(new_state_dict)
273
+ else:
274
+ model.load_state_dict(new_state_dict)
275
+ print("load ")
276
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
277
+ checkpoint_path, iteration))
278
+ return model, optimizer, learning_rate, iteration
279
+
280
+
281
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
282
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
283
+ iteration, checkpoint_path))
284
+ if hasattr(model, 'module'):
285
+ state_dict = model.module.state_dict()
286
+ else:
287
+ state_dict = model.state_dict()
288
+ torch.save({'model': state_dict,
289
+ 'iteration': iteration,
290
+ 'optimizer': optimizer.state_dict(),
291
+ 'learning_rate': learning_rate}, checkpoint_path)
292
+
293
+ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
294
+ """Freeing up space by deleting saved ckpts
295
+
296
+ Arguments:
297
+ path_to_models -- Path to the model directory
298
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
299
+ sort_by_time -- True -> chronologically delete ckpts
300
+ False -> lexicographically delete ckpts
301
+ """
302
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
303
+ name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
304
+ time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
305
+ sort_key = time_key if sort_by_time else name_key
306
+ x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
307
+ to_del = [os.path.join(path_to_models, fn) for fn in
308
+ (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
309
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
310
+ del_routine = lambda x: [os.remove(x), del_info(x)]
311
+ rs = [del_routine(fn) for fn in to_del]
312
+
313
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
314
+ for k, v in scalars.items():
315
+ writer.add_scalar(k, v, global_step)
316
+ for k, v in histograms.items():
317
+ writer.add_histogram(k, v, global_step)
318
+ for k, v in images.items():
319
+ writer.add_image(k, v, global_step, dataformats='HWC')
320
+ for k, v in audios.items():
321
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
322
+
323
+
324
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
325
+ f_list = glob.glob(os.path.join(dir_path, regex))
326
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
327
+ x = f_list[-1]
328
+ print(x)
329
+ return x
330
+
331
+
332
+ def plot_spectrogram_to_numpy(spectrogram):
333
+ global MATPLOTLIB_FLAG
334
+ if not MATPLOTLIB_FLAG:
335
+ import matplotlib
336
+ matplotlib.use("Agg")
337
+ MATPLOTLIB_FLAG = True
338
+ mpl_logger = logging.getLogger('matplotlib')
339
+ mpl_logger.setLevel(logging.WARNING)
340
+ import matplotlib.pylab as plt
341
+ import numpy as np
342
+
343
+ fig, ax = plt.subplots(figsize=(10,2))
344
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
345
+ interpolation='none')
346
+ plt.colorbar(im, ax=ax)
347
+ plt.xlabel("Frames")
348
+ plt.ylabel("Channels")
349
+ plt.tight_layout()
350
+
351
+ fig.canvas.draw()
352
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
353
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
354
+ plt.close()
355
+ return data
356
+
357
+
358
+ def plot_alignment_to_numpy(alignment, info=None):
359
+ global MATPLOTLIB_FLAG
360
+ if not MATPLOTLIB_FLAG:
361
+ import matplotlib
362
+ matplotlib.use("Agg")
363
+ MATPLOTLIB_FLAG = True
364
+ mpl_logger = logging.getLogger('matplotlib')
365
+ mpl_logger.setLevel(logging.WARNING)
366
+ import matplotlib.pylab as plt
367
+ import numpy as np
368
+
369
+ fig, ax = plt.subplots(figsize=(6, 4))
370
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
371
+ interpolation='none')
372
+ fig.colorbar(im, ax=ax)
373
+ xlabel = 'Decoder timestep'
374
+ if info is not None:
375
+ xlabel += '\n\n' + info
376
+ plt.xlabel(xlabel)
377
+ plt.ylabel('Encoder timestep')
378
+ plt.tight_layout()
379
+
380
+ fig.canvas.draw()
381
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
382
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
383
+ plt.close()
384
+ return data
385
+
386
+
387
+ def load_wav_to_torch(full_path):
388
+ sampling_rate, data = read(full_path)
389
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
390
+
391
+
392
+ def load_filepaths_and_text(filename, split="|"):
393
+ with open(filename, encoding='utf-8') as f:
394
+ filepaths_and_text = [line.strip().split(split) for line in f]
395
+ return filepaths_and_text
396
+
397
+
398
+ def get_hparams(init=True):
399
+ parser = argparse.ArgumentParser()
400
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
401
+ help='JSON file for configuration')
402
+ parser.add_argument('-m', '--model', type=str, required=True,
403
+ help='Model name')
404
+
405
+ args = parser.parse_args()
406
+ model_dir = os.path.join("./logs", args.model)
407
+
408
+ if not os.path.exists(model_dir):
409
+ os.makedirs(model_dir)
410
+
411
+ config_path = args.config
412
+ config_save_path = os.path.join(model_dir, "config.json")
413
+ if init:
414
+ with open(config_path, "r") as f:
415
+ data = f.read()
416
+ with open(config_save_path, "w") as f:
417
+ f.write(data)
418
+ else:
419
+ with open(config_save_path, "r") as f:
420
+ data = f.read()
421
+ config = json.loads(data)
422
+
423
+ hparams = HParams(**config)
424
+ hparams.model_dir = model_dir
425
+ return hparams
426
+
427
+
428
+ def get_hparams_from_dir(model_dir):
429
+ config_save_path = os.path.join(model_dir, "config.json")
430
+ with open(config_save_path, "r") as f:
431
+ data = f.read()
432
+ config = json.loads(data)
433
+
434
+ hparams =HParams(**config)
435
+ hparams.model_dir = model_dir
436
+ return hparams
437
+
438
+
439
+ def get_hparams_from_file(config_path):
440
+ with open(config_path, "r") as f:
441
+ data = f.read()
442
+ config = json.loads(data)
443
+
444
+ hparams =HParams(**config)
445
+ return hparams
446
+
447
+
448
+ def check_git_hash(model_dir):
449
+ source_dir = os.path.dirname(os.path.realpath(__file__))
450
+ if not os.path.exists(os.path.join(source_dir, ".git")):
451
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
452
+ source_dir
453
+ ))
454
+ return
455
+
456
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
457
+
458
+ path = os.path.join(model_dir, "githash")
459
+ if os.path.exists(path):
460
+ saved_hash = open(path).read()
461
+ if saved_hash != cur_hash:
462
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
463
+ saved_hash[:8], cur_hash[:8]))
464
+ else:
465
+ open(path, "w").write(cur_hash)
466
+
467
+
468
+ def get_logger(model_dir, filename="train.log"):
469
+ global logger
470
+ logger = logging.getLogger(os.path.basename(model_dir))
471
+ logger.setLevel(logging.DEBUG)
472
+
473
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
474
+ if not os.path.exists(model_dir):
475
+ os.makedirs(model_dir)
476
+ h = logging.FileHandler(os.path.join(model_dir, filename))
477
+ h.setLevel(logging.DEBUG)
478
+ h.setFormatter(formatter)
479
+ logger.addHandler(h)
480
+ return logger
481
+
482
+
483
+ def repeat_expand_2d(content, target_len):
484
+ # content : [h, t]
485
+
486
+ src_len = content.shape[-1]
487
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
488
+ temp = torch.arange(src_len+1) * target_len / src_len
489
+ current_pos = 0
490
+ for i in range(target_len):
491
+ if i < temp[current_pos+1]:
492
+ target[:, i] = content[:, current_pos]
493
+ else:
494
+ current_pos += 1
495
+ target[:, i] = content[:, current_pos]
496
+
497
+ return target
498
+
499
+
500
+ def mix_model(model_paths,mix_rate,mode):
501
+ mix_rate = torch.FloatTensor(mix_rate)/100
502
+ model_tem = torch.load(model_paths[0])
503
+ models = [torch.load(path)["model"] for path in model_paths]
504
+ if mode == 0:
505
+ mix_rate = F.softmax(mix_rate,dim=0)
506
+ for k in model_tem["model"].keys():
507
+ model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
508
+ for i,model in enumerate(models):
509
+ model_tem["model"][k] += model[k]*mix_rate[i]
510
+ torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
511
+ return os.path.join(os.path.curdir,"output.pth")
512
+
513
+ class HParams():
514
+ def __init__(self, **kwargs):
515
+ for k, v in kwargs.items():
516
+ if type(v) == dict:
517
+ v = HParams(**v)
518
+ self[k] = v
519
+
520
+ def keys(self):
521
+ return self.__dict__.keys()
522
+
523
+ def items(self):
524
+ return self.__dict__.items()
525
+
526
+ def values(self):
527
+ return self.__dict__.values()
528
+
529
+ def __len__(self):
530
+ return len(self.__dict__)
531
+
532
+ def __getitem__(self, key):
533
+ return getattr(self, key)
534
+
535
+ def __setitem__(self, key, value):
536
+ return setattr(self, key, value)
537
+
538
+ def __contains__(self, key):
539
+ return key in self.__dict__
540
+
541
+ def __repr__(self):
542
+ return self.__dict__.__repr__()