Chenhao commited on
Commit
660c142
·
1 Parent(s): 5899d37

Format the code with claude 3.5

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. api.py +143 -111
  3. start.sh +3 -1
  4. test/01_rpc_test.py +67 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ .venv/
3
+ .vscode/
4
+
5
+ *.pyc
api.py CHANGED
@@ -33,7 +33,7 @@ security = HTTPBearer()
33
 
34
  app = FastAPI(
35
  title="SenseVoice API",
36
- description="语音识别 API 服务",
37
  version="1.0.0"
38
  )
39
 
@@ -55,78 +55,77 @@ model = AutoModel(
55
  device="cuda"
56
  )
57
 
58
- # 复用原有的格式化函数
59
  emotion_dict: Dict[str, str] = {
60
- "<|HAPPY|>": "😊",
61
- "<|SAD|>": "😔",
62
- "<|ANGRY|>": "😡",
63
- "<|NEUTRAL|>": "",
64
- "<|FEARFUL|>": "😰",
65
- "<|DISGUSTED|>": "🤢",
66
- "<|SURPRISED|>": "😮",
67
  }
68
 
69
  event_dict: Dict[str, str] = {
70
- "<|BGM|>": "🎼",
71
- "<|Speech|>": "",
72
- "<|Applause|>": "👏",
73
- "<|Laughter|>": "😀",
74
- "<|Cry|>": "😭",
75
- "<|Sneeze|>": "🤧",
76
- "<|Breath|>": "",
77
- "<|Cough|>": "🤧",
78
  }
79
 
80
  emoji_dict: Dict[str, str] = {
81
  "<|nospeech|><|Event_UNK|>": "❓",
82
- "<|zh|>": "",
83
- "<|en|>": "",
84
- "<|yue|>": "",
85
- "<|ja|>": "",
86
- "<|ko|>": "",
87
- "<|nospeech|>": "",
88
- "<|HAPPY|>": "😊",
89
- "<|SAD|>": "😔",
90
- "<|ANGRY|>": "😡",
91
- "<|NEUTRAL|>": "",
92
- "<|BGM|>": "🎼",
93
- "<|Speech|>": "",
94
- "<|Applause|>": "👏",
95
- "<|Laughter|>": "😀",
96
- "<|FEARFUL|>": "😰",
97
- "<|DISGUSTED|>": "🤢",
98
- "<|SURPRISED|>": "😮",
99
- "<|Cry|>": "😭",
100
- "<|EMO_UNKNOWN|>": "",
101
- "<|Sneeze|>": "🤧",
102
- "<|Breath|>": "",
103
- "<|Cough|>": "😷",
104
- "<|Sing|>": "",
105
  "<|Speech_Noise|>": "",
106
- "<|withitn|>": "",
107
- "<|woitn|>": "",
108
- "<|GBG|>": "",
109
- "<|Event_UNK|>": "",
110
  }
111
 
112
  lang_dict: Dict[str, str] = {
113
- "<|zh|>": "<|lang|>",
114
- "<|en|>": "<|lang|>",
115
- "<|yue|>": "<|lang|>",
116
- "<|ja|>": "<|lang|>",
117
- "<|ko|>": "<|lang|>",
118
- "<|nospeech|>": "<|lang|>",
119
  }
120
 
121
  emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"}
122
  event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"}
123
 
124
 
125
- def format_text_basic(text: str) -> str:
126
- """Replace special tokens with corresponding emojis"""
127
- for token in emoji_dict:
128
- text = text.replace(token, emoji_dict[token])
129
- return text
130
 
131
 
132
  def format_text_with_emotion(text: str) -> str:
@@ -198,53 +197,90 @@ def format_text_advanced(text: str) -> str:
198
 
199
 
200
  async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str:
201
- """ Audio as an already normalized Float32 Tensor
 
 
 
 
 
 
 
 
202
  """
203
- # Step 01. Normalize
204
- input_wav = audio.to(torch.float32)
205
-
206
- # Step 02. Convert audio to mono channel
207
- if len(input_wav.shape) > 1:
208
- input_wav = input_wav.mean(dim=0)
209
- input_wav = input_wav.squeeze()
210
- # Step 03. Resample to 16kHz
211
- if sample_rate != 16000:
212
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
213
- input_wav = resampler(input_wav[None, :])[0, :].numpy()
214
- # Step 04. Model Inference
215
- text = model.generate(
216
- input=input_wav,
217
- cache={},
218
- language=language,
219
- use_itn=True,
220
- batch_size_s=500,
221
- merge_vad=True
222
- )
223
- # Step 05. Format Result
224
- result = text[0]["text"]
225
- result = format_text_advanced(result)
226
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
229
- """Process audio data and return transcription result"""
 
 
 
 
 
 
 
 
 
 
 
230
  try:
231
- # Convert bytes to numpy array
232
  audio_buffer = BytesIO(audio_data)
233
  waveform, sample_rate = torchaudio.load(
234
- uri = audio_buffer,
235
- normalize = True,
236
- channels_first = True,
237
  )
238
-
239
  result = await audio_stt(waveform, sample_rate, language)
240
-
241
  return result
242
 
243
  except Exception as e:
244
- import traceback
245
- traceback.print_exc()
246
- traceback.print_stack()
247
- raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}")
248
 
249
 
250
  async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials:
@@ -260,56 +296,52 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(secur
260
  @app.post("/v1/audio/transcriptions")
261
  async def transcribe_audio(
262
  file: UploadFile = File(...),
263
- model: Optional[str] = "FunAudioLLM/SenseVoiceSmall",
264
- language: Optional[str] = "auto",
265
  token: HTTPAuthorizationCredentials = Depends(verify_token)
266
  ) -> Dict[str, Union[str, int, float]]:
267
- """Audio transcription endpoint
268
 
269
  Args:
270
- file: Audio file (supports common audio formats)
271
- model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall
272
- language: Language code, supports auto/zh/en/yue/ja/ko/nospeech
 
273
 
274
  Returns:
275
- Dict[str, Union[str, int, float]]: {
276
- "text": "Transcription result",
277
- "error_code": 0,
278
- "error_msg": "",
279
- "process_time": 1.234 # Processing time in seconds
280
- }
281
  """
282
  start_time = time.time()
283
 
284
  try:
285
- # Validate file format
286
  if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
287
  return {
288
  "text": "",
289
  "error_code": 400,
290
- "error_msg": "Unsupported audio format",
291
  "process_time": time.time() - start_time
292
  }
293
 
294
- # Validate model
295
  if model != "FunAudioLLM/SenseVoiceSmall":
296
  return {
297
  "text": "",
298
  "error_code": 400,
299
- "error_msg": "Unsupported model",
300
  "process_time": time.time() - start_time
301
  }
302
 
303
- # Validate language
304
  if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]:
305
  return {
306
  "text": "",
307
  "error_code": 400,
308
- "error_msg": "Unsupported language",
309
  "process_time": time.time() - start_time
310
  }
311
 
312
- # Process audio
313
  content = await file.read()
314
  text = await process_audio(content, language)
315
 
@@ -341,8 +373,8 @@ def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: s
341
  # Normalize audio
342
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
343
 
 
344
  input_wav = torch.from_numpy(input_wav)
345
-
346
  result = asyncio.run(audio_stt(input_wav, sample_rate, language))
347
 
348
  return result
 
33
 
34
  app = FastAPI(
35
  title="SenseVoice API",
36
+ description="Speech To Text API Service",
37
  version="1.0.0"
38
  )
39
 
 
55
  device="cuda"
56
  )
57
 
 
58
  emotion_dict: Dict[str, str] = {
59
+ "<|HAPPY|>": "😊",
60
+ "<|SAD|>": "😔",
61
+ "<|ANGRY|>": "😡",
62
+ "<|NEUTRAL|>": "",
63
+ "<|FEARFUL|>": "😰",
64
+ "<|DISGUSTED|>": "🤢",
65
+ "<|SURPRISED|>": "😮",
66
  }
67
 
68
  event_dict: Dict[str, str] = {
69
+ "<|BGM|>": "🎼",
70
+ "<|Speech|>": "",
71
+ "<|Applause|>": "👏",
72
+ "<|Laughter|>": "😀",
73
+ "<|Cry|>": "😭",
74
+ "<|Sneeze|>": "🤧",
75
+ "<|Breath|>": "",
76
+ "<|Cough|>": "🤧",
77
  }
78
 
