Spaces:
Running
Running
# coding=utf-8 | |
from io import BytesIO | |
from typing import Optional, Dict, Any, List, Set, Union, Tuple | |
# System Libraries | |
import os | |
import time | |
import asyncio | |
# Third-party imports | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.responses import HTMLResponse | |
import numpy as np | |
import torch | |
import torchaudio | |
from funasr import AutoModel | |
from dotenv import load_dotenv | |
import os | |
import time | |
import gradio as gr | |
# ๅ ่ฝฝ็ฏๅขๅ้ | |
load_dotenv() | |
# ่ทๅAPI Token | |
API_TOKEN: str = os.getenv("API_TOKEN") | |
if not API_TOKEN: | |
raise RuntimeError("API_TOKEN environment variable is not set") | |
# ่ฎพ็ฝฎ่ฎค่ฏ | |
security = HTTPBearer() | |
app = FastAPI( | |
title="SenseVoice API", | |
description="Speech To Text API Service", | |
version="1.0.0" | |
) | |
# ๅ ่ฎธ่ทจๅ่ฏทๆฑ | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ๅๅงๅๆจกๅ | |
model = AutoModel( | |
model="FunAudioLLM/SenseVoiceSmall", | |
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
vad_kwargs={"max_single_segment_time": 30000}, | |
hub="hf", | |
device="cuda" | |
) | |
emotion_dict: Dict[str, str] = { | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
} | |
event_dict: Dict[str, str] = { | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|Cry|>": "๐ญ", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐คง", | |
} | |
emoji_dict: Dict[str, str] = { | |
"<|nospeech|><|Event_UNK|>": "โ", | |
"<|zh|>": "", | |
"<|en|>": "", | |
"<|yue|>": "", | |
"<|ja|>": "", | |
"<|ko|>": "", | |
"<|nospeech|>": "", | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
"<|Cry|>": "๐ญ", | |
"<|EMO_UNKNOWN|>": "", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐ท", | |
"<|Sing|>": "", | |
"<|Speech_Noise|>": "", | |
"<|withitn|>": "", | |
"<|woitn|>": "", | |
"<|GBG|>": "", | |
"<|Event_UNK|>": "", | |
} | |
lang_dict: Dict[str, str] = { | |
"<|zh|>": "<|lang|>", | |
"<|en|>": "<|lang|>", | |
"<|yue|>": "<|lang|>", | |
"<|ja|>": "<|lang|>", | |
"<|ko|>": "<|lang|>", | |
"<|nospeech|>": "<|lang|>", | |
} | |
emo_set: Set[str] = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} | |
event_set: Set[str] = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} | |
def format_text_with_emotion(text: str) -> str: | |
"""Format text with emotion and event markers""" | |
token_count: Dict[str, int] = {} | |
original_text = text | |
for token in emoji_dict: | |
token_count[token] = text.count(token) | |
# Determine dominant emotion | |
dominant_emotion = "<|NEUTRAL|>" | |
for emotion in emotion_dict: | |
if token_count[emotion] > token_count[dominant_emotion]: | |
dominant_emotion = emotion | |
# Add event markers | |
text = original_text | |
for event in event_dict: | |
if token_count[event] > 0: | |
text = event_dict[event] + text | |
# Replace all tokens with their emoji equivalents | |
for token in emoji_dict: | |
text = text.replace(token, emoji_dict[token]) | |
# Add dominant emotion | |
text = text + emotion_dict[dominant_emotion] | |
# Clean up emoji spacing | |
for emoji in emo_set.union(event_set): | |
text = text.replace(" " + emoji, emoji) | |
text = text.replace(emoji + " ", emoji) | |
return text.strip() | |
def format_text_advanced(text: str) -> str: | |
"""Advanced text formatting with multilingual and complex token handling""" | |
def get_emotion(text: str) -> Optional[str]: | |
return text[-1] if text[-1] in emo_set else None | |
def get_event(text: str) -> Optional[str]: | |
return text[0] if text[0] in event_set else None | |
# Handle special cases | |
text = text.replace("<|nospeech|><|Event_UNK|>", "โ") | |
for lang in lang_dict: | |
text = text.replace(lang, "<|lang|>") | |
# Process text segments | |
text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] | |
formatted_text = " " + text_segments[0] | |
current_event = get_event(formatted_text) | |
# Merge segments | |
for i in range(1, len(text_segments)): | |
if not text_segments[i]: | |
continue | |
if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: | |
text_segments[i] = text_segments[i][1:] | |
current_event = get_event(text_segments[i]) | |
if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): | |
formatted_text = formatted_text[:-1] | |
formatted_text += text_segments[i].strip() | |
formatted_text = formatted_text.replace("The.", " ") | |
return formatted_text.strip() | |
async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str: | |
"""Process audio tensor and perform speech-to-text conversion. | |
Args: | |
audio: Input audio tensor | |
sample_rate: Audio sample rate in Hz | |
language: Target language code (auto/zh/en/yue/ja/ko/nospeech) | |
Returns: | |
str: Transcribed and formatted text result | |
""" | |
try: | |
# Normalize | |
if audio.dtype != torch.float32: | |
if audio.dtype == torch.int16: | |
audio = audio.float() / torch.iinfo(torch.int16).max | |
elif audio.dtype == torch.int32: | |
audio = audio.float() / torch.iinfo(torch.int32).max | |
else: | |
audio = audio.float() | |
# Make sure audio in correct range | |
if audio.abs().max() > 1.0: | |
audio = audio / audio.abs().max() | |
# Convert to mono channel | |
if len(audio.shape) > 1: | |
audio = audio.mean(dim=0) | |
audio = audio.squeeze() | |
# Resample | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, | |
new_freq=16000 | |
) | |
audio = resampler(audio.unsqueeze(0)).squeeze(0) | |
text = model.generate( | |
input=audio, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True | |
) | |
# ๆ ผๅผๅ็ปๆ | |
result = text[0]["text"] | |
return format_text_advanced(result) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Audio processing failed in audio_stt: {str(e)}" | |
) | |
async def process_audio(audio_data: bytes, language: str = "auto") -> str: | |
"""Process audio data and return transcription result. | |
Args: | |
audio_data: Raw audio data in bytes | |
language: Target language code | |
Returns: | |
str: Transcribed and formatted text | |
Raises: | |
HTTPException: If audio processing fails | |
""" | |
try: | |
audio_buffer = BytesIO(audio_data) | |
waveform, sample_rate = torchaudio.load( | |
uri=audio_buffer, | |
normalize=True, | |
channels_first=True | |
) | |
result = await audio_stt(waveform, sample_rate, language) | |
return result | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Audio processing failed: {str(e)}" | |
) | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: | |
"""Verify Bearer Token authentication""" | |
if credentials.credentials != API_TOKEN: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"} | |
) | |
return credentials | |
async def transcribe_audio( | |
file: UploadFile = File(...), | |
model: str = "FunAudioLLM/SenseVoiceSmall", | |
language: str = "auto", | |
token: HTTPAuthorizationCredentials = Depends(verify_token) | |
) -> Dict[str, Union[str, int, float]]: | |
"""Audio transcription endpoint. | |
Args: | |
file: Audio file (supports mp3, wav, flac, ogg, m4a) | |
model: Model name | |
language: Language code | |
token: Authentication token | |
Returns: | |
Dict containing transcription result and metadata | |
""" | |
start_time = time.time() | |
try: | |
# Check the file format | |
if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "ไธๆฏๆ็้ณ้ขๆ ผๅผ", | |
"process_time": time.time() - start_time | |
} | |
# Check the model | |
if model != "FunAudioLLM/SenseVoiceSmall": | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "ไธๆฏๆ็ๆจกๅ", | |
"process_time": time.time() - start_time | |
} | |
# Check the language | |
if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: | |
return { | |
"text": "", | |
"error_code": 400, | |
"error_msg": "ไธๆฏๆ็่ฏญ่จ", | |
"process_time": time.time() - start_time | |
} | |
# STT | |
content = await file.read() | |
text = await process_audio(content, language) | |
return { | |
"text": text, | |
"error_code": 0, | |
"error_msg": "", | |
"process_time": time.time() - start_time | |
} | |
except Exception as e: | |
return { | |
"text": "", | |
"error_code": 500, | |
"error_msg": str(e), | |
"process_time": time.time() - start_time | |
} | |
def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: | |
"""Gradio interface for audio transcription""" | |
try: | |
if audio is None: | |
return "Please upload an audio file" | |
# Extract audio data | |
sample_rate, input_wav = audio | |
# Normalize audio | |
input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max | |
# Model Inference | |
input_wav = torch.from_numpy(input_wav) | |
result = asyncio.run(audio_stt(input_wav, sample_rate, language)) | |
return result | |
except Exception as e: | |
return f"Processing failed: {str(e)}" | |
# Create Gradio interface with localized labels | |
demo = gr.Interface( | |
fn=transcribe_audio_gradio, | |
inputs=[ | |
gr.Audio( | |
sources=["upload", "microphone"], | |
type="numpy", | |
label="Upload audio or record from microphone" | |
), | |
gr.Dropdown( | |
choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], | |
value="auto", | |
label="Select Language" | |
) | |
], | |
outputs=gr.Textbox(label="Recognition Result"), | |
title="SenseVoice Speech Recognition", | |
description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", | |
examples=[ | |
["examples/zh.mp3", "zh"], | |
["examples/en.mp3", "en"], | |
] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# Custom Swagger UI redirect | |
async def custom_swagger_ui_html(): | |
return HTMLResponse(""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>SenseVoice API Documentation</title> | |
<meta http-equiv="refresh" content="0;url=/docs/" /> | |
</head> | |
<body> | |
<p>Redirecting to API documentation...</p> | |
</body> | |
</html> | |
""") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |