megatrump commited on
Commit
60b9834
·
1 Parent(s): 2711de6

统一了代码风格

Browse files
Files changed (1) hide show
  1. api.py +131 -100
api.py CHANGED
@@ -1,8 +1,11 @@
1
  # coding=utf-8
2
 
3
  from io import BytesIO
4
- from typing import Optional
 
 
5
 
 
6
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -20,7 +23,7 @@ import gradio as gr
20
  load_dotenv()
21
 
22
  # 获取API Token
23
- API_TOKEN = os.getenv("API_TOKEN")
24
  if not API_TOKEN:
25
  raise RuntimeError("API_TOKEN environment variable is not set")
26
 
@@ -52,7 +55,7 @@ model = AutoModel(
52
  )
53
 
54
  # 复用原有的格式化函数
55
- emo_dict = {
56
  "<|HAPPY|>": "😊",
57
  "<|SAD|>": "😔",
58
  "<|ANGRY|>": "😡",
@@ -62,7 +65,7 @@ emo_dict = {
62
  "<|SURPRISED|>": "😮",
63
  }
64
 
65
- event_dict = {
66
  "<|BGM|>": "🎼",
67
  "<|Speech|>": "",
68
  "<|Applause|>": "👏",
@@ -73,7 +76,7 @@ event_dict = {
73
  "<|Cough|>": "🤧",
74
  }
75
 
76
- emoji_dict = {
77
  "<|nospeech|><|Event_UNK|>": "❓",
78
  "<|zh|>": "",
79
  "<|en|>": "",
@@ -105,7 +108,7 @@ emoji_dict = {
105
  "<|Event_UNK|>": "",
106
  }
107
 
108
- lang_dict = {
109
  "<|zh|>": "<|lang|>",
110
  "<|en|>": "<|lang|>",
111
  "<|yue|>": "<|lang|>",
@@ -114,82 +117,105 @@ lang_dict = {
114
  "<|nospeech|>": "<|lang|>",
115
  }
116
 
117
- emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
118
- event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷"}
119
 
120
 
121
- def format_str(s):
122
- for sptk in emoji_dict:
123
- s = s.replace(sptk, emoji_dict[sptk])
124
- return s
 
125
 
126
 
127
- def format_str_v2(s):
128
- sptk_dict = {}
129
- for sptk in emoji_dict:
130
- sptk_dict[sptk] = s.count(sptk)
131
- s = s.replace(sptk, "")
132
- emo = "<|NEUTRAL|>"
133
- for e in emo_dict:
134
- if sptk_dict[e] > sptk_dict[emo]:
135
- emo = e
136
- for e in event_dict:
137
- if sptk_dict[e] > 0:
138
- s = event_dict[e] + s
139
- s = s + emo_dict[emo]
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
141
  for emoji in emo_set.union(event_set):
142
- s = s.replace(" " + emoji, emoji)
143
- s = s.replace(emoji + " ", emoji)
144
- return s.strip()
145
 
146
 
147
- def format_str_v3(s):
148
- def get_emo(s):
149
- return s[-1] if s[-1] in emo_set else None
 
150
 
151
- def get_event(s):
152
- return s[0] if s[0] in event_set else None
153
 
154
- s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
 
155
  for lang in lang_dict:
156
- s = s.replace(lang, "<|lang|>")
157
- s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
158
- new_s = " " + s_list[0]
159
- cur_ent_event = get_event(new_s)
160
- for i in range(1, len(s_list)):
161
- if len(s_list[i]) == 0:
 
 
 
 
162
  continue
163
- if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
164
- s_list[i] = s_list[i][1:]
165
- cur_ent_event = get_event(s_list[i])
166
- if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
167
- new_s = new_s[:-1]
168
- new_s += s_list[i].strip().lstrip()
169
- new_s = new_s.replace("The.", " ")
170
- return new_s.strip()
 
 
 
171
 
172
 
173
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
174
- """处理音频数据并返回识别结果"""
175
  try:
176
- # 将字节数据转换为 numpy 数组
177
  audio_buffer = BytesIO(audio_data)
178
  waveform, sample_rate = torchaudio.load(audio_buffer)
179
 
180
- # 转换为单声道
181
  if waveform.shape[0] > 1:
182
  waveform = waveform.mean(dim=0)
183
 
184
- # 转换为 numpy array 并归一化
185
  input_wav = waveform.numpy().astype(np.float32)
186
 
187
- # 重采样到 16kHz
188
  if sample_rate != 16000:
189
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
190
  input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy()
191
 
192
- # 模型推理
193
  text = model.generate(
194
  input=input_wav,
195
  cache={},
@@ -199,18 +225,18 @@ async def process_audio(audio_data: bytes, language: str = "auto") -> str:
199
  merge_vad=True
200
  )
201
 
202
- # 格式化结果
203
  result = text[0]["text"]
204
- result = format_str_v3(result)
205
 
206
  return result
207
 
208
  except Exception as e:
209
- raise HTTPException(status_code=500, detail=f"音频处理失败:{str(e)}")
210
 
211
 
212
- async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
213
- """验证Bearer Token"""
214
  if credentials.credentials != API_TOKEN:
215
  raise HTTPException(
216
  status_code=401,
@@ -225,49 +251,53 @@ async def transcribe_audio(
225
  model: Optional[str] = "FunAudioLLM/SenseVoiceSmall",
226
  language: Optional[str] = "auto",
227
  token: HTTPAuthorizationCredentials = Depends(verify_token)
228
- ):
229
- """音频转写接口
230
 
231
  Args:
232
- file: 音频文件(支持常见音频格式)
233
- model: 模型名称,目前仅支持 FunAudioLLM/SenseVoiceSmall
234
- language: 语言代码,支持 auto/zh/en/yue/ja/ko/nospeech
235
 
236
  Returns:
237
- {
238
- "text": "识别结果",
239
  "error_code": 0,
240
  "error_msg": "",
241
- "process_time": 1.234 # 处理时间(秒)
242
  }
243
  """
244
  start_time = time.time()
245
 
246
  try:
 
247
  if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
248
  return {
249
  "text": "",
250
  "error_code": 400,
251
- "error_msg": "不支持的音频格式",
252
  "process_time": time.time() - start_time
253
  }
254
 
 
255
  if model != "FunAudioLLM/SenseVoiceSmall":
256
  return {
257
  "text": "",
258
  "error_code": 400,
259
- "error_msg": "不支持的模型",
260
  "process_time": time.time() - start_time
261
  }
262
 
 
263
  if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]:
264
  return {
265
  "text": "",
266
  "error_code": 400,
267
- "error_msg": "不支持的语言",
268
  "process_time": time.time() - start_time
269
  }
270
 
 
271
  content = await file.read()
272
  text = await process_audio(content, language)
273
 
@@ -287,33 +317,29 @@ async def transcribe_audio(
287
  }
288
 
289
 
290
- def transcribe_audio_gradio(audio, language="auto"):
291
- """Gradio界面的音频转写函数"""
292
  try:
293
  if audio is None:
294
- return "请上传音频文件"
295
 
296
- # 读取音频数据
297
- fs, input_wav = audio
298
-
299
- print('------------------------------')
300
- print(fs, type(fs))
301
- print(input_wav, type(input_wav))
302
- print('------------------------------')
303
-
304
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
305
 
306
- # 转换为单声道
307
  if len(input_wav.shape) > 1:
308
  input_wav = input_wav.mean(-1)
309
 
310
- # 重采样到16kHz
311
- if fs != 16000:
312
- resampler = torchaudio.transforms.Resample(fs, 16000)
313
- input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
314
- input_wav = resampler(input_wav_t[None, :])[0, :].numpy()
315
 
316
- # 模型推理
317
  text = model.generate(
318
  input=input_wav,
319
  cache={},
@@ -323,48 +349,53 @@ def transcribe_audio_gradio(audio, language="auto"):
323
  merge_vad=True
324
  )
325
 
326
- # 格式化结果
327
  result = text[0]["text"]
328
- result = format_str_v3(result)
329
 
330
  return result
331
  except Exception as e:
332
- return f"处理失败:{str(e)}"
333
 
334
- # 创建Gradio界面
335
  demo = gr.Interface(
336
  fn=transcribe_audio_gradio,
337
  inputs=[
338
- gr.Audio(sources=["upload", "microphone", ], type="numpy", label="上传音频或使用麦克风录音"),
 
 
 
 
339
  gr.Dropdown(
340
  choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"],
341
  value="auto",
342
- label="选择语言"
343
  )
344
  ],
345
- outputs=gr.Textbox(label="识别结果"),
346
- title="SenseVoice 语音识别",
347
- description="支持中文、英语、粤语、日语、韩语等多种语言的语音转写服务",
348
  examples=[
349
  ["examples/zh.mp3", "zh"],
350
  ["examples/en.mp3", "en"],
351
  ]
352
  )
353
 
354
- # Gradio应用挂载到FastAPI
355
  app = gr.mount_gradio_app(app, demo, path="/")
356
 
 
357
  @app.get("/docs", include_in_schema=False)
358
  async def custom_swagger_ui_html():
359
  return HTMLResponse("""
360
  <!DOCTYPE html>
361
  <html>
362
  <head>
363
- <title>SenseVoice API 文档</title>
364
  <meta http-equiv="refresh" content="0;url=/docs/" />
365
  </head>
366
  <body>
367
- <p>正在跳转到API文档...</p>
368
  </body>
369
  </html>
370
  """)
 
1
  # coding=utf-8
2
 
3
  from io import BytesIO
4
+ from typing import Optional, Dict, Any, List, Set, Union, Tuple
5
+ import os
6
+ import time
7
 
8
+ # Third-party imports
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
23
  load_dotenv()
24
 
25
  # 获取API Token
26
+ API_TOKEN: str = os.getenv("API_TOKEN")
27
  if not API_TOKEN:
28
  raise RuntimeError("API_TOKEN environment variable is not set")
29
 
 
55
  )
56
 
57
  # 复用原有的格式化函数
58
+ emotion_dict: Dict[str, str] = {
59
  "<|HAPPY|>": "😊",
60
  "<|SAD|>": "😔",
61
  "<|ANGRY|>": "😡",
 
65
  "<|SURPRISED|>": "😮",
66
  }
67
 
68
+ event_dict: Dict[str, str] = {
69
  "<|BGM|>": "🎼",
70
  "<|Speech|>": "",
71
  "<|Applause|>": "👏",
 
76
  "<|Cough|>": "🤧",
77
  }
78
 
79
+ emoji_dict: Dict[str, str] = {
80
  "<|nospeech|><|Event_UNK|>": "❓",
81
  "<|zh|>": "",
82
  "<|en|>": "",
 
108
  "<|Event_UNK|>": "",
109
  }
110
 
111
+ lang_dict: Dict[str, str] = {
112
  "<|zh|>": "<|lang|>",
113
  "<|en|>": "<|lang|>",
114
  "<|yue|>": "<|lang|>",
 
117
  "<|nospeech|>": "<|lang|>",
118
  }
119
 
120
+ emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"}
121
+ event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"}
122
 
123
 
124
+ def format_text_basic(text: str) -> str:
125
+ """Replace special tokens with corresponding emojis"""
126
+ for token in emoji_dict:
127
+ text = text.replace(token, emoji_dict[token])
128
+ return text
129
 
130
 
131
+ def format_text_with_emotion(text: str) -> str:
132
+ """Format text with emotion and event markers"""
133
+ token_count: Dict[str, int] = {}
134
+ original_text = text
135
+ for token in emoji_dict:
136
+ token_count[token] = text.count(token)
137
+
138
+ # Determine dominant emotion
139
+ dominant_emotion = "<|NEUTRAL|>"
140
+ for emotion in emotion_dict:
141
+ if token_count[emotion] > token_count[dominant_emotion]:
142
+ dominant_emotion = emotion
143
+
144
+ # Add event markers
145
+ text = original_text
146
+ for event in event_dict:
147
+ if token_count[event] > 0:
148
+ text = event_dict[event] + text
149
+
150
+ # Replace all tokens with their emoji equivalents
151
+ for token in emoji_dict:
152
+ text = text.replace(token, emoji_dict[token])
153
+
154
+ # Add dominant emotion
155
+ text = text + emotion_dict[dominant_emotion]
156
 
157
+ # Clean up emoji spacing
158
  for emoji in emo_set.union(event_set):
159
+ text = text.replace(" " + emoji, emoji)
160
+ text = text.replace(emoji + " ", emoji)
161
+ return text.strip()
162
 
163
 
164
+ def format_text_advanced(text: str) -> str:
165
+ """Advanced text formatting with multilingual and complex token handling"""
166
+ def get_emotion(text: str) -> Optional[str]:
167
+ return text[-1] if text[-1] in emo_set else None
168
 
169
+ def get_event(text: str) -> Optional[str]:
170
+ return text[0] if text[0] in event_set else None
171
 
172
+ # Handle special cases
173
+ text = text.replace("<|nospeech|><|Event_UNK|>", "❓")
174
  for lang in lang_dict:
175
+ text = text.replace(lang, "<|lang|>")
176
+
177
+ # Process text segments
178
+ text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")]
179
+ formatted_text = " " + text_segments[0]
180
+ current_event = get_event(formatted_text)
181
+
182
+ # Merge segments
183
+ for i in range(1, len(text_segments)):
184
+ if not text_segments[i]:
185
  continue
186
+
187
+ if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None:
188
+ text_segments[i] = text_segments[i][1:]
189
+ current_event = get_event(text_segments[i])
190
+
191
+ if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text):
192
+ formatted_text = formatted_text[:-1]
193
+ formatted_text += text_segments[i].strip()
194
+
195
+ formatted_text = formatted_text.replace("The.", " ")
196
+ return formatted_text.strip()
197
 
198
 
199
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
200
+ """Process audio data and return transcription result"""
201
  try:
202
+ # Convert bytes to numpy array
203
  audio_buffer = BytesIO(audio_data)
204
  waveform, sample_rate = torchaudio.load(audio_buffer)
205
 
206
+ # Convert to mono channel
207
  if waveform.shape[0] > 1:
208
  waveform = waveform.mean(dim=0)
209
 
210
+ # Convert to numpy array and normalize
211
  input_wav = waveform.numpy().astype(np.float32)
212
 
213
+ # Resample to 16kHz if needed
214
  if sample_rate != 16000:
215
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
216
  input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy()
217
 
218
+ # Model inference
219
  text = model.generate(
220
  input=input_wav,
221
  cache={},
 
225
  merge_vad=True
226
  )
227
 
228
+ # Format result
229
  result = text[0]["text"]
230
+ result = format_text_advanced(result)
231
 
232
  return result
233
 
234
  except Exception as e:
235
+ raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
236
 
237
 
238
+ async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials:
239
+ """Verify Bearer Token authentication"""
240
  if credentials.credentials != API_TOKEN:
241
  raise HTTPException(
242
  status_code=401,
 
251
  model: Optional[str] = "FunAudioLLM/SenseVoiceSmall",
252
  language: Optional[str] = "auto",
253
  token: HTTPAuthorizationCredentials = Depends(verify_token)
254
+ ) -> Dict[str, Union[str, int, float]]:
255
+ """Audio transcription endpoint
256
 
257
  Args:
258
+ file: Audio file (supports common audio formats)
259
+ model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall
260
+ language: Language code, supports auto/zh/en/yue/ja/ko/nospeech
261
 
262
  Returns:
263
+ Dict[str, Union[str, int, float]]: {
264
+ "text": "Transcription result",
265
  "error_code": 0,
266
  "error_msg": "",
267
+ "process_time": 1.234 # Processing time in seconds
268
  }
269
  """
270
  start_time = time.time()
271
 
272
  try:
273
+ # Validate file format
274
  if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
275
  return {
276
  "text": "",
277
  "error_code": 400,
278
+ "error_msg": "Unsupported audio format",
279
  "process_time": time.time() - start_time
280
  }
281
 
282
+ # Validate model
283
  if model != "FunAudioLLM/SenseVoiceSmall":
284
  return {
285
  "text": "",
286
  "error_code": 400,
287
+ "error_msg": "Unsupported model",
288
  "process_time": time.time() - start_time
289
  }
290
 
291
+ # Validate language
292
  if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]:
293
  return {
294
  "text": "",
295
  "error_code": 400,
296
+ "error_msg": "Unsupported language",
297
  "process_time": time.time() - start_time
298
  }
299
 
300
+ # Process audio
301
  content = await file.read()
302
  text = await process_audio(content, language)
303
 
 
317
  }