79
  emoji_dict: Dict[str, str] = {
80
  "<|nospeech|><|Event_UNK|>": "❓",
81
+ "<|zh|>": "",
82
+ "<|en|>": "",
83
+ "<|yue|>": "",
84
+ "<|ja|>": "",
85
+ "<|ko|>": "",
86
+ "<|nospeech|>": "",
87
+ "<|HAPPY|>": "😊",
88
+ "<|SAD|>": "😔",
89
+ "<|ANGRY|>": "😡",
90
+ "<|NEUTRAL|>": "",
91
+ "<|BGM|>": "🎼",
92
+ "<|Speech|>": "",
93
+ "<|Applause|>": "👏",
94
+ "<|Laughter|>": "😀",
95
+ "<|FEARFUL|>": "😰",
96
+ "<|DISGUSTED|>": "🤢",
97
+ "<|SURPRISED|>": "😮",
98
+ "<|Cry|>": "😭",
99
+ "<|EMO_UNKNOWN|>": "",
100
+ "<|Sneeze|>": "🤧",
101
+ "<|Breath|>": "",
102
+ "<|Cough|>": "😷",
103
+ "<|Sing|>": "",
104
  "<|Speech_Noise|>": "",
105
+ "<|withitn|>": "",
106
+ "<|woitn|>": "",
107
+ "<|GBG|>": "",
108
+ "<|Event_UNK|>": "",
109
  }
110
 
111
  lang_dict: Dict[str, str] = {
112
+ "<|zh|>": "<|lang|>",
113
+ "<|en|>": "<|lang|>",
114
+ "<|yue|>": "<|lang|>",
115
+ "<|ja|>": "<|lang|>",
116
+ "<|ko|>": "<|lang|>",
117
+ "<|nospeech|>": "<|lang|>",
118
  }
119
 
120
  emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"}
121
  event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"}
122
 
123
 
124
+ # def format_text_basic(text: str) -> str:
125
+ # """Replace special tokens with corresponding emojis"""
126
+ # for token in emoji_dict:
127
+ # text = text.replace(token, emoji_dict[token])
128
+ # return text
129
 
130
 
131
  def format_text_with_emotion(text: str) -> str:
 
197
 
198
 
199
  async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str:
200
+ """Process audio tensor and perform speech-to-text conversion.
201
+
202
+ Args:
203
+ audio: Input audio tensor
204
+ sample_rate: Audio sample rate in Hz
205
+ language: Target language code (auto/zh/en/yue/ja/ko/nospeech)
206
+
207
+ Returns:
208
+ str: Transcribed and formatted text result
209
  """
210
+ try:
211
+ # Normalize
212
+ if audio.dtype != torch.float32:
213
+ if audio.dtype == torch.int16:
214
+ audio = audio.float() / torch.iinfo(torch.int16).max
215
+ elif audio.dtype == torch.int32:
216
+ audio = audio.float() / torch.iinfo(torch.int32).max
217
+ else:
218
+ audio = audio.float()
219
+
220
+ # Make sure audio in correct range
221
+ if audio.abs().max() > 1.0:
222
+ audio = audio / audio.abs().max()
223
+
224
+ # Convert to mono channel
225
+ if len(audio.shape) > 1:
226
+ audio = audio.mean(dim=0)
227
+ audio = audio.squeeze()
228
+
229
+ # Resample
230
+ if sample_rate != 16000:
231
+ resampler = torchaudio.transforms.Resample(
232
+ orig_freq=sample_rate,
233
+ new_freq=16000
234
+ )
235
+ audio = resampler(audio.unsqueeze(0)).squeeze(0)
236
+
237
+ text = model.generate(
238
+ input=audio,
239
+ cache={},
240
+ language=language,
241
+ use_itn=True,
242
+ batch_size_s=500,
243
+ merge_vad=True
244
+ )
245
+
246
+ # 格式化结果
247
+ result = text[0]["text"]
248
+ return format_text_advanced(result)
249
+
250
+ except Exception as e:
251
+ raise HTTPException(
252
+ status_code=500,
253
+ detail=f"Audio processing failed in audio_stt: {str(e)}"
254
+ )
255
 
256
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
257
+ """Process audio data and return transcription result.
258
+
259
+ Args:
260
+ audio_data: Raw audio data in bytes
261
+ language: Target language code
262
+
263
+ Returns:
264
+ str: Transcribed and formatted text
265
+
266
+ Raises:
267
+ HTTPException: If audio processing fails
268
+ """
269
  try:
 
270
  audio_buffer = BytesIO(audio_data)
271
  waveform, sample_rate = torchaudio.load(
272
+ uri=audio_buffer,
273
+ normalize=True,
274
+ channels_first=True
275
  )
 
276
  result = await audio_stt(waveform, sample_rate, language)
 
277
  return result
278
 
279
  except Exception as e:
280
+ raise HTTPException(
281
+ status_code=500,
282
+ detail=f"Audio processing failed: {str(e)}"
283
+ )
284
 
285
 
286
  async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials:
 
296
  @app.post("/v1/audio/transcriptions")
