Spaces:
Running
Running
# coding=utf-8 | |
from io import BytesIO | |
from typing import Optional | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.responses import HTMLResponse | |
import numpy as np | |
import torch | |
import torchaudio | |
from funasr import AutoModel | |
from dotenv import load_dotenv | |
import os | |
# 加载环境变量 | |
load_dotenv() | |
# 获取API Token | |
API_TOKEN = os.getenv("API_TOKEN") | |
if not API_TOKEN: | |
raise RuntimeError("API_TOKEN environment variable is not set") | |
# 设置认证 | |
security = HTTPBearer() | |
app = FastAPI( | |
title="SenseVoice API", | |
description="语音识别 API 服务", | |
version="1.0.0" | |
) | |
# 允许跨域请求 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 初始化模型 | |
model = AutoModel( | |
model="FunAudioLLM/SenseVoiceSmall", | |
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
vad_kwargs={"max_single_segment_time": 30000}, | |
hub="hf", | |
device="cuda" | |
) | |
# 复用原有的格式化函数 | |
emo_dict = { | |
"<|HAPPY|>": "😊", | |
"<|SAD|>": "😔", | |
"<|ANGRY|>": "😡", | |
"<|NEUTRAL|>": "", | |
"<|FEARFUL|>": "😰", | |
"<|DISGUSTED|>": "🤢", | |
"<|SURPRISED|>": "😮", | |
} | |
event_dict = { | |
"<|BGM|>": "🎼", | |
"<|Speech|>": "", | |
"<|Applause|>": "👏", | |
"<|Laughter|>": "😀", | |
"<|Cry|>": "😭", | |
"<|Sneeze|>": "🤧", | |
"<|Breath|>": "", | |
"<|Cough|>": "🤧", | |
} | |
emoji_dict = { | |
"<|nospeech|><|Event_UNK|>": "❓", | |
"<|zh|>": "", | |
"<|en|>": "", | |
"<|yue|>": "", | |
"<|ja|>": "", | |
"<|ko|>": "", | |
"<|nospeech|>": "", | |
"<|HAPPY|>": "😊", | |
"<|SAD|>": "😔", | |
"<|ANGRY|>": "😡", | |
"<|NEUTRAL|>": "", | |
"<|BGM|>": "🎼", | |
"<|Speech|>": "", | |
"<|Applause|>": "👏", | |
"<|Laughter|>": "😀", | |
"<|FEARFUL|>": "😰", | |
"<|DISGUSTED|>": "🤢", | |
"<|SURPRISED|>": "😮", | |
"<|Cry|>": "😭", | |
"<|EMO_UNKNOWN|>": "", | |
"<|Sneeze|>": "🤧", | |
"<|Breath|>": "", | |
"<|Cough|>": "😷", | |
"<|Sing|>": "", | |
"<|Speech_Noise|>": "", | |
"<|withitn|>": "", | |
"<|woitn|>": "", | |
"<|GBG|>": "", | |
"<|Event_UNK|>": "", | |
} | |
lang_dict = { | |
"<|zh|>": "<|lang|>", | |
"<|en|>": "<|lang|>", | |
"<|yue|>": "<|lang|>", | |
"<|ja|>": "<|lang|>", | |
"<|ko|>": "<|lang|>", | |
"<|nospeech|>": "<|lang|>", | |
} | |
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"} | |
event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷"} | |
def format_str(s): | |
for sptk in emoji_dict: | |
s = s.replace(sptk, emoji_dict[sptk]) | |
return s | |
def format_str_v2(s): | |
sptk_dict = {} | |
for sptk in emoji_dict: | |
sptk_dict[sptk] = s.count(sptk) | |
s = s.replace(sptk, "") | |
emo = "<|NEUTRAL|>" | |
for e in emo_dict: | |
if sptk_dict[e] > sptk_dict[emo]: | |
emo = e | |
for e in event_dict: | |
if sptk_dict[e] > 0: | |
s = event_dict[e] + s | |
s = s + emo_dict[emo] | |
for emoji in emo_set.union(event_set): | |
s = s.replace(" " + emoji, emoji) | |
s = s.replace(emoji + " ", emoji) | |
return s.strip() | |
def format_str_v3(s): | |
def get_emo(s): | |
return s[-1] if s[-1] in emo_set else None | |
def get_event(s): | |
return s[0] if s[0] in event_set else None | |
s = s.replace("<|nospeech|><|Event_UNK|>", "❓") | |
for lang in lang_dict: | |
s = s.replace(lang, "<|lang|>") | |
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")] | |
new_s = " " + s_list[0] | |
cur_ent_event = get_event(new_s) | |
for i in range(1, len(s_list)): | |
if len(s_list[i]) == 0: | |
continue | |
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None: | |
s_list[i] = s_list[i][1:] | |
cur_ent_event = get_event(s_list[i]) | |
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s): | |
new_s = new_s[:-1] | |
new_s += s_list[i].strip().lstrip() | |
new_s = new_s.replace("The.", " ") | |
return new_s.strip() | |
async def process_audio(audio_data: bytes, language: str = "auto") -> str: | |
"""处理音频数据并返回识别结果""" | |
try: | |
# 将字节数据转换为 numpy 数组 | |
audio_buffer = BytesIO(audio_data) | |
waveform, sample_rate = torchaudio.load(audio_buffer) | |
# 转换为单声道 | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0) | |
# 转换为 numpy array 并归一化 | |
input_wav = waveform.numpy().astype(np.float32) | |
# 重采样到 16kHz | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy() | |
# 模型推理 | |
text = model.generate( | |
input=input_wav, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# 格式化结果 | |
result = text[0]["text"] | |
result = format_str_v3(result) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"音频处理失败:{str(e)}") | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
"""验证Bearer Token""" | |
if credentials.credentials != API_TOKEN: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"} | |
) | |
return credentials | |
async def transcribe_audio( | |
file: UploadFile = File(...), | |
model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", | |
language: Optional[str] = "auto", | |
token: HTTPAuthorizationCredentials = Depends(verify_token) | |
): | |
"""音频转写接口 | |
Args: | |
file: 音频文件(支持常见音频格式) | |
model: 模型名称,目前仅支持 FunAudioLLM/SenseVoiceSmall | |
language: 语言代码,支持 auto/zh/en/yue/ja/ko/nospeech | |
Returns: | |
{"text": "识别结果"} | |
""" | |
if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
raise HTTPException(status_code=400, detail="不支持的音频格式") | |
if model != "FunAudioLLM/SenseVoiceSmall": | |
raise HTTPException(status_code=400, detail="不支持的模型") | |
if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: | |
raise HTTPException(status_code=400, detail="不支持的语言") | |
try: | |
content = await file.read() | |
text = await process_audio(content, language) | |
return {"text": text} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>SenseVoice API</title> | |
<style> | |
body { font-family: Arial, sans-serif; max-width: 800px; margin: 40px auto; padding: 0 20px; line-height: 1.6; } | |
h1 { color: #2c3e50; } | |
.api-info { background: #f8f9fa; padding: 20px; border-radius: 5px; margin: 20px 0; } | |
.api-link { display: inline-block; background: #3498db; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px; margin-top: 20px; } | |
.api-link:hover { background: #2980b9; } | |
</style> | |
</head> | |
<body> | |
<h1>欢迎使用 SenseVoice API</h1> | |
<div class="api-info"> | |
<h2>服务信息</h2> | |
<p>版本:1.0.0</p> | |
<p>描述:多语言语音识别服务,支持中文、英语、粤语、日语、韩语等多种语言的语音转写。</p> | |
<h2>主要功能</h2> | |
<ul> | |
<li>支持多种音频格式:MP3、WAV、FLAC、OGG、M4A</li> | |
<li>自动语言检测</li> | |
<li>情感和事件识别</li> | |
<li>高性能语音识别引擎</li> | |
</ul> | |
</div> | |
<a href="/docs" class="api-link">查看API文档</a> | |
</body> | |
</html> | |
""" | |
return html_content | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |