Spaces:
Build error
Build error
import torch | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import numpy as np | |
import base64 | |
import io | |
from scipy.io.wavfile import write | |
import sounddevice as sd | |
# 自定义模块 | |
import commons | |
import utils | |
from models import SynthesizerTrn | |
from text.symbols import symbols | |
from text import text_to_sequence | |
# 检查 PyTorch 版本 | |
print(torch.__version__) | |
# 检查 CUDA 是否可用 | |
print(torch.cuda.is_available()) | |
# 检查当前 CUDA 版本 | |
print(torch.version.cuda) | |
# FastAPI 应用 | |
app = FastAPI() | |
# 请求体模型 | |
class TextRequest(BaseModel): | |
text: str | |
# 加载配置和模型 | |
config_path = "configs/steins_gate_base.json" | |
checkpoint_path = "G_265000.pth" | |
hps = utils.get_hparams_from_file(config_path) | |
net_g = SynthesizerTrn( | |
len(symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
**hps.model, | |
).eval() | |
utils.load_checkpoint(checkpoint_path, net_g, None) | |
# 文本到语音合成 | |
def text_to_speech(content): | |
stn_tst = text_to_sequence(content, hps.data.text_cleaners) | |
if hps.data.add_blank: | |
stn_tst = commons.intersperse(stn_tst, 0) | |
stn_tst = torch.LongTensor(stn_tst) | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) | |
audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1)[0][0, 0].data.float().numpy() | |
return hps.data.sampling_rate, audio | |
# API 路由:文本转语音 | |
def synthesize(request: TextRequest): | |
# 假设 text_to_speech 是生成音频的函数 | |
sampling_rate, audio = text_to_speech(request.text) | |
# 将音频数据保存到 BytesIO 对象 | |
wav_bytes = io.BytesIO() | |
write(wav_bytes, sampling_rate, (audio * 32767).astype(np.int16)) | |
wav_bytes.seek(0) # 将指针移动到文件开头 | |
# 将 WAV 文件编码为 Base64 | |
audio_base64 = base64.b64encode(wav_bytes.read()).decode("utf-8") | |
return {"audio": audio_base64} | |
# 主函数 | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="127.0.0.1", port=8000) |