Spaces:
Build error
Build error
import os | |
import subprocess | |
import logging | |
import torch | |
import gradio as gr | |
from scipy.io.wavfile import write | |
# 自定义模块 | |
import commons | |
import utils | |
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate | |
from models import SynthesizerTrn | |
from text.symbols import symbols | |
from text import text_to_sequence | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 编译 monotonic_align 模块 | |
def compile_monotonic_align(): | |
try: | |
os.system('cd monotonic_align && python setup.py build_ext --inplace && cd ..') | |
logger.info("Successfully compiled monotonic_align.") | |
except subprocess.CalledProcessError as e: | |
logger.error(f"Failed to compile monotonic_align: {e}") | |
raise RuntimeError("Compilation of monotonic_align failed.") | |
# 加载配置和模型 | |
def load_config_and_model(config_path, checkpoint_path): | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Config file not found: {config_path}") | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
# 加载超参数 | |
hps = utils.get_hparams_from_file(config_path) | |
logger.info("Loaded hyperparameters from config file.") | |
# 初始化模型 | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
**hps.model, | |
) | |
net_g.eval() | |
logger.info("Initialized SynthesizerTrn model.") | |
# 加载预训练权重 | |
utils.load_checkpoint(checkpoint_path, net_g, None) | |
logger.info(f"Loaded model checkpoint from {checkpoint_path}.") | |
return hps, net_g | |
# 文本到语音合成 | |
def text_to_speech(content, hps, net_g): | |
if not content or not isinstance(content, str): | |
raise ValueError("Input text is empty or invalid.") | |
try: | |
# 将文本转换为序列 | |
stn_tst = text_to_sequence(content, hps.data.text_cleaners) | |
if hps.data.add_blank: | |
stn_tst = commons.intersperse(stn_tst, 0) | |
stn_tst = torch.LongTensor(stn_tst) | |
# 模型推理 | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) | |
audio = net_g.infer( | |
x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 | |
)[0][0, 0].data.float().numpy() | |
return hps.data.sampling_rate, audio | |
except Exception as e: | |
logger.error(f"Error during text-to-speech synthesis: {e}") | |
raise RuntimeError("Failed to generate audio.") | |
# Gradio 界面 | |
def create_gradio_interface(hps, net_g): | |
def safe_syn(content): | |
try: | |
return text_to_speech(content, hps, net_g) | |
except Exception as e: | |
logger.error(f"Error in Gradio interface: {e}") | |
return None | |
app = gr.Blocks() | |
with app: | |
with gr.Tabs(): | |
with gr.TabItem("Basic"): | |
input1 = gr.Textbox(label="Input Text", placeholder="Enter text here...") | |
submit = gr.Button("Convert", variant="primary") | |
output1 = gr.Audio(label="Output Audio") | |
submit.click(safe_syn, input1, output1) | |
return app | |
# 主函数 | |
def main(): | |
try: | |
# 编译 monotonic_align | |
compile_monotonic_align() | |
# 加载配置和模型 | |
config_path = "configs/steins_gate_base.json" | |
checkpoint_path = "G_265000.pth" | |
hps, net_g = load_config_and_model(config_path, checkpoint_path) | |
# 创建 Gradio 界面 | |
app = create_gradio_interface(hps, net_g) | |
logger.info("Starting Gradio interface...") | |
app.launch() | |
except Exception as e: | |
logger.critical(f"Fatal error: {e}") | |
exit(1) | |
if __name__ == "__main__": | |
main() |