297
  async def transcribe_audio(
298
  file: UploadFile = File(...),
299
+ model: str = "FunAudioLLM/SenseVoiceSmall",
300
+ language: str = "auto",
301
  token: HTTPAuthorizationCredentials = Depends(verify_token)
302
  ) -> Dict[str, Union[str, int, float]]:
303
+ """Audio transcription endpoint.
304
 
305
  Args:
306
+ file: Audio file (supports mp3, wav, flac, ogg, m4a)
307
+ model: Model name
308
+ language: Language code
309
+ token: Authentication token
310
 
311
  Returns:
312
+ Dict containing transcription result and metadata
 
 
 
 
 
313
  """
314
  start_time = time.time()
315
 
316
  try:
317
+ # Check the file format
318
  if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
319
  return {
320
  "text": "",
321
  "error_code": 400,
322
+ "error_msg": "不支持的音频格式",
323
  "process_time": time.time() - start_time
324
  }
325
 
326
+ # Check the model
327
  if model != "FunAudioLLM/SenseVoiceSmall":
328
  return {
329
  "text": "",
330
  "error_code": 400,
331
+ "error_msg": "不支持的模型",
332
  "process_time": time.time() - start_time
333
  }
334
 
335
+ # Check the language
336
  if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]:
337
  return {
338
  "text": "",
339
  "error_code": 400,
340
+ "error_msg": "不支持的语言",
341
  "process_time": time.time() - start_time
342
  }
343
 
344
+ # STT
345
  content = await file.read()
346
  text = await process_audio(content, language)
347
 
 
373
  # Normalize audio
374
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
375
 
376
+ # Model Inference
377
  input_wav = torch.from_numpy(input_wav)
 
378
  result = asyncio.run(audio_stt(input_wav, sample_rate, language))
379
 
380
  return result
start.sh CHANGED
@@ -1,7 +1,9 @@
1
  #!/bin/bash
2
 
 
 
3
  # Keep Alive
4
  python3 awake.py &
5
 
6
  # 启动FastAPI服务
7
- python -m uvicorn api:app --host 0.0.0.0 --port 7860
 
1
  #!/bin/bash
2
 
3
+ export API_TOKEN=your-secret-token-here
4
+
5
  # Keep Alive
6
  python3 awake.py &
7
 
8
  # 启动FastAPI服务
9
+ python -m uvicorn api:app --host 0.0.0.0 --port 8000
test/01_rpc_test.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import asyncio
3
+ import httpx
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ async def transcribe_audio(
8
+ file_path: str,
9
+ api_token: str,
10
+ model: str = "FunAudioLLM/SenseVoiceSmall",
11
+ api_url: str = "http://127.0.0.1:8000/v1/audio/transcriptions"
12
+ ) -> Optional[dict]:
13
+ """异步发送语音识别请求
14
+
15
+ Args:
16
+ file_path: 音频文件路径
17
+ api_token: API 认证令牌
18
+ model: 模型名称,默认为 FunAudioLLM/SenseVoiceSmall
19
+ api_url: API 服务地址
20
+
21
+ Returns:
22
+ dict: 包含识别结果的字典,失败时返回 None
23
+ """
24
+ try:
25
+ # 检查文件是否存在
26
+ audio_file = Path(file_path)
27
+ if not audio_file.exists():
28
+ print(f"错误:文件 {file_path} 不存在")
29
+ return None
30
+
31
+ # 准备请求头和文件
32
+ headers = {"Authorization": f"Bearer {api_token}"}
33
+ files = {
34
+ "file": (audio_file.name, audio_file.open("rb")),
35
+ "model": (None, model)
36
+ }
37
+
38
+ # 发送异步请求
39
+ async with httpx.AsyncClient() as client:
40
+ response = await client.post(
41
+ api_url,
42
+ headers=headers,
43
+ files=files,
44
+ timeout=60,
45
+ )
46
+ print(response.text)
47
+ response.raise_for_status()
48
+ return response.json()
49
+
50
+ except httpx.HTTPError as e:
51
+ print(f"HTTP 请求错误:{str(e)}")
52
+ return None
53
+ except Exception as e:
54
+ print(f"发生错误:{str(e)}")
55
+ return None
56
+
57
+ async def main():
58
+ # 使用示例
59
+ file_path = "../examples/zh.mp3"
60
+ api_token = "your-secret-token-here"
61
+
62
+ result = await transcribe_audio(file_path, api_token)
63
+ if result:
64
+ print(f"识别结果:{result['text']}")
65
+
66
+ if __name__ == "__main__":
67
+ asyncio.run(main())