AiMimicry commited on
Commit
30a73fc
·
1 Parent(s): 576ea34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -539
app.py CHANGED
@@ -1,542 +1,109 @@
1
  import os
2
- import glob
3
- import re
4
- import sys
5
- import argparse
6
- import logging
7
- import json
8
- import subprocess
9
- import warnings
10
- import random
11
- import functools
12
-
13
  import librosa
14
  import numpy as np
15
- from scipy.io.wavfile import read
16
- import torch
17
- from torch.nn import functional as F
18
- from modules.commons import sequence_mask
19
- from hubert import hubert_model
20
-
21
- MATPLOTLIB_FLAG = False
22
-
23
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
24
- logger = logging
25
-
26
- f0_bin = 256
27
- f0_max = 1100.0
28
- f0_min = 50.0
29
- f0_mel_min = 1127 * np.log(1 + f0_min / 700)
30
- f0_mel_max = 1127 * np.log(1 + f0_max / 700)
31
-
32
-
33
- # def normalize_f0(f0, random_scale=True):
34
- # f0_norm = f0.clone() # create a copy of the input Tensor
35
- # batch_size, _, frame_length = f0_norm.shape
36
- # for i in range(batch_size):
37
- # means = torch.mean(f0_norm[i, 0, :])
38
- # if random_scale:
39
- # factor = random.uniform(0.8, 1.2)
40
- # else:
41
- # factor = 1
42
- # f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor
43
- # return f0_norm
44
- # def normalize_f0(f0, random_scale=True):
45
- # means = torch.mean(f0[:, 0, :], dim=1, keepdim=True)
46
- # if random_scale:
47
- # factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device)
48
- # else:
49
- # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
50
- # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
51
- # return f0_norm
52
-
53
- def deprecated(func):
54
- """This is a decorator which can be used to mark functions
55
- as deprecated. It will result in a warning being emitted
56
- when the function is used."""
57
- @functools.wraps(func)
58
- def new_func(*args, **kwargs):
59
- warnings.simplefilter('always', DeprecationWarning) # turn off filter
60
- warnings.warn("Call to deprecated function {}.".format(func.__name__),
61
- category=DeprecationWarning,
62
- stacklevel=2)
63
- warnings.simplefilter('default', DeprecationWarning) # reset filter
64
- return func(*args, **kwargs)
65
- return new_func
66
-
67
- def normalize_f0(f0, x_mask, uv, random_scale=True):
68
- # calculate means based on x_mask
69
- uv_sum = torch.sum(uv, dim=1, keepdim=True)
70
- uv_sum[uv_sum == 0] = 9999
71
- means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
72
-
73
- if random_scale:
74
- factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
75
- else:
76
- factor = torch.ones(f0.shape[0], 1).to(f0.device)
77
- # normalize f0 based on means and factor
78
- f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
79
- if torch.isnan(f0_norm).any():
80
- exit(0)
81
- return f0_norm * x_mask
82
-
83
- def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05):
84
- from modules.crepe import CrepePitchExtractor
85
- x = wav_numpy
86
- if p_len is None:
87
- p_len = x.shape[0]//hop_length
88
- else:
89
- assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
90
-
91
- f0_min = 50
92
- f0_max = 1100
93
- F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold)
94
- f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len)
95
- return f0,uv
96
-
97
- def plot_data_to_numpy(x, y):
98
- global MATPLOTLIB_FLAG
99
- if not MATPLOTLIB_FLAG:
100
- import matplotlib
101
- matplotlib.use("Agg")
102
- MATPLOTLIB_FLAG = True
103
- mpl_logger = logging.getLogger('matplotlib')
104
- mpl_logger.setLevel(logging.WARNING)
105
- import matplotlib.pylab as plt
106
- import numpy as np
107
-
108
- fig, ax = plt.subplots(figsize=(10, 2))
109
- plt.plot(x)
110
- plt.plot(y)
111
- plt.tight_layout()
112
-
113
- fig.canvas.draw()
114
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
115
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
116
- plt.close()
117
- return data
118
-
119
-
120
-
121
- def interpolate_f0(f0):
122
-
123
- data = np.reshape(f0, (f0.size, 1))
124
-
125
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
126
- vuv_vector[data > 0.0] = 1.0
127
- vuv_vector[data <= 0.0] = 0.0
128
-
129
- ip_data = data
130
-
131
- frame_number = data.size
132
- last_value = 0.0
133
- for i in range(frame_number):
134
- if data[i] <= 0.0:
135
- j = i + 1
136
- for j in range(i + 1, frame_number):
137
- if data[j] > 0.0:
138
- break
139
- if j < frame_number - 1:
140
- if last_value > 0.0:
141
- step = (data[j] - data[i - 1]) / float(j - i)
142
- for k in range(i, j):
143
- ip_data[k] = data[i - 1] + step * (k - i + 1)
144
- else:
145
- for k in range(i, j):
146
- ip_data[k] = data[j]
147
- else:
148
- for k in range(i, frame_number):
149
- ip_data[k] = last_value
150
- else:
151
- ip_data[i] = data[i] # this may not be necessary
152
- last_value = data[i]
153
-
154
- return ip_data[:,0], vuv_vector[:,0]
155
-
156
-
157
- def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
158
- import parselmouth
159
- x = wav_numpy
160
- if p_len is None:
161
- p_len = x.shape[0]//hop_length
162
- else:
163
- assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
164
- time_step = hop_length / sampling_rate * 1000
165
- f0_min = 50
166
- f0_max = 1100
167
- f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
168
- time_step=time_step / 1000, voicing_threshold=0.6,
169
- pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
170
-
171
- pad_size=(p_len - len(f0) + 1) // 2
172
- if(pad_size>0 or p_len - len(f0) - pad_size>0):
173
- f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
174
- return f0
175
-
176
- def resize_f0(x, target_len):
177
- source = np.array(x)
178
- source[source<0.001] = np.nan
179
- target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
180
- res = np.nan_to_num(target)
181
- return res
182
-
183
- def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
184
- import pyworld
185
- if p_len is None:
186
- p_len = wav_numpy.shape[0]//hop_length
187
- f0, t = pyworld.dio(
188
- wav_numpy.astype(np.double),
189
- fs=sampling_rate,
190
- f0_ceil=800,
191
- frame_period=1000 * hop_length / sampling_rate,
192
- )
193
- f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
194
- for index, pitch in enumerate(f0):
195
- f0[index] = round(pitch, 1)
196
- return resize_f0(f0, p_len)
197
-
198
- def f0_to_coarse(f0):
199
- is_torch = isinstance(f0, torch.Tensor)
200
- f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
201
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
202
-
203
- f0_mel[f0_mel <= 1] = 1
204
- f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
205
- f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
206
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
207
- return f0_coarse
208
-
209
-
210
- def get_hubert_model():
211
- vec_path = "hubert/checkpoint_best_legacy_500.pt"
212
- print("load model(s) from {}".format(vec_path))
213
- from fairseq import checkpoint_utils
214
- models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
215
- [vec_path],
216
- suffix="",
217
- )
218
- model = models[0]
219
- model.eval()
220
- return model
221
-
222
- def get_hubert_content(hmodel, wav_16k_tensor):
223
- feats = wav_16k_tensor
224
- if feats.dim() == 2: # double channels
225
- feats = feats.mean(-1)
226
- assert feats.dim() == 1, feats.dim()
227
- feats = feats.view(1, -1)
228
- padding_mask = torch.BoolTensor(feats.shape).fill_(False)
229
- inputs = {
230
- "source": feats.to(wav_16k_tensor.device),
231
- "padding_mask": padding_mask.to(wav_16k_tensor.device),
232
- "output_layer": 9, # layer 9
233
- }
234
- with torch.no_grad():
235
- logits = hmodel.extract_features(**inputs)
236
- feats = hmodel.final_proj(logits[0])
237
- return feats.transpose(1, 2)
238
-
239
-
240
- def get_content(cmodel, y):
241
- with torch.no_grad():
242
- c = cmodel.extract_features(y.squeeze(1))[0]
243
- c = c.transpose(1, 2)
244
- return c
245
-
246
-
247
-
248
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
249
- assert os.path.isfile(checkpoint_path)
250
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
251
- iteration = checkpoint_dict['iteration']
252
- learning_rate = checkpoint_dict['learning_rate']
253
- if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
254
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
255
- saved_state_dict = checkpoint_dict['model']
256
- if hasattr(model, 'module'):
257
- state_dict = model.module.state_dict()
258
- else:
259
- state_dict = model.state_dict()
260
- new_state_dict = {}
261
- for k, v in state_dict.items():
262
- try:
263
- # assert "dec" in k or "disc" in k
264
- # print("load", k)
265
- new_state_dict[k] = saved_state_dict[k]
266
- assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
267
- except:
268
- print("error, %s is not in the checkpoint" % k)
269
- logger.info("%s is not in the checkpoint" % k)
270
- new_state_dict[k] = v
271
- if hasattr(model, 'module'):
272
- model.module.load_state_dict(new_state_dict)
273
- else:
274
- model.load_state_dict(new_state_dict)
275
- print("load ")
276
- logger.info("Loaded checkpoint '{}' (iteration {})".format(
277
- checkpoint_path, iteration))
278
- return model, optimizer, learning_rate, iteration
279
-
280
-
281
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
282
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
283
- iteration, checkpoint_path))
284
- if hasattr(model, 'module'):
285
- state_dict = model.module.state_dict()
286
- else:
287
- state_dict = model.state_dict()
288
- torch.save({'model': state_dict,
289
- 'iteration': iteration,
290
- 'optimizer': optimizer.state_dict(),
291
- 'learning_rate': learning_rate}, checkpoint_path)
292
-
293
- def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
294
- """Freeing up space by deleting saved ckpts
295
-
296
- Arguments:
297
- path_to_models -- Path to the model directory
298
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
299
- sort_by_time -- True -> chronologically delete ckpts
300
- False -> lexicographically delete ckpts
301
- """
302
- ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
303
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
304
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
305
- sort_key = time_key if sort_by_time else name_key
306
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
307
- to_del = [os.path.join(path_to_models, fn) for fn in
308
- (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
309
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
310
- del_routine = lambda x: [os.remove(x), del_info(x)]
311
- rs = [del_routine(fn) for fn in to_del]
312
-
313
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
314
- for k, v in scalars.items():
315
- writer.add_scalar(k, v, global_step)
316
- for k, v in histograms.items():
317
- writer.add_histogram(k, v, global_step)
318
- for k, v in images.items():
319
- writer.add_image(k, v, global_step, dataformats='HWC')
320
- for k, v in audios.items():
321
- writer.add_audio(k, v, global_step, audio_sampling_rate)
322
-
323
-
324
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
325
- f_list = glob.glob(os.path.join(dir_path, regex))
326
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
327
- x = f_list[-1]
328
- print(x)
329
- return x
330
-
331
-
332
- def plot_spectrogram_to_numpy(spectrogram):
333
- global MATPLOTLIB_FLAG
334
- if not MATPLOTLIB_FLAG:
335
- import matplotlib
336
- matplotlib.use("Agg")
337
- MATPLOTLIB_FLAG = True
338
- mpl_logger = logging.getLogger('matplotlib')
339
- mpl_logger.setLevel(logging.WARNING)
340
- import matplotlib.pylab as plt
341
- import numpy as np
342
-
343
- fig, ax = plt.subplots(figsize=(10,2))
344
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
345
- interpolation='none')
346
- plt.colorbar(im, ax=ax)
347
- plt.xlabel("Frames")
348
- plt.ylabel("Channels")
349
- plt.tight_layout()
350
-
351
- fig.canvas.draw()
352
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
353
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
354
- plt.close()
355
- return data
356
-
357
-
358
- def plot_alignment_to_numpy(alignment, info=None):
359
- global MATPLOTLIB_FLAG
360
- if not MATPLOTLIB_FLAG:
361
- import matplotlib
362
- matplotlib.use("Agg")
363
- MATPLOTLIB_FLAG = True
364
- mpl_logger = logging.getLogger('matplotlib')
365
- mpl_logger.setLevel(logging.WARNING)
366
- import matplotlib.pylab as plt
367
- import numpy as np
368
-
369
- fig, ax = plt.subplots(figsize=(6, 4))
370
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
371
- interpolation='none')
372
- fig.colorbar(im, ax=ax)
373
- xlabel = 'Decoder timestep'
374
- if info is not None:
375
- xlabel += '\n\n' + info
376
- plt.xlabel(xlabel)
377
- plt.ylabel('Encoder timestep')
378
- plt.tight_layout()
379
-
380
- fig.canvas.draw()
381
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
382
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
383
- plt.close()
384
- return data
385
-
386
-
387
- def load_wav_to_torch(full_path):
388
- sampling_rate, data = read(full_path)
389
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
390
-
391
-
392
- def load_filepaths_and_text(filename, split="|"):
393
- with open(filename, encoding='utf-8') as f:
394
- filepaths_and_text = [line.strip().split(split) for line in f]
395
- return filepaths_and_text
396
-
397
-
398
- def get_hparams(init=True):
399
- parser = argparse.ArgumentParser()
400
- parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
401
- help='JSON file for configuration')
402
- parser.add_argument('-m', '--model', type=str, required=True,
403
- help='Model name')
404
-
405
- args = parser.parse_args()
406
- model_dir = os.path.join("./logs", args.model)
407
-
408
- if not os.path.exists(model_dir):
409
- os.makedirs(model_dir)
410
-
411
- config_path = args.config
412
- config_save_path = os.path.join(model_dir, "config.json")
413
- if init:
414
- with open(config_path, "r") as f:
415
- data = f.read()
416
- with open(config_save_path, "w") as f:
417
- f.write(data)
418
- else:
419
- with open(config_save_path, "r") as f:
420
- data = f.read()
421
- config = json.loads(data)
422
-
423
- hparams = HParams(**config)
424
- hparams.model_dir = model_dir
425
- return hparams
426
-
427
-
428
- def get_hparams_from_dir(model_dir):
429
- config_save_path = os.path.join(model_dir, "config.json")
430
- with open(config_save_path, "r") as f:
431
- data = f.read()
432
- config = json.loads(data)
433
-
434
- hparams =HParams(**config)
435
- hparams.model_dir = model_dir
436
- return hparams
437
-
438
-
439
- def get_hparams_from_file(config_path):
440
- with open(config_path, "r") as f:
441
- data = f.read()
442
- config = json.loads(data)
443
-
444
- hparams =HParams(**config)
445
- return hparams
446
-
447
-
448
- def check_git_hash(model_dir):
449
- source_dir = os.path.dirname(os.path.realpath(__file__))
450
- if not os.path.exists(os.path.join(source_dir, ".git")):
451
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
452
- source_dir
453
- ))
454
- return
455
-
456
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
457
-
458
- path = os.path.join(model_dir, "githash")
459
- if os.path.exists(path):
460
- saved_hash = open(path).read()
461
- if saved_hash != cur_hash:
462
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
463
- saved_hash[:8], cur_hash[:8]))
464
- else:
465
- open(path, "w").write(cur_hash)
466
-
467
-
468
- def get_logger(model_dir, filename="train.log"):
469
- global logger
470
- logger = logging.getLogger(os.path.basename(model_dir))
471
- logger.setLevel(logging.DEBUG)
472
-
473
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
474
- if not os.path.exists(model_dir):
475
- os.makedirs(model_dir)
476
- h = logging.FileHandler(os.path.join(model_dir, filename))
477
- h.setLevel(logging.DEBUG)
478
- h.setFormatter(formatter)
479
- logger.addHandler(h)
480
- return logger
481
-
482
-
483
- def repeat_expand_2d(content, target_len):
484
- # content : [h, t]
485
-
486
- src_len = content.shape[-1]
487
- target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
488
- temp = torch.arange(src_len+1) * target_len / src_len
489
- current_pos = 0
490
- for i in range(target_len):
491
- if i < temp[current_pos+1]:
492
- target[:, i] = content[:, current_pos]
493
- else:
494
- current_pos += 1
495
- target[:, i] = content[:, current_pos]
496
-
497
- return target
498
-
499
-
500
- def mix_model(model_paths,mix_rate,mode):
501
- mix_rate = torch.FloatTensor(mix_rate)/100
502
- model_tem = torch.load(model_paths[0])
503
- models = [torch.load(path)["model"] for path in model_paths]
504
- if mode == 0:
505
- mix_rate = F.softmax(mix_rate,dim=0)
506
- for k in model_tem["model"].keys():
507
- model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
508
- for i,model in enumerate(models):
509
- model_tem["model"][k] += model[k]*mix_rate[i]
510
- torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
511
- return os.path.join(os.path.curdir,"output.pth")
512
-
513
- class HParams():
514
- def __init__(self, **kwargs):
515
- for k, v in kwargs.items():
516
- if type(v) == dict:
517
- v = HParams(**v)
518
- self[k] = v
519
-
520
- def keys(self):
521
- return self.__dict__.keys()
522
-
523
- def items(self):
524
- return self.__dict__.items()
525
-
526
- def values(self):
527
- return self.__dict__.values()
528
-
529
- def __len__(self):
530
- return len(self.__dict__)
531
-
532
- def __getitem__(self, key):
533
- return getattr(self, key)
534
-
535
- def __setitem__(self, key, value):
536
- return setattr(self, key, value)
537
-
538
- def __contains__(self, key):
539
- return key in self.__dict__
540
-
541
- def __repr__(self):
542
- return self.__dict__.__repr__()
 
