File size: 5,194 Bytes
e61d4fe 1e44fe0 2278032 1e44fe0 e61d4fe 1e44fe0 2278032 e61d4fe 1e44fe0 705afb7 1e44fe0 705afb7 1e44fe0 e61d4fe 1e44fe0 2278032 1e44fe0 e61d4fe 1e44fe0 e61d4fe 2278032 e61d4fe 705afb7 e61d4fe 1e44fe0 e61d4fe 1e44fe0 e61d4fe 705afb7 e61d4fe 2278032 e61d4fe 705afb7 e61d4fe 1e44fe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import base64
import io
import logging
import os
import pathlib
import typing
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Request, UploadFile, File, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo
from starlette.websockets import WebSocketState
@asynccontextmanager
async def register_init(app: FastAPI):
"""
启动初始化
:return:
"""
print('Loading ASR model...')
setup_asr_model()
yield
def register_middleware(app: FastAPI):
# Gzip: Always at the top
app.add_middleware(GZipMiddleware)
# CORS: Always at the end
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
def create_app():
app = FastAPI(
lifespan=register_init
)
register_middleware(app)
return app
app = create_app()
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
asr_model: typing.Optional[WhisperModel] = None
def setup_asr_model():
global asr_model
if asr_model is None:
logging.info('Loading ASR model...')
asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
logging.info('Load ASR model finished.')
return asr_model
class TranscribeRequestParams(BaseModel):
uuid: str = Field(title='Request Unique Id.')
audio_file: str
language: typing.Literal['en', 'zh',]
using_file_content: bool
@model_validator(mode='after')
def check_audio_file(self):
if self.using_file_content:
return self
if not pathlib.Path(self.audio_file).exists():
raise FileNotFoundError(f'Audio file not exists.')
@app.post('/transcribe')
async def transcribe_api(
request: Request,
obj: TranscribeRequestParams
):
try:
audio_file = obj.audio_file
if obj.using_file_content:
audio_file = io.BytesIO(base64.b64decode(obj.audio_file))
segments, _ = asr_model.transcribe(audio_file, language=obj.language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': obj.uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': obj.uuid,
'transcribed_text': transcribed_text
}
return response_body
@app.post('/transcribe-file')
async def transcribe_file_api(
request: Request,
uuid: str,
audio_file: typing.Annotated[UploadFile, File()],
language: typing.Literal['en', 'zh']
):
try:
segments, _ = asr_model.transcribe(audio_file.file, language=language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': uuid,
'transcribed_text': transcribed_text
}
return response_body
@app.websocket('/transcribe')
async def transcribe_ws_api(
websocket: WebSocket
):
await websocket.accept()
while websocket.client_state == WebSocketState.CONNECTED:
request_params = await websocket.receive_json()
try:
form = TranscribeRequestParams.model_validate(request_params)
except ValidationError as exc:
logging.exception(exc)
await websocket.send_json({
"if_success": False,
'uuid': request_params.get('uuid', ''),
'msg': f'{exc}'
})
continue
try:
audio_file = form.audio_file
if form.using_file_content:
audio_file = io.BytesIO(base64.b64decode(form.audio_file))
segments, _ = asr_model.transcribe(audio_file, language=form.language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': form.uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': form.uuid,
'transcribed_text': transcribed_text
}
await websocket.send_json(response_body)
if __name__ == '__main__':
uvicorn.run(
app,
host=os.environ.get('HOST', '0.0.0.0'),
port=int(os.environ.get('PORT', 8080))
)
|