Plachta commited on
Commit
9340499
1 Parent(s): cd2f227

Chunk wise inference & streaming output

Browse files
Files changed (1) hide show
  1. app.py +94 -27
app.py CHANGED
@@ -6,6 +6,8 @@ from modules.commons import build_model, load_checkpoint, recursive_munch
6
  import yaml
7
  from hf_utils import load_custom_model_from_hf
8
  import spaces
 
 
9
 
10
  # Load model and configuration
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -111,6 +113,19 @@ def adjust_f0_semitones(f0_sequence, n_semitones):
111
  factor = 2 ** (n_semitones / 12)
112
  return f0_sequence * factor
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  @spaces.GPU
115
  @torch.no_grad()
116
  @torch.inference_mode()
@@ -134,17 +149,23 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
134
  S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
135
  elif speech_tokenizer_type == 'facodec':
136
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
137
- wave_lengths_24k = torch.LongTensor([converted_waves_24k.size(1)]).to(converted_waves_24k.device)
138
  waves_input = converted_waves_24k.unsqueeze(1)
139
- z = codec_encoder.encoder(waves_input)
140
- (
141
- quantized,
142
- codes
143
- ) = codec_encoder.quantizer(
144
- z,
145
- waves_input,
146
- )
147
- S_alt = torch.cat([codes[1], codes[0]], dim=1)
 
 
 
 
 
 
 
148
 
149
  # S_ori should be extracted in the same way
150
  waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
@@ -207,26 +228,72 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
207
  # Length regulation
208
  cond = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers), f0=shifted_f0_alt)[0]
209
  prompt_condition = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers), f0=F0_ori)[0]
210
- cat_condition = torch.cat([prompt_condition, cond], dim=1)
211
-
212
- # Voice Conversion
213
- vc_target = inference_module.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
214
- mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
215
- vc_target = vc_target[:, :, mel2.size(-1):]
216
-
217
- # Convert to waveform
218
- # if f0_condition:
219
- # f04vocoder = torch.nn.functional.interpolate(shifted_f0_alt.unsqueeze(1), size=vc_target.size(-1),
220
- # mode='nearest').squeeze(1)
221
- # else:
222
- f04vocoder = None
223
- vc_wave = hift_gen.inference(vc_target, f0=f04vocoder)
224
 
225
- return sr, vc_wave.squeeze(0).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
228
  if __name__ == "__main__":
229
- description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
 
 
230
  inputs = [
231
  gr.Audio(type="filepath", label="Source Audio"),
232
  gr.Audio(type="filepath", label="Reference Audio"),
@@ -244,7 +311,7 @@ if __name__ == "__main__":
244
  ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
245
  "examples/reference/teio_0.wav", 100, 1.0, 0.7, 3, True, True, 0],]
246
 
247
- outputs = gr.Audio(label="Output Audio")
248
 
249
  gr.Interface(fn=voice_conversion,
250
  description=description,
 
6
  import yaml
7
  from hf_utils import load_custom_model_from_hf
8
  import spaces
9
+ import numpy as np
10
+ from pydub import AudioSegment
11
 
12
  # Load model and configuration
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
113
  factor = 2 ** (n_semitones / 12)
114
  return f0_sequence * factor
115
 
116
+ def crossfade(chunk1, chunk2, overlap):
117
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
118
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
119
+ chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
120
+ return chunk2
121
+
122
+ # streaming and chunk processing related params
123
+ max_context_window = sr // hop_length * 30
124
+ overlap_frame_len = 64
125
+ overlap_wave_len = overlap_frame_len * hop_length
126
+ max_wave_len_per_chunk = 24000 * 20
127
+ bitrate = "320k"
128
+
129
  @spaces.GPU
130
  @torch.no_grad()
131
  @torch.inference_mode()
 
149
  S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
150
  elif speech_tokenizer_type == 'facodec':
151
  converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
 
152
  waves_input = converted_waves_24k.unsqueeze(1)
153
+ wave_input_chunks = [
154
+ waves_input[..., i:i + max_wave_len_per_chunk] for i in range(0, waves_input.size(-1), max_wave_len_per_chunk)
155
+ ]
156
+ S_alt_chunks = []
157
+ for i, chunk in enumerate(wave_input_chunks):
158
+ z = codec_encoder.encoder(chunk)
159
+ (
160
+ quantized,
161
+ codes
162
+ ) = codec_encoder.quantizer(
163
+ z,
164
+ chunk,
165
+ )
166
+ S_alt = torch.cat([codes[1], codes[0]], dim=1)
167
+ S_alt_chunks.append(S_alt)
168
+ S_alt = torch.cat(S_alt_chunks, dim=-1)
169
 
170
  # S_ori should be extracted in the same way
171
  waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
 
228
  # Length regulation
229
  cond = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers), f0=shifted_f0_alt)[0]
