Spaces:
Running
on
Zero
Running
on
Zero
import multiprocessing as mp | |
import torch | |
import os | |
from functools import partial | |
import gradio as gr | |
import traceback | |
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav | |
import spaces | |
os.system('huggingface-cli download ByteDance/MegaTTS3 --local-dir ./checkpoints --repo-type model') | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
infer_pipe = MegaTTS3DiTInfer(device='cuda' if CUDA_AVAILABLE else 'cpu') | |
def forward_gpu(file_content, wav_path, latent_file, inp_text, time_step, p_w, t_w): | |
resource_context = infer_pipe.preprocess(file_content, latent_file) | |
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w) | |
return wav_bytes | |
def model_worker(input_queue, output_queue, device_id): | |
task = input_queue.get() | |
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task | |
if inp_npy_path is None or inp_audio_path is None: | |
output_queue.put(None) | |
raise gr.Error("Please provide .wav and .npy file") | |
if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]): | |
output_queue.put(None) | |
raise gr.Error(".npy and .wav mismatch") | |
if len(inp_text) > 200: | |
output_queue.put(None) | |
raise gr.Error("input text is too long") | |
try: | |
convert_to_wav(inp_audio_path) | |
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav' | |
cut_wav(wav_path, max_len=24) | |
with open(wav_path, 'rb') as file: | |
file_content = file.read() | |
wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w) | |
output_queue.put(wav_bytes) | |
except Exception as e: | |
traceback.print_exc() | |
print(task, str(e)) | |
output_queue.put(None) | |
raise gr.Error("Generation failed") | |
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes): | |
input_queue = mp_manager.Queue() | |
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w) | |
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)) | |
output_queue = mp_manager.Queue() | |
model_worker(input_queue, output_queue, 0) | |
res = output_queue.get() | |
if res is not None: | |
return res | |
else: | |
return None | |
if __name__ == '__main__': | |
mp.set_start_method('spawn', force=True) | |
mp_manager = mp.Manager() | |
num_workers = 1 | |
devices = [0] | |
processes = [] | |
api_interface = gr.Interface(fn= | |
partial(main, processes=processes), | |
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text", | |
gr.Number(label="infer timestep", value=32), | |
gr.Number(label="Intelligibility Weight", value=1.4), | |
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")], | |
title="MegaTTS3", | |
examples=[ | |
['./official_test_case/范闲.wav', './official_test_case/范闲.npy', "你好呀,我是范闲,我是庆国十年来风雨画卷的见证者。", 32, 1.4, 3.0], | |
['./official_test_case/周杰伦1.wav', './official_test_case/周杰伦1.npy', "有的时候嘛,我去台湾开演唱会的时候,会很喜欢来一碗卤肉饭的。", 32, 1.4, 3.0], | |
['./official_test_case/english_talk_zhou.wav', './official_test_case/english_talk_zhou.npy', "Let us do some exercise and practice more.", 32, 1.4, 3.0], | |
], | |
cache_examples=True, | |
description="Upload a speech clip as a reference for timbre, " + | |
"upload the pre-extracted latent file, "+ | |
"input the target text, and receive the cloned voice. "+ | |
"Tip: a generation process should be within 120s (check if your input text are too long). Please use the system gently, as excessive load or languages other than English or Chinese may cause crashes and disrupt access for other users.", concurrency_limit=1) | |
api_interface.launch(server_name='0.0.0.0', server_port=7860, debug=True) | |
for p in processes: | |
p.join() | |