Spaces:
Running
Running
# coding=utf-8 | |
from io import BytesIO | |
from typing import Optional, Dict, Any, List, Set, Union, Tuple | |
import os | |
import time | |
# Third-party imports | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.responses import HTMLResponse | |
import numpy as np | |
import torch | |
import torchaudio | |
from funasr import AutoModel | |
from dotenv import load_dotenv | |
import os | |
import time | |
import gradio as gr | |
# ๅ ่ฝฝ็ฏๅขๅ้ | |
load_dotenv() | |
# ่ทๅAPI Token | |
API_TOKEN: str = os.getenv("API_TOKEN") | |
if not API_TOKEN: | |
raise RuntimeError("API_TOKEN environment variable is not set") | |
# ่ฎพ็ฝฎ่ฎค่ฏ | |
security = HTTPBearer() | |
app = FastAPI( | |
title="SenseVoice API", | |
description="่ฏญ้ณ่ฏๅซ API ๆๅก", | |
version="1.0.0" | |
) | |
# ๅ ่ฎธ่ทจๅ่ฏทๆฑ | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ๅๅงๅๆจกๅ | |
model = AutoModel( | |
model="FunAudioLLM/SenseVoiceSmall", | |
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
vad_kwargs={"max_single_segment_time": 30000}, | |
hub="hf", | |
device="cuda" | |
) | |
# ๅค็จๅๆ็ๆ ผๅผๅๅฝๆฐ | |
emotion_dict: Dict[str, str] = { | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
} | |
event_dict: Dict[str, str] = { | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|Cry|>": "๐ญ", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐คง", | |
} | |
emoji_dict: Dict[str, str] = { | |
"<|nospeech|><|Event_UNK|>": "โ", | |
"<|zh|>": "", | |
"<|en|>": "", | |
"<|yue|>": "", | |
"<|ja|>": "", | |
"<|ko|>": "", | |
"<|nospeech|>": "", | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
"<|Cry|>": "๐ญ", | |
"<|EMO_UNKNOWN|>": "", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐ท", | |
"<|Sing|>": "", | |
"<|Speech_Noise|>": "", | |
"<|withitn|>": "", | |
"<|woitn|>": "", | |
"<|GBG|>": "", | |
"<|Event_UNK|>": "", | |
} | |
lang_dict: Dict[str, str] = { | |
"<|zh|>": "<|lang|>", | |
"<|en|>": "<|lang|>", | |
"<|yue|>": "<|lang|>", | |
"<|ja|>": "<|lang|>", | |
"<|ko|>": "<|lang|>", | |
"<|nospeech|>": "<|lang|>", | |
} | |
emo_set: Set[str] = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} | |
event_set: Set[str] = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} | |
def format_text_basic(text: str) -> str: | |
"""Replace special tokens with corresponding emojis""" | |
for token in emoji_dict: | |
text = text.replace(token, emoji_dict[token]) | |
return text | |
def format_text_with_emotion(text: str) -> str: | |
"""Format text with emotion and event markers""" | |
token_count: Dict[str, int] = {} | |
original_text = text | |
for token in emoji_dict: | |
token_count[token] = text.count(token) | |
# Determine dominant emotion | |
dominant_emotion = "<|NEUTRAL|>" | |
for emotion in emotion_dict: | |
if token_count[emotion] > token_count[dominant_emotion]: | |
dominant_emotion = emotion | |
# Add event markers | |
text = original_text | |
for event in event_dict: | |
if token_count[event] > 0: | |
text = event_dict[event] + text | |
# Replace all tokens with their emoji equivalents | |
for token in emoji_dict: | |
text = text.replace(token, emoji_dict[token]) | |
# Add dominant emotion | |
text = text + emotion_dict[dominant_emotion] | |
# Clean up emoji spacing | |
for emoji in emo_set.union(event_set): | |
text = text.replace(" " + emoji, emoji) | |
text = text.replace(emoji + " ", emoji) | |
return text.strip() | |
def format_text_advanced(text: str) -> str: | |
"""Advanced text formatting with multilingual and complex token handling""" | |
def get_emotion(text: str) -> Optional[str]: | |
return text[-1] if text[-1] in emo_set else None | |
def get_event(text: str) -> Optional[str]: | |
return text[0] if text[0] in event_set else None | |
# Handle special cases | |
text = text.replace("<|nospeech|><|Event_UNK|>", "โ") | |
for lang in lang_dict: | |
text = text.replace(lang, "<|lang|>") | |
# Process text segments | |
text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] | |
formatted_text = " " + text_segments[0] | |
current_event = get_event(formatted_text) | |
# Merge segments | |
for i in range(1, len(text_segments)): | |
if not text_segments[i]: | |
continue | |
if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: | |
text_segments[i] = text_segments[i][1:] | |
current_event = get_event(text_segments[i]) | |
if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): | |
formatted_text = formatted_text[:-1] | |
formatted_text += text_segments[i].strip() | |
formatted_text = formatted_text.replace("The.", " ") | |
return formatted_text.strip() | |
async def process_audio(audio_data: bytes, language: str = "auto") -> str: | |
"""Process audio data and return transcription result""" | |
try: | |
# Convert bytes to numpy array | |
audio_buffer = BytesIO(audio_data) | |
waveform, sample_rate = torchaudio.load(audio_buffer) | |
print(1, waveform.shape) | |
# Convert to mono channel | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0) | |
print(2, waveform.shape) | |
# Convert to numpy array and normalize | |
input_wav = waveform.numpy().astype(np.float32) | |
print(3, input_wav.shape) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy() | |
print(4, input_wav.shape) | |
# Model inference | |
text = model.generate( | |
input=input_wav, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# Format result | |
result = text[0]["text"] | |
result = format_text_advanced(result) | |
return result | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
traceback.print_stack() | |
raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: | |
"""Verify Bearer Token authentication""" | |
if credentials.credentials != API_TOKEN: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"} | |
) | |
return credentials | |
async def transcribe_audio( | |
file: UploadFile = File(...), | |
model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", | |
language: Optional[str] = "auto", | |
token: HTTPAuthorizationCredentials = Depends(verify_token) | |
) -> Dict[str, Union[str, int, float]]: | |
"""Audio transcription endpoint | |
Args: | |
file: Audio file (supports common audio formats) | |
model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall | |
language: Language code, supports auto/zh/en/yue/ja/ko/nospeech | |
Returns: | |
Dict[str, Union[str, int, float]]: { | |
"text": "Transcription result", | |
"error_code": 0, | |
"error_msg": "", | |
"process_time": 1.234 # Processing time in seconds | |
} | |
""" | |
start_time = time.time() | |
try: | |
# Validate file format | |
if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "Unsupported audio format", | |
"process_time": time.time() - start_time | |
} | |
# Validate model | |
if model != "FunAudioLLM/SenseVoiceSmall": | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "Unsupported model", | |
"process_time": time.time() - start_time | |
} | |
# Validate language | |
if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "Unsupported language", | |
"process_time": time.time() - start_time | |
} | |
# Process audio | |
content = await file.read() | |
text = await process_audio(content, language) | |
return { | |
"text": text, | |
"error_code": 0, | |
"error_msg": "", | |
"process_time": time.time() - start_time | |
} | |
except Exception as e: | |
return { | |
"text": "", | |
"error_code": 500, | |
"error_msg": str(e), | |
"process_time": time.time() - start_time | |
} | |
def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: | |
"""Gradio interface for audio transcription""" | |
try: | |
if audio is None: | |
return "Please upload an audio file" | |
# Extract audio data | |
sample_rate, input_wav = audio | |
# Normalize audio | |
input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max | |
# Convert to mono | |
if len(input_wav.shape) > 1: | |
input_wav = input_wav.mean(-1) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) | |
input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() | |
# Model inference | |
text = model.generate( | |
input=input_wav, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# Format result | |
result = text[0]["text"] | |
result = format_text_advanced(result) | |
return result | |
except Exception as e: | |
return f"Processing failed: {str(e)}" | |
# Create Gradio interface with localized labels | |
demo = gr.Interface( | |
fn=transcribe_audio_gradio, | |
inputs=[ | |
gr.Audio( | |
sources=["upload", "microphone"], | |
type="numpy", | |
label="Upload audio or record from microphone" | |
), | |
gr.Dropdown( | |
choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], | |
value="auto", | |
label="Select Language" | |
) | |
], | |
outputs=gr.Textbox(label="Recognition Result"), | |
title="SenseVoice Speech Recognition", | |
description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", | |
examples=[ | |
["examples/zh.mp3", "zh"], | |
["examples/en.mp3", "en"], | |
] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# Custom Swagger UI redirect | |
async def custom_swagger_ui_html(): | |
return HTMLResponse(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>SenseVoice API Documentation</title> | |
<meta http-equiv="refresh" content="0;url=/docs/" /> | |
</head> | |
<body> | |
<p>Redirecting to API documentation...</p> | |
</body> | |
</html> | |
""") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |