import os import subprocess import logging import torch import gradio as gr from scipy.io.wavfile import write # 自定义模块 import commons import utils from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate from models import SynthesizerTrn from text.symbols import symbols from text import text_to_sequence # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 编译 monotonic_align 模块 def compile_monotonic_align(): try: os.system('cd monotonic_align && python setup.py build_ext --inplace && cd ..') logger.info("Successfully compiled monotonic_align.") except subprocess.CalledProcessError as e: logger.error(f"Failed to compile monotonic_align: {e}") raise RuntimeError("Compilation of monotonic_align failed.") # 加载配置和模型 def load_config_and_model(config_path, checkpoint_path): if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") # 加载超参数 hps = utils.get_hparams_from_file(config_path) logger.info("Loaded hyperparameters from config file.") # 初始化模型 net_g = SynthesizerTrn( len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model, ) net_g.eval() logger.info("Initialized SynthesizerTrn model.") # 加载预训练权重 utils.load_checkpoint(checkpoint_path, net_g, None) logger.info(f"Loaded model checkpoint from {checkpoint_path}.") return hps, net_g # 文本到语音合成 def text_to_speech(content, hps, net_g): if not content or not isinstance(content, str): raise ValueError("Input text is empty or invalid.") try: # 将文本转换为序列 stn_tst = text_to_sequence(content, hps.data.text_cleaners) if hps.data.add_blank: stn_tst = commons.intersperse(stn_tst, 0) stn_tst = torch.LongTensor(stn_tst) # 模型推理 with torch.no_grad(): x_tst = stn_tst.unsqueeze(0) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) audio = net_g.infer( x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 )[0][0, 0].data.float().numpy() return hps.data.sampling_rate, audio except Exception as e: logger.error(f"Error during text-to-speech synthesis: {e}") raise RuntimeError("Failed to generate audio.") # Gradio 界面 def create_gradio_interface(hps, net_g): def safe_syn(content): try: return text_to_speech(content, hps, net_g) except Exception as e: logger.error(f"Error in Gradio interface: {e}") return None app = gr.Blocks() with app: with gr.Tabs(): with gr.TabItem("Basic"): input1 = gr.Textbox(label="Input Text", placeholder="Enter text here...") submit = gr.Button("Convert", variant="primary") output1 = gr.Audio(label="Output Audio") submit.click(safe_syn, input1, output1) return app # 主函数 def main(): try: # 编译 monotonic_align compile_monotonic_align() # 加载配置和模型 config_path = "configs/steins_gate_base.json" checkpoint_path = "G_265000.pth" hps, net_g = load_config_and_model(config_path, checkpoint_path) # 创建 Gradio 界面 app = create_gradio_interface(hps, net_g) logger.info("Starting Gradio interface...") app.launch() except Exception as e: logger.critical(f"Fatal error: {e}") exit(1) if __name__ == "__main__": main()