megatrump commited on
Commit
373e485
·
1 Parent(s): 4a1f483

添加了统一的推理入口

Browse files
Files changed (1) hide show
  1. api.py +25 -27
api.py CHANGED
@@ -196,6 +196,30 @@ def format_text_advanced(text: str) -> str:
196
  return formatted_text.strip()
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
200
  """Process audio data and return transcription result"""
201
  try:
@@ -203,33 +227,7 @@ async def process_audio(audio_data: bytes, language: str = "auto") -> str:
203
  audio_buffer = BytesIO(audio_data)
204
  waveform, sample_rate = torchaudio.load(audio_buffer)
205
 
206
- # Convert to mono channel
207
- if waveform.shape[0] > 1:
208
- waveform = waveform.mean(dim=0)
209
- else:
210
- waveform = np.squeeze(waveform)
211
-
212
- # Convert to numpy array and normalize
213
- input_wav = waveform.numpy().astype(np.float32)
214
-
215
- # Resample to 16kHz if needed
216
- if sample_rate != 16000:
217
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
218
- input_wav = resampler(torch.from_numpy(input_wav)[None, :])[0, :].numpy()
219
-
220
- # Model inference
221
- text = model.generate(
222
- input=input_wav,
223
- cache={},
224
- language=language,
225
- use_itn=True,
226
- batch_size_s=500,
227
- merge_vad=True
228
- )
229
-
230
- # Format result
231
- result = text[0]["text"]
232
- result = format_text_advanced(result)
233
 
234
  return result
235
 
 
196
  return formatted_text.strip()
197
 
198
 
199
+ async def audio_stt(audio: np.ndarray, sample_rate: int, language: str = "auto") -> str:
200
+ # Step 01. Normalize & Resample
201
+ input_wav = audio.astype(np.float32) / np.iinfo(np.int16).max
202
+ # Step 02. Convert audio to mono channel
203
+ if len(input_wav.shape) > 1:
204
+ input_wav = input_wav.mean(-1)
205
+ # Step 03. Resample to 16kHz
206
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
207
+ input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32)
208
+ input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy()
209
+ # Step 04. Model Inference
210
+ text = model.generate(
211
+ input=input_wav,
212
+ cache={},
213
+ language=language,
214
+ use_itn=True,
215
+ batch_size_s=500,
216
+ merge_vad=True
217
+ )
218
+ # Step 05. Format Result
219
+ result = text[0]["text"]
220
+ result = format_text_advanced(result)
221
+ return result
222
+
223
  async def process_audio(audio_data: bytes, language: str = "auto") -> str:
224
  """Process audio data and return transcription result"""
225
  try:
 
227
  audio_buffer = BytesIO(audio_data)
228
  waveform, sample_rate = torchaudio.load(audio_buffer)
229
 
230
+ result = audio_stt(waveform, sample_rate, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  return result
233