File size: 3,483 Bytes
1913b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
import langid
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
import openvoice.se_extractor as se_extractor

# Constants
CKPT_BASE_PATH = "checkpoints"
EN_SUFFIX = f"{CKPT_BASE_PATH}/base_speakers/EN"
CONVERTER_SUFFIX = f"{CKPT_BASE_PATH}/converter"
OUTPUT_DIR = "outputs/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Download necessary files
def download_from_hf_hub(filename, local_dir="./"):
    os.makedirs(local_dir, exist_ok=True)
    hf_hub_download(repo_id="myshell-ai/OpenVoice", filename=filename, local_dir=local_dir)

for file in [f"{CONVERTER_SUFFIX}/checkpoint.pth", f"{CONVERTER_SUFFIX}/config.json",
             f"{EN_SUFFIX}/checkpoint.pth", f"{EN_SUFFIX}/config.json",
             f"{EN_SUFFIX}/en_default_se.pth", f"{EN_SUFFIX}/en_style_se.pth"]:
    download_from_hf_hub(file)

# Initialize models
pt_device = "cpu"
en_base_speaker_tts = BaseSpeakerTTS(f"{EN_SUFFIX}/config.json", device=pt_device)
en_base_speaker_tts.load_ckpt(f"{EN_SUFFIX}/checkpoint.pth")

tone_color_converter = ToneColorConverter(f"{CONVERTER_SUFFIX}/config.json", device=pt_device)
tone_color_converter.load_ckpt(f"{CONVERTER_SUFFIX}/checkpoint.pth")

en_source_default_se = torch.load(f"{EN_SUFFIX}/en_default_se.pth")
en_source_style_se = torch.load(f"{EN_SUFFIX}/en_style_se.pth")

# Main prediction function
def predict(prompt, style, audio_file_pth, tau):
    if len(prompt) < 2 or len(prompt) > 200:
        return "Text should be between 2 and 200 characters.", None

    try:
        target_se, _ = se_extractor.get_se(audio_file_pth, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
    except Exception as e:
        return f"Error getting target tone color: {str(e)}", None

    src_path = f"{OUTPUT_DIR}/tmp.wav"
    en_base_speaker_tts.tts(prompt, src_path, speaker=style, language="English")

    save_path = f"{OUTPUT_DIR}/output.wav"
    tone_color_converter.convert(
        audio_src_path=src_path,
        src_se=en_source_style_se if style != "default" else en_source_default_se,
        tgt_se=target_se,
        output_path=save_path,
        tau=tau
    )

    return "Voice cloning completed successfully.", save_path

# Gradio interface
def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# OpenVoice: Instant Voice Cloning with fine-tuning")
        
        with gr.Row():
            input_text = gr.Textbox(label="Text to speak", placeholder="Enter text here (2-200 characters)")
            style = gr.Dropdown(
                label="Style",
                choices=["default", "whispering", "cheerful", "terrified", "angry", "sad", "friendly"],
                value="default"
            )
        
        with gr.Row():
            reference_audio = gr.Audio(label="Reference Audio", type="filepath")
            tau_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Tau (Voice similarity)", info="Higher values make the output more similar to the reference voice")

        submit_button = gr.Button("Generate Voice")
        
        output_text = gr.Textbox(label="Status")
        output_audio = gr.Audio(label="Generated Audio")

        submit_button.click(
            predict,
            inputs=[input_text, style, reference_audio, tau_slider],
            outputs=[output_text, output_audio]
        )

    return demo

# Launch the demo
if __name__ == "__main__":
    demo = create_demo()
    demo.launch()