230
  prompt_condition = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers), f0=F0_ori)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ max_source_window = max_context_window - mel2.size(2)
233
+ # split source condition (cond) into chunks
234
+ processed_frames = 0
235
+ generated_wave_chunks = []
236
+ # generate chunk by chunk and stream the output
237
+ while processed_frames < cond.size(1):
238
+ chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
239
+ is_last_chunk = processed_frames + max_source_window >= cond.size(1)
240
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
241
+ # Voice Conversion
242
+ vc_target = inference_module.cfm.inference(cat_condition,
243
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
244
+ mel2, style2, None, diffusion_steps,
245
+ inference_cfg_rate=inference_cfg_rate)
246
+ vc_target = vc_target[:, :, mel2.size(-1):]
247
+ vc_wave = hift_gen.inference(vc_target, f0=None)
248
+ if processed_frames == 0:
249
+ if is_last_chunk:
250
+ output_wave = vc_wave[0].cpu().numpy()
251
+ generated_wave_chunks.append(output_wave)
252
+ output_wave = (output_wave * 32768.0).astype(np.int16)
253
+ mp3_bytes = AudioSegment(
254
+ output_wave.tobytes(), frame_rate=sr,
255
+ sample_width=output_wave.dtype.itemsize, channels=1
256
+ ).export(format="mp3", bitrate=bitrate).read()
257
+ yield mp3_bytes
258
+ break
259
+ output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
260
+ generated_wave_chunks.append(output_wave)
261
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
262
+ processed_frames += vc_target.size(2) - overlap_frame_len
263
+ output_wave = (output_wave * 32768.0).astype(np.int16)
264
+ mp3_bytes = AudioSegment(
265
+ output_wave.tobytes(), frame_rate=sr,
266
+ sample_width=output_wave.dtype.itemsize, channels=1
267
+ ).export(format="mp3", bitrate=bitrate).read()
268
+ yield mp3_bytes
269
+ elif is_last_chunk:
270
+ output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
271
+ generated_wave_chunks.append(output_wave)
272
+ processed_frames += vc_target.size(2) - overlap_frame_len
273
+ output_wave = (output_wave * 32768.0).astype(np.int16)
274
+ mp3_bytes = AudioSegment(
275
+ output_wave.tobytes(), frame_rate=sr,
276
+ sample_width=output_wave.dtype.itemsize, channels=1
277
+ ).export(format="mp3", bitrate=bitrate).read()
278
+ yield mp3_bytes
279
+ break
280
+ else:
281
+ output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
282
+ generated_wave_chunks.append(output_wave)
283
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
284
+ processed_frames += vc_target.size(2) - overlap_frame_len
285
+ output_wave = (output_wave * 32768.0).astype(np.int16)
286
+ mp3_bytes = AudioSegment(
287
+ output_wave.tobytes(), frame_rate=sr,
288
+ sample_width=output_wave.dtype.itemsize, channels=1
289
+ ).export(format="mp3", bitrate=bitrate).read()
290
+ yield mp3_bytes
291
 
292
 
293
  if __name__ == "__main__":
294
+ description = ("Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) "
295
+ "for details and updates.<br>Note that any reference audio will be forcefully clipped to 25s if beyond this length.<br> "
296
+ "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.")
297
  inputs = [
298
  gr.Audio(type="filepath", label="Source Audio"),
299
  gr.Audio(type="filepath", label="Reference Audio"),
 
311
  ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
312
  "examples/reference/teio_0.wav", 100, 1.0, 0.7, 3, True, True, 0],]
313
 
314
+ outputs = gr.Audio(label="Output Audio", streaming=True, format='mp3')
315
 
316
  gr.Interface(fn=voice_conversion,
317
  description=description,