soiz1 commited on
Commit
91ccc0d
·
verified ·
1 Parent(s): cfa9320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -338
app.py CHANGED
@@ -1,342 +1,5 @@
1
- import os
2
- import spaces
3
  import gradio as gr
4
- import torch
5
- import torchaudio
6
- import librosa
7
- from modules.commons import build_model, load_checkpoint, recursive_munch
8
- import yaml
9
- from hf_utils import load_custom_model_from_hf
10
- import numpy as np
11
- from pydub import AudioSegment
12
 
13
- # Load model and configuration
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
17
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
18
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
19
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
20
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
21
- config = yaml.safe_load(open(dit_config_path, 'r'))
22
- model_params = recursive_munch(config['model_params'])
23
- model = build_model(model_params, stage='DiT')
24
- hop_length = config['preprocess_params']['spect_params']['hop_length']
25
- sr = config['preprocess_params']['sr']
26
-
27
- # Load checkpoints
28
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
29
- load_only_params=True, ignore_modules=[], is_distributed=False)
30
- for key in model:
31
- model[key].eval()
32
- model[key].to(device)
33
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
34
-
35
- # Load additional modules
36
- from modules.campplus.DTDNN import CAMPPlus
37
-
38
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
39
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
40
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
41
- campplus_model.eval()
42
- campplus_model.to(device)
43
-
44
- from modules.bigvgan import bigvgan
45
-
46
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
47
-
48
- # remove weight norm in the model and set to eval mode
49
- bigvgan_model.remove_weight_norm()
50
- bigvgan_model = bigvgan_model.eval().to(device)
51
-
52
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
53
-
54
- codec_config = yaml.safe_load(open(config_path))
55
- codec_model_params = recursive_munch(codec_config['model_params'])
56
- codec_encoder = build_model(codec_model_params, stage="codec")
57
-
58
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
59
-
60
- for key in codec_encoder:
61
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
62
- _ = [codec_encoder[key].eval() for key in codec_encoder]
63
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
64
-
65
- # whisper
66
- from transformers import AutoFeatureExtractor, WhisperModel
67
-
68
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
69
- 'whisper_name') else "openai/whisper-small"
70
- whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
71
- del whisper_model.decoder
72
- whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
73
-
74
- # Generate mel spectrograms
75
- mel_fn_args = {
76
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
77
- "win_size": config['preprocess_params']['spect_params']['win_length'],
78
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
79
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
80
- "sampling_rate": sr,
81
- "fmin": 0,
82
- "fmax": None,
83
- "center": False
84
- }
85
- from modules.audio import mel_spectrogram
86
-
87
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
88
-
89
- # f0 conditioned model
90
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
91
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
92
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
93
-
94
- config = yaml.safe_load(open(dit_config_path, 'r'))
95
- model_params = recursive_munch(config['model_params'])
96
- model_f0 = build_model(model_params, stage='DiT')
97
- hop_length = config['preprocess_params']['spect_params']['hop_length']
98
- sr = config['preprocess_params']['sr']
99
-
100
- # Load checkpoints
101
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
102
- load_only_params=True, ignore_modules=[], is_distributed=False)
103
- for key in model_f0:
104
- model_f0[key].eval()
105
- model_f0[key].to(device)
106
- model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
107
-
108
- # f0 extractor
109
- from modules.rmvpe import RMVPE
110
-
111
- model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
112
- rmvpe = RMVPE(model_path, is_half=False, device=device)
113
-
114
- mel_fn_args_f0 = {
115
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
116
- "win_size": config['preprocess_params']['spect_params']['win_length'],
117
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
118
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
119
- "sampling_rate": sr,
120
- "fmin": 0,
121
- "fmax": None,
122
- "center": False
123
- }
124
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
125
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
126
-
127
- # remove weight norm in the model and set to eval mode
128
- bigvgan_44k_model.remove_weight_norm()
129
- bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
130
-
131
- def adjust_f0_semitones(f0_sequence, n_semitones):
132
- factor = 2 ** (n_semitones / 12)
133
- return f0_sequence * factor
134
-
135
- def crossfade(chunk1, chunk2, overlap):
136
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
137
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
138
- chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
139
- return chunk2
140
-
141
- # streaming and chunk processing related params
142
- bitrate = "320k"
143
- overlap_frame_len = 16
144
- @spaces.GPU
145
- @torch.no_grad()
146
- @torch.inference_mode()
147
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
148
- inference_module = model if not f0_condition else model_f0
149
- mel_fn = to_mel if not f0_condition else to_mel_f0
150
- bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
151
- sr = 22050 if not f0_condition else 44100
152
- hop_length = 256 if not f0_condition else 512
153
- max_context_window = sr // hop_length * 30
154
- overlap_wave_len = overlap_frame_len * hop_length
155
- # Load audio
156
- source_audio = librosa.load(source, sr=sr)[0]
157
- ref_audio = librosa.load(target, sr=sr)[0]
158
-
159
- # Process audio
160
- source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
161
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
162
-
163
- # Resample
164
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
165
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
166
- # if source audio less than 30 seconds, whisper can handle in one forward
167
- if converted_waves_16k.size(-1) <= 16000 * 30:
168
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
169
- return_tensors="pt",
170
- return_attention_mask=True,
171
- sampling_rate=16000)
172
- alt_input_features = whisper_model._mask_input_features(
173
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
174
- alt_outputs = whisper_model.encoder(
175
- alt_input_features.to(whisper_model.encoder.dtype),
176
- head_mask=None,
177
- output_attentions=False,
178
- output_hidden_states=False,
179
- return_dict=True,
180
- )
181
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
182
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
183
- else:
184
- overlapping_time = 5 # 5 seconds
185
- S_alt_list = []
186
- buffer = None
187
- traversed_time = 0
188
- while traversed_time < converted_waves_16k.size(-1):
189
- if buffer is None: # first chunk
190
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
191
- else:
192
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
193
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
194
- return_tensors="pt",
195
- return_attention_mask=True,
196
- sampling_rate=16000)
197
- alt_input_features = whisper_model._mask_input_features(
198
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
199
- alt_outputs = whisper_model.encoder(
200
- alt_input_features.to(whisper_model.encoder.dtype),
201
- head_mask=None,
202
- output_attentions=False,
203
- output_hidden_states=False,
204
- return_dict=True,
205
- )
206
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
207
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
208
- if traversed_time == 0:
209
- S_alt_list.append(S_alt)
210
- else:
211
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
212
- buffer = chunk[:, -16000 * overlapping_time:]
213
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
214
- S_alt = torch.cat(S_alt_list, dim=1)
215
-
216
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
217
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
218
- return_tensors="pt",
219
- return_attention_mask=True)
220
- ori_input_features = whisper_model._mask_input_features(
221
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
222
- with torch.no_grad():
223
- ori_outputs = whisper_model.encoder(
224
- ori_input_features.to(whisper_model.encoder.dtype),
225
- head_mask=None,
226
- output_attentions=False,
227
- output_hidden_states=False,
228
- return_dict=True,
229
- )
230
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
231
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
232
-
233
- mel = mel_fn(source_audio.to(device).float())
234
- mel2 = mel_fn(ref_audio.to(device).float())
235
-
236
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
237
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
238
-
239
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
240
- num_mel_bins=80,
241
- dither=0,
242
- sample_frequency=16000)
243
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
244
- style2 = campplus_model(feat2.unsqueeze(0))
245
-
246
- if f0_condition:
247
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
248
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
249
-
250
- F0_ori = torch.from_numpy(F0_ori).to(device)[None]
251
- F0_alt = torch.from_numpy(F0_alt).to(device)[None]
252
-
253
- voiced_F0_ori = F0_ori[F0_ori > 1]
254
- voiced_F0_alt = F0_alt[F0_alt > 1]
255
-
256
- log_f0_alt = torch.log(F0_alt + 1e-5)
257
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
258
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
259
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
260
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
261
-
262
- # shift alt log f0 level to ori log f0 level
263
- shifted_log_f0_alt = log_f0_alt.clone()
264
- if auto_f0_adjust:
265
- shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
266
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
267
- if pitch_shift != 0:
268
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
269
- else:
270
- F0_ori = None
271
- F0_alt = None
272
- shifted_f0_alt = None
273
-
274
- # Length regulation
275
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
276
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
277
-
278
- max_source_window = max_context_window - mel2.size(2)
279
- # split source condition (cond) into chunks
280
- processed_frames = 0
281
- generated_wave_chunks = []
282
- # generate chunk by chunk and stream the output
283
- while processed_frames < cond.size(1):
284
- chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
285
- is_last_chunk = processed_frames + max_source_window >= cond.size(1)
286
- cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
287
- with torch.autocast(device_type='cuda', dtype=torch.float16):
288
- # Voice Conversion
289
- vc_target = inference_module.cfm.inference(cat_condition,
290
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
291
- mel2, style2, None, diffusion_steps,
292
- inference_cfg_rate=inference_cfg_rate)
293
- vc_target = vc_target[:, :, mel2.size(-1):]
294
- vc_wave = bigvgan_fn(vc_target.float())[0]
295
- if processed_frames == 0:
296
- if is_last_chunk:
297
- output_wave = vc_wave[0].cpu().numpy()
298
- generated_wave_chunks.append(output_wave)
299
- output_wave = (output_wave * 32768.0).astype(np.int16)
300
- mp3_bytes = AudioSegment(
301
- output_wave.tobytes(), frame_rate=sr,
302
- sample_width=output_wave.dtype.itemsize, channels=1
303
- ).export(format="mp3", bitrate=bitrate).read()
304
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
305
- break
306
- output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
307
- generated_wave_chunks.append(output_wave)
308
- previous_chunk = vc_wave[0, -overlap_wave_len:]
309
- processed_frames += vc_target.size(2) - overlap_frame_len
310
- output_wave = (output_wave * 32768.0).astype(np.int16)
311
- mp3_bytes = AudioSegment(
312
- output_wave.tobytes(), frame_rate=sr,
313
- sample_width=output_wave.dtype.itemsize, channels=1
314
- ).export(format="mp3", bitrate=bitrate).read()
315
- yield mp3_bytes, None
316
- elif is_last_chunk:
317
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
318
- generated_wave_chunks.append(output_wave)
319
- processed_frames += vc_target.size(2) - overlap_frame_len
320
- output_wave = (output_wave * 32768.0).astype(np.int16)
321
- mp3_bytes = AudioSegment(
322
- output_wave.tobytes(), frame_rate=sr,
323
- sample_width=output_wave.dtype.itemsize, channels=1
324
- ).export(format="mp3", bitrate=bitrate).read()
325
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
326
- break
327
- else:
328
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
329
- generated_wave_chunks.append(output_wave)
330
- previous_chunk = vc_wave[0, -overlap_wave_len:]
331
- processed_frames += vc_target.size(2) - overlap_frame_len
332
- output_wave = (output_wave * 32768.0).astype(np.int16)
333
- mp3_bytes = AudioSegment(
334
- output_wave.tobytes(), frame_rate=sr,
335
- sample_width=output_wave.dtype.itemsize, channels=1
336
- ).export(format="mp3", bitrate=bitrate).read()
337
- yield mp3_bytes, None
338
-
339
-
340
  gallery_data = [
341
  {"name": "sikokumetan", "webp": "default/sikokumetan.webp", "mp3": "default/sikokumetan.mp3"}
342
  ]
