Spaces:
Running
Running
# coding=utf-8 | |
from io import BytesIO | |
from typing import Optional | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.responses import HTMLResponse | |
import numpy as np | |
import torch | |
import torchaudio | |
from funasr import AutoModel | |
from dotenv import load_dotenv | |
import os | |
import time | |
import gradio as gr | |
# 加载环境变量 | |
load_dotenv() | |
# 获取API Token | |
API_TOKEN = os.getenv("API_TOKEN") | |
if not API_TOKEN: | |
raise RuntimeError("API_TOKEN environment variable is not set") | |
# 设置认证 | |
security = HTTPBearer() | |
app = FastAPI( | |
title="SenseVoice API", | |
description="语音识别 API 服务", | |
version="1.0.0" | |
) | |
# 允许跨域请求 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 初始化模型 | |
model = AutoModel( | |
model="FunAudioLLM/SenseVoiceSmall", | |
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
vad_kwargs={"max_single_segment_time": 30000}, | |
hub="hf", | |
device="cuda" | |
) | |
# 复用原有的格式化函数 | |
emo_dict = { | |
"<|HAPPY|>": "😊", | |
"<|SAD|>": "😔", | |
"<|ANGRY|>": "😡", | |
"<|NEUTRAL|>": "", | |
"<|FEARFUL|>": "😰", | |
"<|DISGUSTED|>": "🤢", | |
"<|SURPRISED|>": "😮", | |
} | |
event_dict = { | |
"<|BGM|>": "🎼", | |
"<|Speech|>": "", | |
"<|Applause|>": "👏", | |
"<|Laughter|>": "😀", | |
"<|Cry|>": "😭", | |
"<|Sneeze|>": "🤧", | |
"<|Breath|>": "", | |
"<|Cough|>": "🤧", | |
} | |
emoji_dict = { | |
"<|nospeech|><|Event_UNK|>": "❓", | |
"<|zh|>": "", | |
"<|en|>": "", | |
"<|yue|>": "", | |
"<|ja|>": "", | |
"<|ko|>": "", | |
"<|nospeech|>": "", | |
"<|HAPPY|>": "😊", | |
"<|SAD|>": "😔", | |
"<|ANGRY|>": "😡", | |
"<|NEUTRAL|>": "", | |
"<|BGM|>": "🎼", | |
"<|Speech|>": "", | |
"<|Applause|>": "👏", | |
"<|Laughter|>": "😀", | |
"<|FEARFUL|>": "😰", | |
"<|DISGUSTED|>": "🤢", | |
"<|SURPRISED|>": "😮", | |
"<|Cry|>": "😭", | |
"<|EMO_UNKNOWN|>": "", | |
"<|Sneeze|>": "🤧", | |
"<|Breath|>": "", | |
"<|Cough|>": "😷", | |
"<|Sing|>": "", | |
"<|Speech_Noise|>": "", | |
"<|withitn|>": "", | |
"<|woitn|>": "", | |
"<|GBG|>": "", | |
"<|Event_UNK|>": "", | |
} | |
lang_dict = { | |
"<|zh|>": "<|lang|>", | |
"<|en|>": "<|lang|>", | |
"<|yue|>": "<|lang|>", | |
"<|ja|>": "<|lang|>", | |
"<|ko|>": "<|lang|>", | |
"<|nospeech|>": "<|lang|>", | |
} | |
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"} | |
event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷"} | |
def format_str(s): | |
for sptk in emoji_dict: | |
s = s.replace(sptk, emoji_dict[sptk]) | |
return s | |
def format_str_v2(s): | |
sptk_dict = {} | |
for sptk in emoji_dict: | |
sptk_dict[sptk] = s.count(sptk) | |
s = s.replace(sptk, "") | |
emo = "<|NEUTRAL|>" | |
for e in emo_dict: | |
if sptk_dict[e] > sptk_dict[emo]: | |
emo = e | |
for e in event_dict: | |
if sptk_dict[e] > 0: | |
s = event_dict[e] + s | |
s = s + emo_dict[emo] | |
for emoji in emo_set.union(event_set): | |
s = s.replace(" " + emoji, emoji) | |
s = s.replace(emoji + " ", emoji) | |
return s.strip() | |
def format_str_v3(s): | |
def get_emo(s): | |
return s[-1] if s[-1] in emo_set else None | |
def get_event(s): | |
return s[0] if s[0] in event_set else None | |
s = s.replace("<|nospeech|><|Event_UNK|>", "❓") | |
for lang in lang_dict: | |
s = s.replace(lang, "<|lang|>") | |
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")] | |
new_s = " " + s_list[0] | |
cur_ent_event = get_event(new_s) | |
for i in range(1, len(s_list)): | |
if len(s_list[i]) == 0: | |
continue | |
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None: | |
s_list[i] = s_list[i][1:] | |
cur_ent_event = get_event(s_list[i]) | |
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s): | |
new_s = new_s[:-1] | |
new_s += s_list[i].strip().lstrip() | |
new_s = new_s.replace("The.", " ") | |
return new_s.strip() | |
async def process_audio(audio_data: bytes, language: str = "auto") -> str: | |
"""处理音频数据并返回识别结果""" | |
try: | |
# 将字节数据转换为 numpy 数组 | |
audio_buffer = BytesIO(audio_data) | |
waveform, sample_rate = torchaudio.load(audio_buffer) | |
# 转换为单声道 | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0) | |
# 转换为 numpy array 并归一化 | |
input_wav = waveform.numpy().astype(np.float32) | |
# 重采样到 16kHz | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy() | |
# 模型推理 | |
text = model.generate( | |
input=input_wav, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# 格式化结果 | |
result = text[0]["text"] | |
result = format_str_v3(result) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"音频处理失败:{str(e)}") | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
"""验证Bearer Token""" | |
if credentials.credentials != API_TOKEN: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"} | |
) | |
return credentials | |
async def transcribe_audio( | |
file: UploadFile = File(...), | |
model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", | |
language: Optional[str] = "auto", | |
token: HTTPAuthorizationCredentials = Depends(verify_token) | |
): | |
"""音频转写接口 | |
Args: | |
file: 音频文件(支持常见音频格式) | |
model: 模型名称,目前仅支持 FunAudioLLM/SenseVoiceSmall | |
language: 语言代码,支持 auto/zh/en/yue/ja/ko/nospeech | |
Returns: | |
{ | |
"text": "识别结果", | |
"error_code": 0, | |
"error_msg": "", | |
"process_time": 1.234 # 处理时间(秒) | |
} | |
""" | |
start_time = time.time() | |
try: | |
if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "不支持的音频格式", | |
"process_time": time.time() - start_time | |
} | |
if model != "FunAudioLLM/SenseVoiceSmall": | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "不支持的模型", | |
"process_time": time.time() - start_time | |
} | |
if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "不支持的语言", | |
"process_time": time.time() - start_time | |
} | |
content = await file.read() | |
text = await process_audio(content, language) | |
return { | |
"text": text, | |
"error_code": 0, | |
"error_msg": "", | |
"process_time": time.time() - start_time | |
} | |
except Exception as e: | |
return { | |
"text": "", | |
"error_code": 500, | |
"error_msg": str(e), | |
"process_time": time.time() - start_time | |
} | |
def transcribe_audio_gradio(audio, language="auto"): | |
"""Gradio界面的音频转写函数""" | |
try: | |
if audio is None: | |
return "请上传音频文件" | |
# 读取音频数据 | |
fs, input_wav = audio | |
print('------------------------------') | |
print(fs, type(fs)) | |
print(input_wav, type(input_wav)) | |
print('------------------------------') | |
input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max | |
# 转换为单声道 | |
if len(input_wav.shape) > 1: | |
input_wav = input_wav.mean(-1) | |
# 重采样到16kHz | |
if fs != 16000: | |
resampler = torchaudio.transforms.Resample(fs, 16000) | |
input_wav_t = torch.from_numpy(input_wav).to(torch.float32) | |
input_wav = resampler(input_wav_t[None, :])[0, :].numpy() | |
# 模型推理 | |
text = model.generate( | |
input=input_wav, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# 格式化结果 | |
result = text[0]["text"] | |
result = format_str_v3(result) | |
return result | |
except Exception as e: | |
return f"处理失败:{str(e)}" | |
# 创建Gradio界面 | |
demo = gr.Interface( | |
fn=transcribe_audio_gradio, | |
inputs=[ | |
gr.Audio(sources=["microphone", "upload"], type="numpy", label="上传音频或使用麦克风录音"), | |
gr.Dropdown( | |
choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], | |
value="auto", | |
label="选择语言" | |
) | |
], | |
outputs=gr.Textbox(label="识别结果"), | |
title="SenseVoice 语音识别", | |
description="支持中文、英语、粤语、日语、韩语等多种语言的语音转写服务", | |
examples=[ | |
["examples/chinese.wav", "zh"], | |
["examples/english.wav", "en"] | |
] | |
) | |
# 将Gradio应用挂载到FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
async def custom_swagger_ui_html(): | |
return HTMLResponse(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>SenseVoice API 文档</title> | |
<meta http-equiv="refresh" content="0;url=/docs/" /> | |
</head> | |
<body> | |
<p>正在跳转到API文档...</p> | |
</body> | |
</html> | |
""") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |