# coding=utf-8 from io import BytesIO from typing import Optional, Dict, Any, List, Set, Union, Tuple # System Libraries import os import time import asyncio # Third-party imports from fastapi import FastAPI, File, UploadFile, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import HTMLResponse import numpy as np import torch import torchaudio from funasr import AutoModel from dotenv import load_dotenv import os import time import gradio as gr # 加载环境变量 load_dotenv() # 获取API Token API_TOKEN: str = os.getenv("API_TOKEN") if not API_TOKEN: raise RuntimeError("API_TOKEN environment variable is not set") # 设置认证 security = HTTPBearer() app = FastAPI( title="SenseVoice API", description="Speech To Text API Service", version="1.0.0" ) # 允许跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 初始化模型 model = AutoModel( model="FunAudioLLM/SenseVoiceSmall", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_kwargs={"max_single_segment_time": 30000}, hub="hf", device="cuda" ) emotion_dict: Dict[str, str] = { "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", } event_dict: Dict[str, str] = { "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|Cry|>": "😭", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "🤧", } emoji_dict: Dict[str, str] = { "<|nospeech|><|Event_UNK|>": "❓", "<|zh|>": "", "<|en|>": "", "<|yue|>": "", "<|ja|>": "", "<|ko|>": "", "<|nospeech|>": "", "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", "<|Cry|>": "😭", "<|EMO_UNKNOWN|>": "", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "😷", "<|Sing|>": "", "<|Speech_Noise|>": "", "<|withitn|>": "", "<|woitn|>": "", "<|GBG|>": "", "<|Event_UNK|>": "", } lang_dict: Dict[str, str] = { "<|zh|>": "<|lang|>", "<|en|>": "<|lang|>", "<|yue|>": "<|lang|>", "<|ja|>": "<|lang|>", "<|ko|>": "<|lang|>", "<|nospeech|>": "<|lang|>", } emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"} event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"} def format_text_with_emotion(text: str) -> str: """Format text with emotion and event markers""" token_count: Dict[str, int] = {} original_text = text for token in emoji_dict: token_count[token] = text.count(token) # Determine dominant emotion dominant_emotion = "<|NEUTRAL|>" for emotion in emotion_dict: if token_count[emotion] > token_count[dominant_emotion]: dominant_emotion = emotion # Add event markers text = original_text for event in event_dict: if token_count[event] > 0: text = event_dict[event] + text # Replace all tokens with their emoji equivalents for token in emoji_dict: text = text.replace(token, emoji_dict[token]) # Add dominant emotion text = text + emotion_dict[dominant_emotion] # Clean up emoji spacing for emoji in emo_set.union(event_set): text = text.replace(" " + emoji, emoji) text = text.replace(emoji + " ", emoji) return text.strip() def format_text_advanced(text: str) -> str: """Advanced text formatting with multilingual and complex token handling""" def get_emotion(text: str) -> Optional[str]: return text[-1] if text[-1] in emo_set else None def get_event(text: str) -> Optional[str]: return text[0] if text[0] in event_set else None # Handle special cases text = text.replace("<|nospeech|><|Event_UNK|>", "❓") for lang in lang_dict: text = text.replace(lang, "<|lang|>") # Process text segments text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] formatted_text = " " + text_segments[0] current_event = get_event(formatted_text) # Merge segments for i in range(1, len(text_segments)): if not text_segments[i]: continue if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: text_segments[i] = text_segments[i][1:] current_event = get_event(text_segments[i]) if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): formatted_text = formatted_text[:-1] formatted_text += text_segments[i].strip() formatted_text = formatted_text.replace("The.", " ") return formatted_text.strip() async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str: """Process audio tensor and perform speech-to-text conversion. Args: audio: Input audio tensor sample_rate: Audio sample rate in Hz language: Target language code (auto/zh/en/yue/ja/ko/nospeech) Returns: str: Transcribed and formatted text result """ try: # Normalize if audio.dtype != torch.float32: if audio.dtype == torch.int16: audio = audio.float() / torch.iinfo(torch.int16).max elif audio.dtype == torch.int32: audio = audio.float() / torch.iinfo(torch.int32).max else: audio = audio.float() # Make sure audio in correct range if audio.abs().max() > 1.0: audio = audio / audio.abs().max() # Convert to mono channel if len(audio.shape) > 1: audio = audio.mean(dim=0) audio = audio.squeeze() # Resample if sample_rate != 16000: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 ) audio = resampler(audio.unsqueeze(0)).squeeze(0) text = model.generate( input=audio, cache={}, language=language, use_itn=True, batch_size_s=500, merge_vad=True ) # 格式化结果 result = text[0]["text"] return format_text_advanced(result) except Exception as e: raise HTTPException( status_code=500, detail=f"Audio processing failed in audio_stt: {str(e)}" ) async def process_audio(audio_data: bytes, language: str = "auto") -> str: """Process audio data and return transcription result. Args: audio_data: Raw audio data in bytes language: Target language code Returns: str: Transcribed and formatted text Raises: HTTPException: If audio processing fails """ try: audio_buffer = BytesIO(audio_data) waveform, sample_rate = torchaudio.load( uri=audio_buffer, normalize=True, channels_first=True ) result = await audio_stt(waveform, sample_rate, language) return result except Exception as e: raise HTTPException( status_code=500, detail=f"Audio processing failed: {str(e)}" ) async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: """Verify Bearer Token authentication""" if credentials.credentials != API_TOKEN: raise HTTPException( status_code=401, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"} ) return credentials @app.post("/v1/audio/transcriptions") async def transcribe_audio( file: UploadFile = File(...), model: str = "FunAudioLLM/SenseVoiceSmall", language: str = "auto", token: HTTPAuthorizationCredentials = Depends(verify_token) ) -> Dict[str, Union[str, int, float]]: """Audio transcription endpoint. Args: file: Audio file (supports mp3, wav, flac, ogg, m4a) model: Model name language: Language code token: Authentication token Returns: Dict containing transcription result and metadata """ start_time = time.time() try: # Check the file format if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): return { "text": "", "error_code": 400, "error_msg": "不支持的音频格式", "process_time": time.time() - start_time } # Check the model if model != "FunAudioLLM/SenseVoiceSmall": return { "text": "", "error_code": 400, "error_msg": "不支持的模型", "process_time": time.time() - start_time } # Check the language if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: return { "text": "", "error_code": 400, "error_msg": "不支持的语言", "process_time": time.time() - start_time } # STT content = await file.read() text = await process_audio(content, language) return { "text": text, "error_code": 0, "error_msg": "", "process_time": time.time() - start_time } except Exception as e: return { "text": "", "error_code": 500, "error_msg": str(e), "process_time": time.time() - start_time } def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: """Gradio interface for audio transcription""" try: if audio is None: return "Please upload an audio file" # Extract audio data sample_rate, input_wav = audio # Normalize audio input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max # Model Inference input_wav = torch.from_numpy(input_wav) result = asyncio.run(audio_stt(input_wav, sample_rate, language)) return result except Exception as e: return f"Processing failed: {str(e)}" # Create Gradio interface with localized labels demo = gr.Interface( fn=transcribe_audio_gradio, inputs=[ gr.Audio( sources=["upload", "microphone"], type="numpy", label="Upload audio or record from microphone" ), gr.Dropdown( choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], value="auto", label="Select Language" ) ], outputs=gr.Textbox(label="Recognition Result"), title="SenseVoice Speech Recognition", description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", examples=[ ["examples/zh.mp3", "zh"], ["examples/en.mp3", "en"], ] ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") # Custom Swagger UI redirect @app.get("/docs", include_in_schema=False) async def custom_swagger_ui_html(): return HTMLResponse("""
Redirecting to API documentation...
""") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)