Chenhao commited on
Commit
5899d37
·
1 Parent(s): 373e485

使用了统一的推理入口

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -15
  2. api.py +21 -32
  3. build.sh +1 -1
Dockerfile CHANGED
@@ -1,5 +1,4 @@
1
- # 构建阶段
2
- FROM python:3.12-slim as builder
3
 
4
  # 设置工作目录
5
  WORKDIR /app
@@ -16,22 +15,10 @@ COPY requirements.txt .
16
  # 安装Python依赖
17
  RUN pip install --no-cache-dir -r requirements.txt
18
 
19
- # 运行阶段
20
- FROM python:3.12-slim
21
-
22
- # 安装ffmpeg
23
- RUN apt-get update && apt-get install -y --no-install-recommends \
24
- ffmpeg \
25
- && rm -rf /var/lib/apt/lists/*
26
-
27
  # 创建非特权用户
28
  RUN useradd -m -s /bin/bash app
29
 
30
- # 设置工作目录
31
- WORKDIR /app
32
-
33
- # 复制应用代码和依赖
34
- COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
35
  COPY . .
36
 
37
  # 设置权限
 
1
+ FROM python:3.12-slim
 
2
 
3
  # 设置工作目录
4
  WORKDIR /app
 
15
  # 安装Python依赖
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
 
 
 
 
 
 
 
 
18
  # 创建非特权用户
19
  RUN useradd -m -s /bin/bash app
20
 
21
+ # 复制应用代码
 
 
 
 
22
  COPY . .
23
 
24
  # 设置权限
api.py CHANGED
@@ -4,6 +4,7 @@ from io import BytesIO
4
  from typing import Optional, Dict, Any, List, Set, Union, Tuple
5
  import os
6
  import time
 
7
 
8
  # Third-party imports
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
@@ -196,16 +197,20 @@ def format_text_advanced(text: str) -> str:
196
  return formatted_text.strip()
197
 
198
 
199
- async def audio_stt(audio: np.ndarray, sample_rate: int, language: str = "auto") -> str:
200
- # Step 01. Normalize & Resample
201
- input_wav = audio.astype(np.float32) / np.iinfo(np.int16).max
 
 
 
202
  # Step 02. Convert audio to mono channel
203
  if len(input_wav.shape) > 1:
204
- input_wav = input_wav.mean(-1)
 
205
  # Step 03. Resample to 16kHz
206
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
207
- input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32)
208
- input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy()
209
  # Step 04. Model Inference
210
  text = model.generate(
211
  input=input_wav,
@@ -225,9 +230,13 @@ async def process_audio(audio_data: bytes, language: str = "auto") -> str:
225
  try:
226
  # Convert bytes to numpy array
227
  audio_buffer = BytesIO(audio_data)
228
- waveform, sample_rate = torchaudio.load(audio_buffer)
 
 
 
 
229
 
230
- result = audio_stt(waveform, sample_rate, language)
231
 
232
  return result
233
 
@@ -332,29 +341,9 @@ def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: s
332
  # Normalize audio
333
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
334
 
335
- # Convert to mono
336
- if len(input_wav.shape) > 1:
337
- input_wav = input_wav.mean(-1)
338
-
339
- # Resample to 16kHz if needed
340
- if sample_rate != 16000:
341
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
342
- input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32)
343
- input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy()
344
-
345
- # Model inference
346
- text = model.generate(
347
- input=input_wav,
348
- cache={},
349
- language=language,
350
- use_itn=True,
351
- batch_size_s=500,
352
- merge_vad=True
353
- )
354
-
355
- # Format result
356
- result = text[0]["text"]
357
- result = format_text_advanced(result)
358
 
359
  return result
360
  except Exception as e:
 
4
  from typing import Optional, Dict, Any, List, Set, Union, Tuple
5
  import os
6
  import time
7
+ import asyncio
8
 
9
  # Third-party imports
10
  from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
 
197
  return formatted_text.strip()
198
 
199
 
200
+ async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str:
201
+ """ Audio as an already normalized Float32 Tensor
202
+ """
203
+ # Step 01. Normalize
204
+ input_wav = audio.to(torch.float32)
205
+
206
  # Step 02. Convert audio to mono channel
207
  if len(input_wav.shape) > 1:
208
+ input_wav = input_wav.mean(dim=0)
209
+ input_wav = input_wav.squeeze()
210
  # Step 03. Resample to 16kHz
211
+ if sample_rate != 16000:
212
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
213
+ input_wav = resampler(input_wav[None, :])[0, :].numpy()
214
  # Step 04. Model Inference
215
  text = model.generate(
216
  input=input_wav,
 
230
  try:
231
  # Convert bytes to numpy array
232
  audio_buffer = BytesIO(audio_data)
233
+ waveform, sample_rate = torchaudio.load(
234
+ uri = audio_buffer,
235
+ normalize = True,
236
+ channels_first = True,
237
+ )
238
 
239
+ result = await audio_stt(waveform, sample_rate, language)
240
 
241
  return result
242
 
 
341
  # Normalize audio
342
  input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
343
 
344
+ input_wav = torch.from_numpy(input_wav)
345
+
346
+ result = asyncio.run(audio_stt(input_wav, sample_rate, language))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  return result
349
  except Exception as e:
build.sh CHANGED
@@ -22,7 +22,7 @@ docker build -t $IMAGE_NAME .
22
  echo "启动容器..."
23
  docker run -d \
24
  --name $CONTAINER_NAME \
25
- -p $PORT:8000 \
26
  -e API_TOKEN="your-secret-token-here" \
27
  -e PYTHONUNBUFFERED=1 \
28
  $IMAGE_NAME
 
22
  echo "启动容器..."
23
  docker run -d \
24
  --name $CONTAINER_NAME \
25
+ -p $PORT:7860 \
26
  -e API_TOKEN="your-secret-token-here" \
27
  -e PYTHONUNBUFFERED=1 \
28
  $IMAGE_NAME