File size: 3,933 Bytes
7353fd4
2ef7c17
00eb13e
2ef7c17
 
 
7353fd4
2ef7c17
7353fd4
 
 
 
 
 
 
2ef7c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import subprocess
import logging
import torch
import gradio as gr
from scipy.io.wavfile import write

# 自定义模块
import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 编译 monotonic_align 模块
def compile_monotonic_align():
    try:

        os.system('cd monotonic_align && python setup.py build_ext --inplace && cd ..')

        logger.info("Successfully compiled monotonic_align.")
    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to compile monotonic_align: {e}")
        raise RuntimeError("Compilation of monotonic_align failed.")

# 加载配置和模型
def load_config_and_model(config_path, checkpoint_path):
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    # 加载超参数
    hps = utils.get_hparams_from_file(config_path)
    logger.info("Loaded hyperparameters from config file.")

    # 初始化模型
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model,
    )
    net_g.eval()
    logger.info("Initialized SynthesizerTrn model.")

    # 加载预训练权重
    utils.load_checkpoint(checkpoint_path, net_g, None)
    logger.info(f"Loaded model checkpoint from {checkpoint_path}.")

    return hps, net_g

# 文本到语音合成
def text_to_speech(content, hps, net_g):
    if not content or not isinstance(content, str):
        raise ValueError("Input text is empty or invalid.")

    try:
        # 将文本转换为序列
        stn_tst = text_to_sequence(content, hps.data.text_cleaners)
        if hps.data.add_blank:
            stn_tst = commons.intersperse(stn_tst, 0)
        stn_tst = torch.LongTensor(stn_tst)

        # 模型推理
        with torch.no_grad():
            x_tst = stn_tst.unsqueeze(0)
            x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
            audio = net_g.infer(
                x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
            )[0][0, 0].data.float().numpy()

        return hps.data.sampling_rate, audio
    except Exception as e:
        logger.error(f"Error during text-to-speech synthesis: {e}")
        raise RuntimeError("Failed to generate audio.")

# Gradio 界面
def create_gradio_interface(hps, net_g):
    def safe_syn(content):
        try:
            return text_to_speech(content, hps, net_g)
        except Exception as e:
            logger.error(f"Error in Gradio interface: {e}")
            return None

    app = gr.Blocks()
    with app:
        with gr.Tabs():
            with gr.TabItem("Basic"):
                input1 = gr.Textbox(label="Input Text", placeholder="Enter text here...")
                submit = gr.Button("Convert", variant="primary")
                output1 = gr.Audio(label="Output Audio")
            submit.click(safe_syn, input1, output1)

    return app

# 主函数
def main():
    try:
        # 编译 monotonic_align
        compile_monotonic_align()

        # 加载配置和模型
        config_path = "configs/steins_gate_base.json"
        checkpoint_path = "G_265000.pth"
        hps, net_g = load_config_and_model(config_path, checkpoint_path)

        # 创建 Gradio 界面
        app = create_gradio_interface(hps, net_g)
        logger.info("Starting Gradio interface...")
        app.launch()
    except Exception as e:
        logger.critical(f"Fatal error: {e}")
        exit(1)

if __name__ == "__main__":
    main()