1
  import os
2
+ import io
3
+ import gradio as gr
 
 
 
 
 
 
 
 
 
4
  import librosa
5
  import numpy as np
6
+ import utils
7
+ from inference.infer_tool import Svc
8
+ import logging
9
+ import soundfile
10
+ import asyncio
11
+ import argparse
12
+ import gradio.processing_utils as gr_processing_utils
13
+ logging.getLogger('numba').setLevel(logging.WARNING)
14
+ logging.getLogger('markdown_it').setLevel(logging.WARNING)
15
+ logging.getLogger('urllib3').setLevel(logging.WARNING)
16
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
17
+
18
+ limitation = os.getenv("SYSTEM") == "spaces" # limit audio length in huggingface spaces
19
+
20
+ audio_postprocess_ori = gr.Audio.postprocess
21
+
22
+ def audio_postprocess(self, y):
23
+ data = audio_postprocess_ori(self, y)
24
+ if data is None:
25
+ return None
26
+ return gr_processing_utils.encode_url_or_file_to_base64(data["name"])
27
+
28
+
29
+ gr.Audio.postprocess = audio_postprocess
30
+ def create_vc_fn(model, sid):
31
+ def vc_fn(input_audio, vc_transform, auto_f0, fmp):
32
+ if input_audio is None:
33
+ return "You need to upload an audio", None
34
+ sampling_rate, audio = input_audio
35
+ duration = audio.shape[0] / sampling_rate
36
+ if duration > 20 and limitation:
37
+ return "Please upload an audio file that is less than 20 seconds. If you need to generate a longer audio file, please use Colab.", None
38
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
39
+ if len(audio.shape) > 1:
40
+ audio = librosa.to_mono(audio.transpose(1, 0))
41
+ if sampling_rate != 16000:
42
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
43
+ raw_path = io.BytesIO()
44
+ soundfile.write(raw_path, audio, 16000, format="wav")
45
+ raw_path.seek(0)
46
+ out_audio, out_sr = model.infer(sid, vc_transform, raw_path,
47
+ auto_predict_f0=auto_f0, F0_mean_pooling=fmp
48
+ )
49
+ return "Success", (44100, out_audio.cpu().numpy())
50
+ return vc_fn
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('--device', type=str, default='cpu')
55
+ parser.add_argument('--api', action="store_true", default=False)
56
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
57
+ args = parser.parse_args()
58
+ hubert_model = utils.get_hubert_model().to(args.device)
59
+ models = []
60
+ voices = []
61
+ for f in os.listdir("models"):
62
+ name = f
63
+ model = Svc(fr"models/{f}/{f}.pth", f"models/{f}/config.json", device=args.device)
64
+ cover = f"models/{f}/cover.jpg" if os.path.exists(f"models/{f}/cover.jpg") else None
65
+ models.append((name, cover, create_vc_fn(model, name)))
66
+ with gr.Blocks() as app:
67
+ gr.Markdown(
68
+ "# <center> Sovits Models\n"
69
+ "## <center> The input audio should be clean and pure voice without background music.\n"
70
+ "[![Original Repo](https://badgen.net/badge/icon/github?icon=github&label=Original%20Repo)](https://github.com/svc-develop-team/so-vits-svc)"
71
+ )
72
+
73
+ with gr.Tabs():
74
+ for (name, cover, vc_fn) in models:
75
+ with gr.TabItem(name):
76
+ with gr.Row():
77
+ gr.Markdown(
78
+ '<div align="center">'
79
+ f'<img style="width:auto;height:300px;" src="file/{cover}">' if cover else ""
80
+ '</div>'
81
+ )
82
+ with gr.Row():
83
+ with gr.Column():
84
+ vc_input = gr.Audio(label="Input audio"+' (less than 20 seconds)' if limitation else '')
85
+ vc_transform = gr.Number(label="vc_transform", value=0)
86
+ auto_f0 = gr.Checkbox(label="auto_f0", value=False)
87
+ fmp = gr.Checkbox(label="fmp", value=False)
88
+ vc_submit = gr.Button("Generate", variant="primary")
89
+
90
+ with gr.Column():
91
+ vc_output1 = gr.Textbox(label="Output Message")
92
+ vc_output2 = gr.Audio(label="Output Audio")
93
+ vc_submit.click(vc_fn, [vc_input, vc_transform, auto_f0, fmp], [vc_output1, vc_output2])
94
+
95
+ """
96
+ for category, link in others.items():
97
+ with gr.TabItem(category):
98
+ gr.Markdown(
99
+ f'''
100
+ <center>
101
+ <h2>Click to Go</h2>
102
+ <a href="{link}">
103
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-xl-dark.svg"
104
+ </a>
105
+ </center>
106
+ '''
107
+ )
108
+ """
109
+ app.queue(concurrency_count=1, api_open=args.api).launch(share=args.share)