Uniaff commited on
Commit
52a7f35
·
verified ·
1 Parent(s): 5967f17

Update seedvc.py

Browse files
Files changed (1) hide show
  1. seedvc.py +358 -374
seedvc.py CHANGED
@@ -1,374 +1,358 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- import torchaudio
5
- import librosa
6
- from modules.commons import build_model, load_checkpoint, recursive_munch
7
- import yaml
8
- from hf_utils import load_custom_model_from_hf
9
- import numpy as np
10
- from pydub import AudioSegment
11
-
12
- # Load model and configuration
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
16
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
17
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
18
- # dit_checkpoint_path = "E:/DiT_epoch_00018_step_801000.pth"
19
- # dit_config_path = "configs/config_dit_mel_seed_uvit_whisper_small_encoder_wavenet.yml"
20
- config = yaml.safe_load(open(dit_config_path, 'r'))
21
- model_params = recursive_munch(config['model_params'])
22
- model = build_model(model_params, stage='DiT')
23
- hop_length = config['preprocess_params']['spect_params']['hop_length']
24
- sr = config['preprocess_params']['sr']
25
-
26
- # Load checkpoints
27
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
28
- load_only_params=True, ignore_modules=[], is_distributed=False)
29
- for key in model:
30
- model[key].eval()
31
- model[key].to(device)
32
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
33
-
34
- # Load additional modules
35
- from modules.campplus.DTDNN import CAMPPlus
36
-
37
- campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
38
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
39
- campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
40
- campplus_model.eval()
41
- campplus_model.to(device)
42
-
43
- from modules.bigvgan import bigvgan
44
-
45
- bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
46
-
47
- # remove weight norm in the model and set to eval mode
48
- bigvgan_model.remove_weight_norm()
49
- bigvgan_model = bigvgan_model.eval().to(device)
50
-
51
- ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
52
-
53
- codec_config = yaml.safe_load(open(config_path))
54
- codec_model_params = recursive_munch(codec_config['model_params'])
55
- codec_encoder = build_model(codec_model_params, stage="codec")
56
-
57
- ckpt_params = torch.load(ckpt_path, map_location="cpu")
58
-
59
- for key in codec_encoder:
60
- codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
61
- _ = [codec_encoder[key].eval() for key in codec_encoder]
62
- _ = [codec_encoder[key].to(device) for key in codec_encoder]
63
-
64
- # whisper
65
- from transformers import AutoFeatureExtractor, WhisperModel
66
-
67
- whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer,
68
- 'whisper_name') else "openai/whisper-small"
69
- whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
70
- del whisper_model.decoder
71
- whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
72
-
73
- # Generate mel spectrograms
74
- mel_fn_args = {
75
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
76
- "win_size": config['preprocess_params']['spect_params']['win_length'],
77
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
78
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
79
- "sampling_rate": sr,
80
- "fmin": 0,
81
- "fmax": None,
82
- "center": False
83
- }
84
- from modules.audio import mel_spectrogram
85
-
86
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
87
-
88
- # f0 conditioned model
89
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
90
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
91
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
92
-
93
- config = yaml.safe_load(open(dit_config_path, 'r'))
94
- model_params = recursive_munch(config['model_params'])
95
- model_f0 = build_model(model_params, stage='DiT')
96
- hop_length = config['preprocess_params']['spect_params']['hop_length']
97
- sr = config['preprocess_params']['sr']
98
-
99
- # Load checkpoints
100
- model_f0, _, _, _ = load_checkpoint(model_f0, None, dit_checkpoint_path,
101
- load_only_params=True, ignore_modules=[], is_distributed=False)
102
- for key in model_f0:
103
- model_f0[key].eval()
104
- model_f0[key].to(device)
105
- model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
106
-
107
- # f0 extractor
108
- from modules.rmvpe import RMVPE
109
-
110
- model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
111
- rmvpe = RMVPE(model_path, is_half=False, device=device)
112
-
113
- mel_fn_args_f0 = {
114
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
115
- "win_size": config['preprocess_params']['spect_params']['win_length'],
116
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
117
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
118
- "sampling_rate": sr,
119
- "fmin": 0,
120
- "fmax": None,
121
- "center": False
122
- }
123
- to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
124
- bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
125
-
126
- # remove weight norm in the model and set to eval mode
127
- bigvgan_44k_model.remove_weight_norm()
128
- bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
129
-
130
- def adjust_f0_semitones(f0_sequence, n_semitones):
131
- factor = 2 ** (n_semitones / 12)
132
- return f0_sequence * factor
133
-
134
- def crossfade(chunk1, chunk2, overlap):
135
- fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
136
- fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
137
- chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
138
- return chunk2
139
-
140
- # streaming and chunk processing related params
141
- max_context_window = sr // hop_length * 30
142
- overlap_frame_len = 16
143
- overlap_wave_len = overlap_frame_len * hop_length
144
- bitrate = "320k"
145
-
146
- @spaces.GPU
147
- @torch.no_grad()
148
- @torch.inference_mode()
149
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
150
- inference_module = model if not f0_condition else model_f0
151
- mel_fn = to_mel if not f0_condition else to_mel_f0
152
- bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
153
- sr = 22050 if not f0_condition else 44100
154
- # Load audio
155
- source_audio = librosa.load(source, sr=sr)[0]
156
- ref_audio = librosa.load(target, sr=sr)[0]
157
-
158
- # Process audio
159
- source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
160
- ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
161
-
162
- # Resample
163
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
164
- converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
165
- # if source audio less than 30 seconds, whisper can handle in one forward
166
- if converted_waves_16k.size(-1) <= 16000 * 30:
167
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
168
- return_tensors="pt",
169
- return_attention_mask=True,
170
- sampling_rate=16000)
171
- alt_input_features = whisper_model._mask_input_features(
172
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
173
- alt_outputs = whisper_model.encoder(
174
- alt_input_features.to(whisper_model.encoder.dtype),
175
- head_mask=None,
176
- output_attentions=False,
177
- output_hidden_states=False,
178
- return_dict=True,
179
- )
180
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
181
- S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
182
- else:
183
- overlapping_time = 5 # 5 seconds
184
- S_alt_list = []
185
- buffer = None
186
- traversed_time = 0
187
- while traversed_time < converted_waves_16k.size(-1):
188
- if buffer is None: # first chunk
189
- chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
190
- else:
191
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
192
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
193
- return_tensors="pt",
194
- return_attention_mask=True,
195
- sampling_rate=16000)
196
- alt_input_features = whisper_model._mask_input_features(
197
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
198
- alt_outputs = whisper_model.encoder(
199
- alt_input_features.to(whisper_model.encoder.dtype),
200
- head_mask=None,
201
- output_attentions=False,
202
- output_hidden_states=False,
203
- return_dict=True,
204
- )
205
- S_alt = alt_outputs.last_hidden_state.to(torch.float32)
206
- S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
207
- if traversed_time == 0:
208
- S_alt_list.append(S_alt)
209
- else:
210
- S_alt_list.append(S_alt[:, 50 * overlapping_time:])
211
- buffer = chunk[:, -16000 * overlapping_time:]
212
- traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
213
- S_alt = torch.cat(S_alt_list, dim=1)
214
-
215
- ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
216
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
217
- return_tensors="pt",
218
- return_attention_mask=True)
219
- ori_input_features = whisper_model._mask_input_features(
220
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
221
- with torch.no_grad():
222
- ori_outputs = whisper_model.encoder(
223
- ori_input_features.to(whisper_model.encoder.dtype),
224
- head_mask=None,
225
- output_attentions=False,
226
- output_hidden_states=False,
227
- return_dict=True,
228
- )
229
- S_ori = ori_outputs.last_hidden_state.to(torch.float32)
230
- S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
231
-
232
- mel = mel_fn(source_audio.to(device).float())
233
- mel2 = mel_fn(ref_audio.to(device).float())
234
-
235
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
236
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
237
-
238
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
239
- num_mel_bins=80,
240
- dither=0,
241
- sample_frequency=16000)
242
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
243
- style2 = campplus_model(feat2.unsqueeze(0))
244
-
245
- if f0_condition:
246
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
247
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
248
-
249
- F0_ori = torch.from_numpy(F0_ori).to(device)[None]
250
- F0_alt = torch.from_numpy(F0_alt).to(device)[None]
251
-
252
- voiced_F0_ori = F0_ori[F0_ori > 1]
253
- voiced_F0_alt = F0_alt[F0_alt > 1]
254
-
255
- log_f0_alt = torch.log(F0_alt + 1e-5)
256
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
257
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
258
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
259
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
260
-
261
- # shift alt log f0 level to ori log f0 level
262
- shifted_log_f0_alt = log_f0_alt.clone()
263
- if auto_f0_adjust:
264
- shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
265
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
266
- if pitch_shift != 0:
267
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
268
- else:
269
- F0_ori = None
270
- F0_alt = None
271
- shifted_f0_alt = None
272
-
273
- # Length regulation
274
- cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
275
- prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
276
-
277
- max_source_window = max_context_window - mel2.size(2)
278
- # split source condition (cond) into chunks
279
- processed_frames = 0
280
- generated_wave_chunks = []
281
- # generate chunk by chunk and stream the output
282
- while processed_frames < cond.size(1):
283
- chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
284
- is_last_chunk = processed_frames + max_source_window >= cond.size(1)
285
- cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
286
- # Voice Conversion
287
- vc_target = inference_module.cfm.inference(cat_condition,
288
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
289
- mel2, style2, None, diffusion_steps,
290
- inference_cfg_rate=inference_cfg_rate)
291
- vc_target = vc_target[:, :, mel2.size(-1):]
292
- vc_wave = bigvgan_fn(vc_target)[0]
293
- if processed_frames == 0:
294
- if is_last_chunk:
295
- output_wave = vc_wave[0].cpu().numpy()
296
- generated_wave_chunks.append(output_wave)
297
- output_wave = (output_wave * 32768.0).astype(np.int16)
298
- mp3_bytes = AudioSegment(
299
- output_wave.tobytes(), frame_rate=sr,
300
- sample_width=output_wave.dtype.itemsize, channels=1
301
- ).export(format="mp3", bitrate=bitrate).read()
302
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
303
- break
304
- output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
305
- generated_wave_chunks.append(output_wave)
306
- previous_chunk = vc_wave[0, -overlap_wave_len:]
307
- processed_frames += vc_target.size(2) - overlap_frame_len
308
- output_wave = (output_wave * 32768.0).astype(np.int16)
309
- mp3_bytes = AudioSegment(
310
- output_wave.tobytes(), frame_rate=sr,
311
- sample_width=output_wave.dtype.itemsize, channels=1
312
- ).export(format="mp3", bitrate=bitrate).read()
313
- yield mp3_bytes, None
314
- elif is_last_chunk:
315
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
316
- generated_wave_chunks.append(output_wave)
317
- processed_frames += vc_target.size(2) - overlap_frame_len
318
- output_wave = (output_wave * 32768.0).astype(np.int16)
319
- mp3_bytes = AudioSegment(
320
- output_wave.tobytes(), frame_rate=sr,
321
- sample_width=output_wave.dtype.itemsize, channels=1
322
- ).export(format="mp3", bitrate=bitrate).read()
323
- yield mp3_bytes, (sr, np.concatenate(generated_wave_chunks))
324
- break
325
- else:
326
- output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
327
- generated_wave_chunks.append(output_wave)
328
- previous_chunk = vc_wave[0, -overlap_wave_len:]
329
- processed_frames += vc_target.size(2) - overlap_frame_len
330
- output_wave = (output_wave * 32768.0).astype(np.int16)
331
- mp3_bytes = AudioSegment(
332
- output_wave.tobytes(), frame_rate=sr,
333
- sample_width=output_wave.dtype.itemsize, channels=1
334
- ).export(format="mp3", bitrate=bitrate).read()
335
- yield mp3_bytes, None
336
-
337
-
338
- if __name__ == "__main__":
339
- description = ("State-of-the-Art zero-shot voice conversion/singing voice conversion. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
340
- "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
341
- "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.<br> "
342
- "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)<br>"
343
- "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。<br>若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
344
- inputs = [
345
- gr.Audio(type="filepath", label="Source Audio / 源音频"),
346
- gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
347
- gr.Slider(minimum=1, maximum=200, value=25, step=1, label="Diffusion Steps / 扩散步数", info="25 by default, 50~100 for best quality / 默认为 25,50~100 为最佳质量"),
348
- gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
349
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence / 有微小影响"),
350
- gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
351
- gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
352
- info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
353
- gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
354
- ]
355
-
356
- examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
357
- ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, False, True, 0],
358
- ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
359
- "examples/reference/kobe_0.wav", 50, 1.0, 0.7, True, False, -6],
360
- ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
361
- "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
362
- ]
363
-
364
- outputs = [gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
365
- gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')]
366
-
367
- gr.Interface(fn=voice_conversion,
368
- description=description,
369
- inputs=inputs,
370
- outputs=outputs,
371
- title="Seed Voice Conversion",
372
- examples=examples,
373
- cache_examples=False,
374
- ).launch()
 
1
+ import torch
2
+ import torchaudio
3
+ import librosa
4
+ from modules.commons import build_model, load_checkpoint, recursive_munch
5
+ import yaml
6
+ from hf_utils import load_custom_model_from_hf
7
+ import numpy as np
8
+
9
+ # Загрузка моделей и конфигураций
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Загрузка конфигурации и модели DiT
13
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
14
+ "Plachta/Seed-VC",
15
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
16
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"
17
+ )
18
+
19
+ config = yaml.safe_load(open(dit_config_path, 'r'))
20
+ model_params = recursive_munch(config['model_params'])
21
+ model = build_model(model_params, stage='DiT')
22
+ hop_length = config['preprocess_params']['spect_params']['hop_length']
23
+ sr = config['preprocess_params']['sr']
24
+
25
+ # Загрузка контрольных точек модели
26
+ model, _, _, _ = load_checkpoint(
27
+ model, None, dit_checkpoint_path,
28
+ load_only_params=True, ignore_modules=[], is_distributed=False
29
+ )
30
+ for key in model:
31
+ model[key].eval()
32
+ model[key].to(device)
33
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
34
+
35
+ # Загрузка дополнительной модели CAMPPlus
36
+ from modules.campplus.DTDNN import CAMPPlus
37
+
38
+ campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
39
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
40
+ campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
41
+ campplus_model.eval()
42
+ campplus_model.to(device)
43
+
44
+ # Загрузка модели BigVGAN
45
+ from modules.bigvgan import bigvgan
46
+
47
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
48
+ bigvgan_model.remove_weight_norm()
49
+ bigvgan_model = bigvgan_model.eval().to(device)
50
+
51
+ # Загрузка модели FAcodec
52
+ ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
53
+
54
+ codec_config = yaml.safe_load(open(config_path))
55
+ codec_model_params = recursive_munch(codec_config['model_params'])
56
+ codec_encoder = build_model(codec_model_params, stage="codec")
57
+
58
+ ckpt_params = torch.load(ckpt_path, map_location="cpu")
59
+
60
+ for key in codec_encoder:
61
+ codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
62
+ _ = [codec_encoder[key].eval() for key in codec_encoder]
63
+ _ = [codec_encoder[key].to(device) for key in codec_encoder]
64
+
65
+ # Загрузка модели Whisper
66
+ from transformers import AutoFeatureExtractor, WhisperModel
67
+
68
+ whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small"
69
+ whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
70
+ del whisper_model.decoder
71
+ whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
72
+
73
+ # Функция для генерации мел-спектрограммы
74
+ mel_fn_args = {
75
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
76
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
77
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
78
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
79
+ "sampling_rate": sr,
80
+ "fmin": 0,
81
+ "fmax": None,
82
+ "center": False
83
+ }
84
+ from modules.audio import mel_spectrogram
85
+
86
+ to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
87
+
88
+ # Модель с F0 условием
89
+ dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf(
90
+ "Plachta/Seed-VC",
91
+ "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
92
+ "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml"
93
+ )
94
+
95
+ config_f0 = yaml.safe_load(open(dit_config_path_f0, 'r'))
96
+ model_params_f0 = recursive_munch(config_f0['model_params'])
97
+ model_f0 = build_model(model_params_f0, stage='DiT')
98
+ hop_length_f0 = config_f0['preprocess_params']['spect_params']['hop_length']
99
+ sr_f0 = config_f0['preprocess_params']['sr']
100
+
101
+ # Загрузка контрольных точек модели с F0
102
+ model_f0, _, _, _ = load_checkpoint(
103
+ model_f0, None, dit_checkpoint_path_f0,
104
+ load_only_params=True, ignore_modules=[], is_distributed=False
105
+ )
106
+ for key in model_f0:
107
+ model_f0[key].eval()
108
+ model_f0[key].to(device)
109
+ model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
110
+
111
+ # Загрузка F0-экстрактора RMVPE
112
+ from modules.rmvpe import RMVPE
113
+
114
+ model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
115
+ rmvpe = RMVPE(model_path, is_half=False, device=device)
116
+
117
+ # Параметры мел-спектрограммы для F0
118
+ mel_fn_args_f0 = {
119
+ "n_fft": config_f0['preprocess_params']['spect_params']['n_fft'],
120
+ "win_size": config_f0['preprocess_params']['spect_params']['win_length'],
121
+ "hop_size": config_f0['preprocess_params']['spect_params']['hop_length'],
122
+ "num_mels": config_f0['preprocess_params']['spect_params']['n_mels'],
123
+ "sampling_rate": sr_f0,
124
+ "fmin": 0,
125
+ "fmax": None,
126
+ "center": False
127
+ }
128
+ to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
129
+
130
+ # Загрузка модели BigVGAN для 44kHz
131
+ bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
132
+ bigvgan_44k_model.remove_weight_norm()
133
+ bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
134
+
135
+ def adjust_f0_semitones(f0_sequence, n_semitones):
136
+ factor = 2 ** (n_semitones / 12)
137
+ return f0_sequence * factor
138
+
139
+ def crossfade(chunk1, chunk2, overlap):
140
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
141
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
142
+ chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
143
+ return chunk2
144
+
145
+ # Параметры для обработки потоков и чанков
146
+ max_context_window = sr // hop_length * 30
147
+ overlap_frame_len = 16
148
+ overlap_wave_len = overlap_frame_len * hop_length
149
+ bitrate = "320k"
150
+
151
+ @torch.no_grad()
152
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
153
+ """
154
+ Функция для голосового преобразования.
155
+
156
+ Параметры:
157
+ - source (str): Путь к исходному аудио файлу.
158
+ - target (str): Путь к целевому аудио файлу (голос, на который нужно преобразовать).
159
+ - diffusion_steps (int): Количество шагов диффузии.
160
+ - length_adjust (float): Коэффициент регулировки длины.
161
+ - inference_cfg_rate (float): Коэффициент CFG для инференса.
162
+ - f0_condition (bool): Использовать ли условие F0.
163
+ - auto_f0_adjust (bool): Автоматически ли корректировать F0.
164
+ - pitch_shift (int): Сдвиг тона в полутонах.
165
+
166
+ Возвращает:
167
+ - tuple: (частота дискретизации, numpy массив аудио данных)
168
+ """
169
+ inference_module = model_f0 if f0_condition else model
170
+ mel_fn = to_mel_f0 if f0_condition else to_mel
171
+ bigvgan_fn = bigvgan_44k_model if f0_condition else bigvgan_model
172
+ sr_used = sr_f0 if f0_condition else sr
173
+
174
+ # Загрузка аудио
175
+ source_audio, _ = librosa.load(source, sr=sr_used)
176
+ ref_audio, _ = librosa.load(target, sr=sr_used)
177
+
178
+ # Ограничение длины целевого аудио
179
+ ref_audio = ref_audio[:sr_used * 25]
180
+
181
+ # Преобразование аудио в тензоры
182
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
183
+ ref_audio = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
184
+
185
+ # Ресемплирование для Whisper
186
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr_used, 16000)
187
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr_used, 16000)
188
+
189
+ # Извлечение признаков с помощью Whisper
190
+ if converted_waves_16k.size(-1) <= 16000 * 30:
191
+ alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()],
192
+ return_tensors="pt",
193
+ return_attention_mask=True,
194
+ sampling_rate=16000)
195
+ alt_input_features = whisper_model._mask_input_features(
196
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
197
+ alt_outputs = whisper_model.encoder(
198
+ alt_input_features.to(whisper_model.encoder.dtype),
199
+ head_mask=None,
200
+ output_attentions=False,
201
+ output_hidden_states=False,
202
+ return_dict=True,
203
+ )
204
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
205
+ S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
206
+ else:
207
+ # Обработка длинного аудио в чанках
208
+ overlapping_time = 5 # секунд
209
+ S_alt_list = []
210
+ buffer = None
211
+ traversed_time = 0
212
+ while traversed_time < converted_waves_16k.size(-1):
213
+ if buffer is None:
214
+ chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
215
+ else:
216
+ chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]], dim=-1)
217
+ alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],
218
+ return_tensors="pt",
219
+ return_attention_mask=True,
220
+ sampling_rate=16000)
221
+ alt_input_features = whisper_model._mask_input_features(
222
+ alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
223
+ alt_outputs = whisper_model.encoder(
224
+ alt_input_features.to(whisper_model.encoder.dtype),
225
+ head_mask=None,
226
+ output_attentions=False,
227
+ output_hidden_states=False,
228
+ return_dict=True,
229
+ )
230
+ S_alt = alt_outputs.last_hidden_state.to(torch.float32)
231
+ S_alt = S_alt[:, :chunk.size(-1) // 320 + 1]
232
+ if traversed_time == 0:
233
+ S_alt_list.append(S_alt)
234
+ else:
235
+ S_alt_list.append(S_alt[:, 50 * overlapping_time:])
236
+ buffer = chunk[:, -16000 * overlapping_time:]
237
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
238
+ S_alt = torch.cat(S_alt_list, dim=1)
239
+
240
+ # Извлечение признаков из референсного аудио
241
+ ori_waves_16k = torchaudio.functional.resample(ref_audio, sr_used, 16000)
242
+ ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()],
243
+ return_tensors="pt",
244
+ return_attention_mask=True)
245
+ ori_input_features = whisper_model._mask_input_features(
246
+ ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
247
+ with torch.no_grad():
248
+ ori_outputs = whisper_model.encoder(
249
+ ori_input_features.to(whisper_model.encoder.dtype),
250
+ head_mask=None,
251
+ output_attentions=False,
252
+ output_hidden_states=False,
253
+ return_dict=True,
254
+ )
255
+ S_ori = ori_outputs.last_hidden_state.to(torch.float32)
256
+ S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
257
+
258
+ mel = mel_fn(source_audio.to(device).float())
259
+ mel2 = mel_fn(ref_audio.to(device).float())
260
+
261
+ target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
262
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
263
+
264
+ # Извлечение стиля с помощью CAMPPlus
265
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
266
+ num_mel_bins=80,
267
+ dither=0,
268
+ sample_frequency=16000)
269
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
270
+ style2 = campplus_model(feat2.unsqueeze(0))
271
+
272
+ if f0_condition:
273
+ # Извлечение F0 с помощью RMVPE
274
+ F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
275
+ F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
276
+
277
+ F0_ori = torch.from_numpy(F0_ori).to(device)[None]
278
+ F0_alt = torch.from_numpy(F0_alt).to(device)[None]
279
+
280
+ voiced_F0_ori = F0_ori[F0_ori > 1]
281
+ voiced_F0_alt = F0_alt[F0_alt > 1]
282
+
283
+ log_f0_alt = torch.log(F0_alt + 1e-5)
284
+ voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
285
+ voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
286
+ median_log_f0_ori = torch.median(voiced_log_f0_ori)
287
+ median_log_f0_alt = torch.median(voiced_log_f0_alt)
288
+
289
+ # Корректировка F0
290
+ shifted_log_f0_alt = log_f0_alt.clone()
291
+ if auto_f0_adjust:
292
+ shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
293
+ shifted_f0_alt = torch.exp(shifted_log_f0_alt)
294
+ if pitch_shift != 0:
295
+ shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
296
+ else:
297
+ F0_ori = None
298
+ F0_alt = None
299
+ shifted_f0_alt = None
300
+
301
+ # Регулировка длины
302
+ cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
303
+ S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt
304
+ )
305
+ prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
306
+ S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori
307
+ )
308
+
309
+ max_source_window = max_context_window - mel2.size(2)
310
+ processed_frames = 0
311
+ generated_wave_chunks = []
312
+
313
+ # Генерация аудио по частям
314
+ while processed_frames < cond.size(1):
315
+ chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
316
+ is_last_chunk = processed_frames + max_source_window >= cond.size(1)
317
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
318
+ vc_target = inference_module.cfm.inference(
319
+ cat_condition,
320
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
321
+ mel2, style2, None, diffusion_steps,
322
+ inference_cfg_rate=inference_cfg_rate
323
+ )
324
+ vc_target = vc_target[:, :, mel2.size(-1):]
325
+ vc_wave = bigvgan_fn(vc_target)[0]
326
+ if processed_frames == 0:
327
+ if is_last_chunk:
328
+ output_wave = vc_wave[0].cpu().numpy()
329
+ generated_wave_chunks.append(output_wave)
330
+ break
331
+ output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
332
+ generated_wave_chunks.append(output_wave)
333
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
334
+ processed_frames += vc_target.size(2) - overlap_frame_len
335
+ elif is_last_chunk:
336
+ output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
337
+ generated_wave_chunks.append(output_wave)
338
+ processed_frames += vc_target.size(2) - overlap_frame_len
339
+ break
340
+ else:
341
+ output_wave = crossfade(
342
+ previous_chunk.cpu().numpy(),
343
+ vc_wave[0, :-overlap_wave_len].cpu().numpy(),
344
+ overlap_wave_len
345
+ )
346
+ generated_wave_chunks.append(output_wave)
347
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
348
+ processed_frames += vc_target.size(2) - overlap_frame_len
349
+
350
+ # Объединение всех чанков в одно аудио
351
+ full_output_wave = np.concatenate(generated_wave_chunks)
352
+
353
+ # Нормализация аудио
354
+ max_val = np.max(np.abs(full_output_wave))
355
+ if max_val > 1.0:
356
+ full_output_wave = full_output_wave / max_val
357
+
358
+ return sr_used, full_output_wave