# coding=utf-8 from io import BytesIO from typing import Optional, Dict, Any, List, Set, Union, Tuple import os import time # Third-party imports from fastapi import FastAPI, File, UploadFile, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import HTMLResponse import numpy as np import torch import torchaudio from funasr import AutoModel from dotenv import load_dotenv import os import time import gradio as gr # 加载环境变量 load_dotenv() # 获取API Token API_TOKEN: str = os.getenv("API_TOKEN") if not API_TOKEN: raise RuntimeError("API_TOKEN environment variable is not set") # 设置认证 security = HTTPBearer() app = FastAPI( title="SenseVoice API", description="语音识别 API 服务", version="1.0.0" ) # 允许跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 初始化模型 model = AutoModel( model="FunAudioLLM/SenseVoiceSmall", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_kwargs={"max_single_segment_time": 30000}, hub="hf", device="cuda" ) # 复用原有的格式化函数 emotion_dict: Dict[str, str] = { "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", } event_dict: Dict[str, str] = { "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|Cry|>": "😭", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "🤧", } emoji_dict: Dict[str, str] = { "<|nospeech|><|Event_UNK|>": "❓", "<|zh|>": "", "<|en|>": "", "<|yue|>": "", "<|ja|>": "", "<|ko|>": "", "<|nospeech|>": "", "<|HAPPY|>": "😊", "<|SAD|>": "😔", "<|ANGRY|>": "😡", "<|NEUTRAL|>": "", "<|BGM|>": "🎼", "<|Speech|>": "", "<|Applause|>": "👏", "<|Laughter|>": "😀", "<|FEARFUL|>": "😰", "<|DISGUSTED|>": "🤢", "<|SURPRISED|>": "😮", "<|Cry|>": "😭", "<|EMO_UNKNOWN|>": "", "<|Sneeze|>": "🤧", "<|Breath|>": "", "<|Cough|>": "😷", "<|Sing|>": "", "<|Speech_Noise|>": "", "<|withitn|>": "", "<|woitn|>": "", "<|GBG|>": "", "<|Event_UNK|>": "", } lang_dict: Dict[str, str] = { "<|zh|>": "<|lang|>", "<|en|>": "<|lang|>", "<|yue|>": "<|lang|>", "<|ja|>": "<|lang|>", "<|ko|>": "<|lang|>", "<|nospeech|>": "<|lang|>", } emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"} event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"} def format_text_basic(text: str) -> str: """Replace special tokens with corresponding emojis""" for token in emoji_dict: text = text.replace(token, emoji_dict[token]) return text def format_text_with_emotion(text: str) -> str: """Format text with emotion and event markers""" token_count: Dict[str, int] = {} original_text = text for token in emoji_dict: token_count[token] = text.count(token) # Determine dominant emotion dominant_emotion = "<|NEUTRAL|>" for emotion in emotion_dict: if token_count[emotion] > token_count[dominant_emotion]: dominant_emotion = emotion # Add event markers text = original_text for event in event_dict: if token_count[event] > 0: text = event_dict[event] + text # Replace all tokens with their emoji equivalents for token in emoji_dict: text = text.replace(token, emoji_dict[token]) # Add dominant emotion text = text + emotion_dict[dominant_emotion] # Clean up emoji spacing for emoji in emo_set.union(event_set): text = text.replace(" " + emoji, emoji) text = text.replace(emoji + " ", emoji) return text.strip() def format_text_advanced(text: str) -> str: """Advanced text formatting with multilingual and complex token handling""" def get_emotion(text: str) -> Optional[str]: return text[-1] if text[-1] in emo_set else None def get_event(text: str) -> Optional[str]: return text[0] if text[0] in event_set else None # Handle special cases text = text.replace("<|nospeech|><|Event_UNK|>", "❓") for lang in lang_dict: text = text.replace(lang, "<|lang|>") # Process text segments text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] formatted_text = " " + text_segments[0] current_event = get_event(formatted_text) # Merge segments for i in range(1, len(text_segments)): if not text_segments[i]: continue if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: text_segments[i] = text_segments[i][1:] current_event = get_event(text_segments[i]) if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): formatted_text = formatted_text[:-1] formatted_text += text_segments[i].strip() formatted_text = formatted_text.replace("The.", " ") return formatted_text.strip() async def audio_stt(audio: np.ndarray, sample_rate: int, language: str = "auto") -> str: # Step 01. Normalize & Resample input_wav = audio.astype(np.float32) / np.iinfo(np.int16).max # Step 02. Convert audio to mono channel if len(input_wav.shape) > 1: input_wav = input_wav.mean(-1) # Step 03. Resample to 16kHz resampler = torchaudio.transforms.Resample(sample_rate, 16000) input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() # Step 04. Model Inference text = model.generate( input=input_wav, cache={}, language=language, use_itn=True, batch_size_s=500, merge_vad=True ) # Step 05. Format Result result = text[0]["text"] result = format_text_advanced(result) return result async def process_audio(audio_data: bytes, language: str = "auto") -> str: """Process audio data and return transcription result""" try: # Convert bytes to numpy array audio_buffer = BytesIO(audio_data) waveform, sample_rate = torchaudio.load(audio_buffer) result = audio_stt(waveform, sample_rate, language) return result except Exception as e: import traceback traceback.print_exc() traceback.print_stack() raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: """Verify Bearer Token authentication""" if credentials.credentials != API_TOKEN: raise HTTPException( status_code=401, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"} ) return credentials @app.post("/v1/audio/transcriptions") async def transcribe_audio( file: UploadFile = File(...), model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", language: Optional[str] = "auto", token: HTTPAuthorizationCredentials = Depends(verify_token) ) -> Dict[str, Union[str, int, float]]: """Audio transcription endpoint Args: file: Audio file (supports common audio formats) model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall language: Language code, supports auto/zh/en/yue/ja/ko/nospeech Returns: Dict[str, Union[str, int, float]]: { "text": "Transcription result", "error_code": 0, "error_msg": "", "process_time": 1.234 # Processing time in seconds } """ start_time = time.time() try: # Validate file format if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): return { "text": "", "error_code": 400, "error_msg": "Unsupported audio format", "process_time": time.time() - start_time } # Validate model if model != "FunAudioLLM/SenseVoiceSmall": return { "text": "", "error_code": 400, "error_msg": "Unsupported model", "process_time": time.time() - start_time } # Validate language if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: return { "text": "", "error_code": 400, "error_msg": "Unsupported language", "process_time": time.time() - start_time } # Process audio content = await file.read() text = await process_audio(content, language) return { "text": text, "error_code": 0, "error_msg": "", "process_time": time.time() - start_time } except Exception as e: return { "text": "", "error_code": 500, "error_msg": str(e), "process_time": time.time() - start_time } def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: """Gradio interface for audio transcription""" try: if audio is None: return "Please upload an audio file" # Extract audio data sample_rate, input_wav = audio # Normalize audio input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max # Convert to mono if len(input_wav.shape) > 1: input_wav = input_wav.mean(-1) # Resample to 16kHz if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() # Model inference text = model.generate( input=input_wav, cache={}, language=language, use_itn=True, batch_size_s=500, merge_vad=True ) # Format result result = text[0]["text"] result = format_text_advanced(result) return result except Exception as e: return f"Processing failed: {str(e)}" # Create Gradio interface with localized labels demo = gr.Interface( fn=transcribe_audio_gradio, inputs=[ gr.Audio( sources=["upload", "microphone"], type="numpy", label="Upload audio or record from microphone" ), gr.Dropdown( choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], value="auto", label="Select Language" ) ], outputs=gr.Textbox(label="Recognition Result"), title="SenseVoice Speech Recognition", description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", examples=[ ["examples/zh.mp3", "zh"], ["examples/en.mp3", "en"], ] ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") # Custom Swagger UI redirect @app.get("/docs", include_in_schema=False) async def custom_swagger_ui_html(): return HTMLResponse("""
Redirecting to API documentation...
""") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)