File size: 5,194 Bytes
e61d4fe
 
 
1e44fe0
2278032
1e44fe0
 
 
 
e61d4fe
1e44fe0
 
 
2278032
e61d4fe
1e44fe0
 
 
 
 
 
 
 
 
705afb7
 
1e44fe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705afb7
 
 
 
 
 
 
 
 
 
1e44fe0
 
 
 
e61d4fe
1e44fe0
2278032
 
 
 
 
 
 
 
 
1e44fe0
 
 
e61d4fe
1e44fe0
 
 
e61d4fe
2278032
 
 
e61d4fe
705afb7
 
 
 
 
 
e61d4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
1e44fe0
 
 
e61d4fe
1e44fe0
 
 
 
 
e61d4fe
705afb7
 
 
 
 
 
e61d4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2278032
 
 
 
e61d4fe
705afb7
 
 
 
 
 
e61d4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e44fe0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import base64
import io
import logging
import os
import pathlib
import typing
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI, Request, UploadFile, File, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo
from starlette.websockets import WebSocketState


@asynccontextmanager
async def register_init(app: FastAPI):
    """
    启动初始化

    :return:
    """
    print('Loading ASR model...')
    setup_asr_model()

    yield


def register_middleware(app: FastAPI):
    # Gzip: Always at the top
    app.add_middleware(GZipMiddleware)

    # CORS: Always at the end
    app.add_middleware(
        CORSMiddleware,
        allow_origins=['*'],
        allow_credentials=True,
        allow_methods=['*'],
        allow_headers=['*'],
    )


def create_app():
    app = FastAPI(
        lifespan=register_init
    )
    register_middleware(app)
    return app


app = create_app()

model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
asr_model: typing.Optional[WhisperModel] = None


def setup_asr_model():
    global asr_model
    if asr_model is None:
        logging.info('Loading ASR model...')
        asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
        logging.info('Load ASR model finished.')
    return asr_model


class TranscribeRequestParams(BaseModel):
    uuid: str = Field(title='Request Unique Id.')
    audio_file: str
    language: typing.Literal['en', 'zh',]
    using_file_content: bool

    @model_validator(mode='after')
    def check_audio_file(self):
        if self.using_file_content:
            return self

        if not pathlib.Path(self.audio_file).exists():
            raise FileNotFoundError(f'Audio file not exists.')


@app.post('/transcribe')
async def transcribe_api(
        request: Request,
        obj: TranscribeRequestParams
):
    try:
        audio_file = obj.audio_file
        if obj.using_file_content:
            audio_file = io.BytesIO(base64.b64decode(obj.audio_file))

        segments, _ = asr_model.transcribe(audio_file, language=obj.language)

        transcribed_text = ''
        for segment in segments:
            transcribed_text = segment.text
            break
    except Exception as exc:
        logging.exception(exc)
        response_body = {
            "if_success": False,
            'uuid': obj.uuid,
            'msg': f'{exc}'
        }
    else:
        response_body = {
            "if_success": True,
            'uuid': obj.uuid,
            'transcribed_text': transcribed_text
        }
    return response_body


@app.post('/transcribe-file')
async def transcribe_file_api(
        request: Request,
        uuid: str,
        audio_file: typing.Annotated[UploadFile, File()],
        language: typing.Literal['en', 'zh']
):
    try:
        segments, _ = asr_model.transcribe(audio_file.file, language=language)

        transcribed_text = ''
        for segment in segments:
            transcribed_text = segment.text
            break
    except Exception as exc:
        logging.exception(exc)
        response_body = {
            "if_success": False,
            'uuid': uuid,
            'msg': f'{exc}'
        }
    else:
        response_body = {
            "if_success": True,
            'uuid': uuid,
            'transcribed_text': transcribed_text
        }

    return response_body


@app.websocket('/transcribe')
async def transcribe_ws_api(
        websocket: WebSocket
):
    await websocket.accept()

    while websocket.client_state == WebSocketState.CONNECTED:
        request_params = await websocket.receive_json()

        try:
            form = TranscribeRequestParams.model_validate(request_params)
        except ValidationError as exc:
            logging.exception(exc)
            await websocket.send_json({
                "if_success": False,
                'uuid': request_params.get('uuid', ''),
                'msg': f'{exc}'
            })
            continue

        try:

            audio_file = form.audio_file
            if form.using_file_content:
                audio_file = io.BytesIO(base64.b64decode(form.audio_file))

            segments, _ = asr_model.transcribe(audio_file, language=form.language)

            transcribed_text = ''
            for segment in segments:
                transcribed_text = segment.text
                break
        except Exception as exc:
            logging.exception(exc)
            response_body = {
                "if_success": False,
                'uuid': form.uuid,
                'msg': f'{exc}'
            }
        else:
            response_body = {
                "if_success": True,
                'uuid': form.uuid,
                'transcribed_text': transcribed_text
            }

        await websocket.send_json(response_body)


if __name__ == '__main__':
    uvicorn.run(
        app,
        host=os.environ.get('HOST', '0.0.0.0'),
        port=int(os.environ.get('PORT', 8080))
    )