@@ -347,6 +10,9 @@ def auto_set_reference(selected_image):
347
  return item["mp3"]
348
  return ""
349
 
 
 
 
350
  if __name__ == "__main__":
351
  description = ("Zero-shot音声変換モデル(学習不要)。ローカルでの利用方法は[GitHubリポジトリ](https://github.com/Plachtaa/seed-vc)をご覧ください。"
352
  "参考音声が25秒を超える場合、自動的に25秒にクリップされます。"
@@ -356,7 +22,7 @@ if __name__ == "__main__":
356
  gr.Audio(type="filepath", label="元音声"),
357
  gr.Audio(type="filepath", label="参考音声"),
358
  gr.Gallery(label="ギャラリー", value=[item["webp"] for item in gallery_data],
359
- interactive=True, elem_id="gallery"),
360
  gr.Slider(minimum=1, maximum=200, value=10, step=1, label="拡散ステップ数", info="デフォルトは10、50~100が最適な品質"),
361
  gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="長さ調整", info="1.0未満で速度を上げ、1.0以上で速度を遅くします"),
362
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="推論CFG率", info="わずかな影響があります"),
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  gallery_data = [
4
  {"name": "sikokumetan", "webp": "default/sikokumetan.webp", "mp3": "default/sikokumetan.mp3"}
5
  ]
 
10
  return item["mp3"]
11
  return ""
12
 
13
+ def handle_gallery_selection(selected_image):
14
+ return auto_set_reference(selected_image)
15
+
16
  if __name__ == "__main__":
17
  description = ("Zero-shot音声変換モデル(学習不要)。ローカルでの利用方法は[GitHubリポジトリ](https://github.com/Plachtaa/seed-vc)をご覧ください。"
18
  "参考音声が25秒を超える場合、自動的に25秒にクリップされます。"
 
22
  gr.Audio(type="filepath", label="元音声"),
23
  gr.Audio(type="filepath", label="参考音声"),
24
  gr.Gallery(label="ギャラリー", value=[item["webp"] for item in gallery_data],
25
+ interactive=True, elem_id="gallery").change(handle_gallery_selection, inputs="gallery", outputs="参考音声"),
26
  gr.Slider(minimum=1, maximum=200, value=10, step=1, label="拡散ステップ数", info="デフォルトは10、50~100が最適な品質"),
27
  gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="長さ調整", info="1.0未満で速度を上げ、1.0以上で速度を遅くします"),
28
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="推論CFG率", info="わずかな影響があります"),