318
 
319
 
320
+ def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str:
321
+ """Gradio interface for audio transcription"""
322
  try:
323
  if audio is None:
324
+ return "Please upload an audio file"
325
 
326
+ # Extract audio data
327
+ sample_rate, input_wav = audio
328
+
329
+ # Normalize audio
 
 
 
 
330
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
331
 
332
+ # Convert to mono
333
  if len(input_wav.shape) > 1:
334
  input_wav = input_wav.mean(-1)
335
 
336
+ # Resample to 16kHz if needed
337
+ if sample_rate != 16000:
338
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
339
+ input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32)
340
+ input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy()
341
 
342
+ # Model inference
343
  text = model.generate(
344
  input=input_wav,
345
  cache={},
 
349
  merge_vad=True
350
  )
351
 
352
+ # Format result
353
  result = text[0]["text"]
354
+ result = format_text_advanced(result)
355
 
356
  return result
357
  except Exception as e:
358
+ return f"Processing failed: {str(e)}"
359
 
360
+ # Create Gradio interface with localized labels
361
  demo = gr.Interface(
362
  fn=transcribe_audio_gradio,
363
  inputs=[
364
+ gr.Audio(
365
+ sources=["upload", "microphone"],
366
+ type="numpy",
367
+ label="Upload audio or record from microphone"
368
+ ),
369
  gr.Dropdown(
370
  choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"],
371
  value="auto",
372
+ label="Select Language"
373
  )
374
  ],
375
+ outputs=gr.Textbox(label="Recognition Result"),
376
+ title="SenseVoice Speech Recognition",
377
+ description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean",
378
  examples=[
379
  ["examples/zh.mp3", "zh"],
380
  ["examples/en.mp3", "en"],
381
  ]
382
  )
383
 
384
+ # Mount Gradio app to FastAPI
385
  app = gr.mount_gradio_app(app, demo, path="/")
386
 
387
+ # Custom Swagger UI redirect
388
  @app.get("/docs", include_in_schema=False)
389
  async def custom_swagger_ui_html():
390
  return HTMLResponse("""
391
  <!DOCTYPE html>
392
  <html>
393
  <head>
394
+ <title>SenseVoice API Documentation</title>
395
  <meta http-equiv="refresh" content="0;url=/docs/" />
396
  </head>
397
  <body>
398
+ <p>Redirecting to API documentation...</p>
399
  </body>
400
  </html>
401
  """)