Fedir Zadniprovskyi commited on
Commit
8f3dcc9
·
1 Parent(s): 624f97e

refactor: split out app into multiple router modules

Browse files
src/faster_whisper_server/main.py CHANGED
@@ -1,62 +1,34 @@
1
  from __future__ import annotations
2
 
3
- import asyncio
4
  from contextlib import asynccontextmanager
5
- import gc
6
- from io import BytesIO
7
- from typing import TYPE_CHECKING, Annotated, Literal
8
 
9
  from fastapi import (
10
  FastAPI,
11
- Form,
12
- HTTPException,
13
- Path,
14
- Query,
15
- Response,
16
- UploadFile,
17
- WebSocket,
18
- WebSocketDisconnect,
19
  )
20
  from fastapi.middleware.cors import CORSMiddleware
21
- from fastapi.responses import StreamingResponse
22
- from fastapi.websockets import WebSocketState
23
- from faster_whisper.vad import VadOptions, get_speech_timestamps
24
- import huggingface_hub
25
- from huggingface_hub.hf_api import RepositoryNotFoundError
26
- from pydantic import AfterValidator
27
 
28
- from faster_whisper_server import hf_utils
29
- from faster_whisper_server.asr import FasterWhisperASR
30
- from faster_whisper_server.audio import AudioStream, audio_samples_from_file
31
  from faster_whisper_server.config import (
32
- SAMPLES_PER_SECOND,
33
- Language,
34
- ResponseFormat,
35
- Task,
36
  config,
37
  )
38
- from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
39
  from faster_whisper_server.logger import logger
40
- from faster_whisper_server.model_manager import ModelManager
41
- from faster_whisper_server.server_models import (
42
- ModelListResponse,
43
- ModelObject,
44
- TranscriptionJsonResponse,
45
- TranscriptionVerboseJsonResponse,
 
 
 
46
  )
47
- from faster_whisper_server.transcriber import audio_transcriber
48
 
49
  if TYPE_CHECKING:
50
- from collections.abc import AsyncGenerator, Generator, Iterable
51
-
52
- from faster_whisper.transcribe import TranscriptionInfo
53
- from huggingface_hub.hf_api import ModelInfo
54
 
55
 
56
  logger.debug(f"Config: {config}")
57
 
58
- model_manager = ModelManager()
59
-
60
 
61
  @asynccontextmanager
62
  async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
@@ -67,6 +39,10 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
67
 
68
  app = FastAPI(lifespan=lifespan)
69
 
 
 
 
 
70
  if config.allow_origins is not None:
71
  app.add_middleware(
72
  CORSMiddleware,
@@ -76,315 +52,6 @@ if config.allow_origins is not None:
76
  allow_headers=["*"],
77
  )
78
 
79
-
80
- @app.get("/health")
81
- def health() -> Response:
82
- return Response(status_code=200, content="OK")
83
-
84
-
85
- @app.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.")
86
- def pull_model(model_name: str) -> Response:
87
- if hf_utils.does_local_model_exist(model_name):
88
- return Response(status_code=200, content="Model already exists")
89
- try:
90
- huggingface_hub.snapshot_download(model_name, repo_type="model")
91
- except RepositoryNotFoundError as e:
92
- return Response(status_code=404, content=str(e))
93
- return Response(status_code=201, content="Model downloaded")
94
-
95
-
96
- @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
97
- def get_running_models() -> dict[str, list[str]]:
98
- return {"models": list(model_manager.loaded_models.keys())}
99
-
100
-
101
- @app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
102
- def load_model_route(model_name: str) -> Response:
103
- if model_name in model_manager.loaded_models:
104
- return Response(status_code=409, content="Model already loaded")
105
- model_manager.load_model(model_name)
106
- return Response(status_code=201)
107
-
108
-
109
- @app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
110
- def stop_running_model(model_name: str) -> Response:
111
- model = model_manager.loaded_models.get(model_name)
112
- if model is not None:
113
- del model_manager.loaded_models[model_name]
114
- gc.collect()
115
- return Response(status_code=204)
116
- return Response(status_code=404)
117
-
118
-
119
- @app.get("/v1/models")
120
- def get_models() -> ModelListResponse:
121
- models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
122
- models = list(models)
123
- models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
124
- transformed_models: list[ModelObject] = []
125
- for model in models:
126
- assert model.created_at is not None
127
- assert model.card_data is not None
128
- assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
129
- if model.card_data.language is None:
130
- language = []
131
- elif isinstance(model.card_data.language, str):
132
- language = [model.card_data.language]
133
- else:
134
- language = model.card_data.language
135
- transformed_model = ModelObject(
136
- id=model.id,
137
- created=int(model.created_at.timestamp()),
138
- object_="model",
139
- owned_by=model.id.split("/")[0],
140
- language=language,
141
- )
142
- transformed_models.append(transformed_model)
143
- return ModelListResponse(data=transformed_models)
144
-
145
-
146
- @app.get("/v1/models/{model_name:path}")
147
- # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
148
- def get_model(
149
- model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
150
- ) -> ModelObject:
151
- models = huggingface_hub.list_models(
152
- model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
153
- )
154
- models = list(models)
155
- models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
156
- if len(models) == 0:
157
- raise HTTPException(status_code=404, detail="Model doesn't exists")
158
- exact_match: ModelInfo | None = None
159
- for model in models:
160
- if model.id == model_name:
161
- exact_match = model
162
- break
163
- if exact_match is None:
164
- raise HTTPException(
165
- status_code=404,
166
- detail=f"Model doesn't exists. Possible matches: {', '.join([model.id for model in models])}",
167
- )
168
- assert exact_match.created_at is not None
169
- assert exact_match.card_data is not None
170
- assert exact_match.card_data.language is None or isinstance(exact_match.card_data.language, str | list)
171
- if exact_match.card_data.language is None:
172
- language = []
173
- elif isinstance(exact_match.card_data.language, str):
174
- language = [exact_match.card_data.language]
175
- else:
176
- language = exact_match.card_data.language
177
- return ModelObject(
178
- id=exact_match.id,
179
- created=int(exact_match.created_at.timestamp()),
180
- object_="model",
181
- owned_by=exact_match.id.split("/")[0],
182
- language=language,
183
- )
184
-
185
-
186
- def segments_to_response(
187
- segments: Iterable[Segment],
188
- transcription_info: TranscriptionInfo,
189
- response_format: ResponseFormat,
190
- ) -> Response:
191
- segments = list(segments)
192
- if response_format == ResponseFormat.TEXT: # noqa: RET503
193
- return Response(segments_to_text(segments), media_type="text/plain")
194
- elif response_format == ResponseFormat.JSON:
195
- return Response(
196
- TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
197
- media_type="application/json",
198
- )
199
- elif response_format == ResponseFormat.VERBOSE_JSON:
200
- return Response(
201
- TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
202
- media_type="application/json",
203
- )
204
- elif response_format == ResponseFormat.VTT:
205
- return Response(
206
- "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
207
- )
208
- elif response_format == ResponseFormat.SRT:
209
- return Response(
210
- "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
211
- )
212
-
213
-
214
- def format_as_sse(data: str) -> str:
215
- return f"data: {data}\n\n"
216
-
217
-
218
- def segments_to_streaming_response(
219
- segments: Iterable[Segment],
220
- transcription_info: TranscriptionInfo,
221
- response_format: ResponseFormat,
222
- ) -> StreamingResponse:
223
- def segment_responses() -> Generator[str, None, None]:
224
- for i, segment in enumerate(segments):
225
- if response_format == ResponseFormat.TEXT:
226
- data = segment.text
227
- elif response_format == ResponseFormat.JSON:
228
- data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
229
- elif response_format == ResponseFormat.VERBOSE_JSON:
230
- data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
231
- elif response_format == ResponseFormat.VTT:
232
- data = segments_to_vtt(segment, i)
233
- elif response_format == ResponseFormat.SRT:
234
- data = segments_to_srt(segment, i)
235
- yield format_as_sse(data)
236
-
237
- return StreamingResponse(segment_responses(), media_type="text/event-stream")
238
-
239
-
240
- def handle_default_openai_model(model_name: str) -> str:
241
- """Exists because some callers may not be able override the default("whisper-1") model name.
242
-
243
- For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
244
- """
245
- if model_name == "whisper-1":
246
- logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
247
- return config.whisper.model
248
- return model_name
249
-
250
-
251
- ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
252
-
253
-
254
- @app.post(
255
- "/v1/audio/translations",
256
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
257
- )
258
- def translate_file(
259
- file: Annotated[UploadFile, Form()],
260
- model: Annotated[ModelName, Form()] = config.whisper.model,
261
- prompt: Annotated[str | None, Form()] = None,
262
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
263
- temperature: Annotated[float, Form()] = 0.0,
264
- stream: Annotated[bool, Form()] = False,
265
- ) -> Response | StreamingResponse:
266
- whisper = model_manager.load_model(model)
267
- segments, transcription_info = whisper.transcribe(
268
- file.file,
269
- task=Task.TRANSLATE,
270
- initial_prompt=prompt,
271
- temperature=temperature,
272
- vad_filter=True,
273
- )
274
- segments = Segment.from_faster_whisper_segments(segments)
275
-
276
- if stream:
277
- return segments_to_streaming_response(segments, transcription_info, response_format)
278
- else:
279
- return segments_to_response(segments, transcription_info, response_format)
280
-
281
-
282
- # https://platform.openai.com/docs/api-reference/audio/createTranscription
283
- # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
284
- @app.post(
285
- "/v1/audio/transcriptions",
286
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
287
- )
288
- def transcribe_file(
289
- file: Annotated[UploadFile, Form()],
290
- model: Annotated[ModelName, Form()] = config.whisper.model,
291
- language: Annotated[Language | None, Form()] = config.default_language,
292
- prompt: Annotated[str | None, Form()] = None,
293
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
294
- temperature: Annotated[float, Form()] = 0.0,
295
- timestamp_granularities: Annotated[
296
- list[Literal["segment", "word"]],
297
- Form(alias="timestamp_granularities[]"),
298
- ] = ["segment"],
299
- stream: Annotated[bool, Form()] = False,
300
- hotwords: Annotated[str | None, Form()] = None,
301
- ) -> Response | StreamingResponse:
302
- whisper = model_manager.load_model(model)
303
- segments, transcription_info = whisper.transcribe(
304
- file.file,
305
- task=Task.TRANSCRIBE,
306
- language=language,
307
- initial_prompt=prompt,
308
- word_timestamps="word" in timestamp_granularities,
309
- temperature=temperature,
310
- vad_filter=True,
311
- hotwords=hotwords,
312
- )
313
- segments = Segment.from_faster_whisper_segments(segments)
314
-
315
- if stream:
316
- return segments_to_streaming_response(segments, transcription_info, response_format)
317
- else:
318
- return segments_to_response(segments, transcription_info, response_format)
319
-
320
-
321
- async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
322
- try:
323
- while True:
324
- bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
325
- logger.debug(f"Received {len(bytes_)} bytes of audio data")
326
- audio_samples = audio_samples_from_file(BytesIO(bytes_))
327
- audio_stream.extend(audio_samples)
328
- if audio_stream.duration - config.inactivity_window_seconds >= 0:
329
- audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
330
- vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
331
- # NOTE: This is a synchronous operation that runs every time new data is received.
332
- # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
333
- timestamps = get_speech_timestamps(audio.data, vad_opts)
334
- if len(timestamps) == 0:
335
- logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
336
- break
337
- elif (
338
- # last speech end time
339
- config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
340
- >= config.max_inactivity_seconds
341
- ):
342
- logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
343
- break
344
- except TimeoutError:
345
- logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
346
- except WebSocketDisconnect as e:
347
- logger.info(f"Client disconnected: {e}")
348
- audio_stream.close()
349
-
350
-
351
- @app.websocket("/v1/audio/transcriptions")
352
- async def transcribe_stream(
353
- ws: WebSocket,
354
- model: Annotated[ModelName, Query()] = config.whisper.model,
355
- language: Annotated[Language | None, Query()] = config.default_language,
356
- response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
357
- temperature: Annotated[float, Query()] = 0.0,
358
- ) -> None:
359
- await ws.accept()
360
- transcribe_opts = {
361
- "language": language,
362
- "temperature": temperature,
363
- "vad_filter": True,
364
- "condition_on_previous_text": False,
365
- }
366
- whisper = model_manager.load_model(model)
367
- asr = FasterWhisperASR(whisper, **transcribe_opts)
368
- audio_stream = AudioStream()
369
- async with asyncio.TaskGroup() as tg:
370
- tg.create_task(audio_receiver(ws, audio_stream))
371
- async for transcription in audio_transcriber(asr, audio_stream):
372
- logger.debug(f"Sending transcription: {transcription.text}")
373
- if ws.client_state == WebSocketState.DISCONNECTED:
374
- break
375
-
376
- if response_format == ResponseFormat.TEXT:
377
- await ws.send_text(transcription.text)
378
- elif response_format == ResponseFormat.JSON:
379
- await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
380
- elif response_format == ResponseFormat.VERBOSE_JSON:
381
- await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
382
-
383
- if ws.client_state != WebSocketState.DISCONNECTED:
384
- logger.info("Closing the connection.")
385
- await ws.close()
386
-
387
-
388
  if config.enable_ui:
389
  import gradio as gr
390
 
 
1
  from __future__ import annotations
2
 
 
3
  from contextlib import asynccontextmanager
4
+ from typing import TYPE_CHECKING
 
 
5
 
6
  from fastapi import (
7
  FastAPI,
 
 
 
 
 
 
 
 
8
  )
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
10
 
 
 
 
11
  from faster_whisper_server.config import (
 
 
 
 
12
  config,
13
  )
 
14
  from faster_whisper_server.logger import logger
15
+ from faster_whisper_server.model_manager import model_manager
16
+ from faster_whisper_server.routers.list_models import (
17
+ router as list_models_router,
18
+ )
19
+ from faster_whisper_server.routers.misc import (
20
+ router as misc_router,
21
+ )
22
+ from faster_whisper_server.routers.stt import (
23
+ router as stt_router,
24
  )
 
25
 
26
  if TYPE_CHECKING:
27
+ from collections.abc import AsyncGenerator
 
 
 
28
 
29
 
30
  logger.debug(f"Config: {config}")
31
 
 
 
32
 
33
  @asynccontextmanager
34
  async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
 
39
 
40
  app = FastAPI(lifespan=lifespan)
41
 
42
+ app.include_router(stt_router)
43
+ app.include_router(list_models_router)
44
+ app.include_router(misc_router)
45
+
46
  if config.allow_origins is not None:
47
  app.add_middleware(
48
  CORSMiddleware,
 
52
  allow_headers=["*"],
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if config.enable_ui:
56
  import gradio as gr
57
 
src/faster_whisper_server/model_manager.py CHANGED
@@ -41,3 +41,6 @@ class ModelManager:
41
  )
42
  self.loaded_models[model_name] = whisper
43
  return whisper
 
 
 
 
41
  )
42
  self.loaded_models[model_name] = whisper
43
  return whisper
44
+
45
+
46
+ model_manager = ModelManager()
src/faster_whisper_server/routers/__init__.py ADDED
File without changes
src/faster_whisper_server/routers/list_models.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Annotated
4
+
5
+ from fastapi import (
6
+ APIRouter,
7
+ HTTPException,
8
+ Path,
9
+ )
10
+ import huggingface_hub
11
+
12
+ from faster_whisper_server.server_models import (
13
+ ModelListResponse,
14
+ ModelObject,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ from huggingface_hub.hf_api import ModelInfo
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ @router.get("/v1/models")
24
+ def get_models() -> ModelListResponse:
25
+ models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
26
+ models = list(models)
27
+ models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
28
+ transformed_models: list[ModelObject] = []
29
+ for model in models:
30
+ assert model.created_at is not None
31
+ assert model.card_data is not None
32
+ assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
33
+ if model.card_data.language is None:
34
+ language = []
35
+ elif isinstance(model.card_data.language, str):
36
+ language = [model.card_data.language]
37
+ else:
38
+ language = model.card_data.language
39
+ transformed_model = ModelObject(
40
+ id=model.id,
41
+ created=int(model.created_at.timestamp()),
42
+ object_="model",
43
+ owned_by=model.id.split("/")[0],
44
+ language=language,
45
+ )
46
+ transformed_models.append(transformed_model)
47
+ return ModelListResponse(data=transformed_models)
48
+
49
+
50
+ @router.get("/v1/models/{model_name:path}")
51
+ # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
52
+ def get_model(
53
+ model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
54
+ ) -> ModelObject:
55
+ models = huggingface_hub.list_models(
56
+ model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
57
+ )
58
+ models = list(models)
59
+ models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003
60
+ if len(models) == 0:
61
+ raise HTTPException(status_code=404, detail="Model doesn't exists")
62
+ exact_match: ModelInfo | None = None
63
+ for model in models:
64
+ if model.id == model_name:
65
+ exact_match = model
66
+ break
67
+ if exact_match is None:
68
+ raise HTTPException(
69
+ status_code=404,
70
+ detail=f"Model doesn't exists. Possible matches: {', '.join([model.id for model in models])}",
71
+ )
72
+ assert exact_match.created_at is not None
73
+ assert exact_match.card_data is not None
74
+ assert exact_match.card_data.language is None or isinstance(exact_match.card_data.language, str | list)
75
+ if exact_match.card_data.language is None:
76
+ language = []
77
+ elif isinstance(exact_match.card_data.language, str):
78
+ language = [exact_match.card_data.language]
79
+ else:
80
+ language = exact_match.card_data.language
81
+ return ModelObject(
82
+ id=exact_match.id,
83
+ created=int(exact_match.created_at.timestamp()),
84
+ object_="model",
85
+ owned_by=exact_match.id.split("/")[0],
86
+ language=language,
87
+ )
src/faster_whisper_server/routers/misc.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+
5
+ from fastapi import (
6
+ APIRouter,
7
+ Response,
8
+ )
9
+ from faster_whisper_server import hf_utils
10
+ from faster_whisper_server.model_manager import model_manager
11
+ import huggingface_hub
12
+ from huggingface_hub.hf_api import RepositoryNotFoundError
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ @router.get("/health")
18
+ def health() -> Response:
19
+ return Response(status_code=200, content="OK")
20
+
21
+
22
+ @router.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.")
23
+ def pull_model(model_name: str) -> Response:
24
+ if hf_utils.does_local_model_exist(model_name):
25
+ return Response(status_code=200, content="Model already exists")
26
+ try:
27
+ huggingface_hub.snapshot_download(model_name, repo_type="model")
28
+ except RepositoryNotFoundError as e:
29
+ return Response(status_code=404, content=str(e))
30
+ return Response(status_code=201, content="Model downloaded")
31
+
32
+
33
+ @router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
34
+ def get_running_models() -> dict[str, list[str]]:
35
+ return {"models": list(model_manager.loaded_models.keys())}
36
+
37
+
38
+ @router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
39
+ def load_model_route(model_name: str) -> Response:
40
+ if model_name in model_manager.loaded_models:
41
+ return Response(status_code=409, content="Model already loaded")
42
+ model_manager.load_model(model_name)
43
+ return Response(status_code=201)
44
+
45
+
46
+ @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
47
+ def stop_running_model(model_name: str) -> Response:
48
+ model = model_manager.loaded_models.get(model_name)
49
+ if model is not None:
50
+ del model_manager.loaded_models[model_name]
51
+ gc.collect()
52
+ return Response(status_code=204)
53
+ return Response(status_code=404)
src/faster_whisper_server/routers/stt.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from io import BytesIO
5
+ from typing import TYPE_CHECKING, Annotated, Literal
6
+
7
+ from fastapi import (
8
+ APIRouter,
9
+ Form,
10
+ Query,
11
+ Response,
12
+ UploadFile,
13
+ WebSocket,
14
+ WebSocketDisconnect,
15
+ )
16
+ from fastapi.responses import StreamingResponse
17
+ from fastapi.websockets import WebSocketState
18
+ from faster_whisper.vad import VadOptions, get_speech_timestamps
19
+ from faster_whisper_server.asr import FasterWhisperASR
20
+ from faster_whisper_server.audio import AudioStream, audio_samples_from_file
21
+ from faster_whisper_server.config import (
22
+ SAMPLES_PER_SECOND,
23
+ Language,
24
+ ResponseFormat,
25
+ Task,
26
+ config,
27
+ )
28
+ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
29
+ from faster_whisper_server.logger import logger
30
+ from faster_whisper_server.model_manager import model_manager
31
+ from faster_whisper_server.server_models import (
32
+ TranscriptionJsonResponse,
33
+ TranscriptionVerboseJsonResponse,
34
+ )
35
+ from faster_whisper_server.transcriber import audio_transcriber
36
+ from pydantic import AfterValidator
37
+
38
+ if TYPE_CHECKING:
39
+ from collections.abc import Generator, Iterable
40
+
41
+ from faster_whisper.transcribe import TranscriptionInfo
42
+
43
+
44
+ router = APIRouter()
45
+
46
+
47
+ def segments_to_response(
48
+ segments: Iterable[Segment],
49
+ transcription_info: TranscriptionInfo,
50
+ response_format: ResponseFormat,
51
+ ) -> Response:
52
+ segments = list(segments)
53
+ if response_format == ResponseFormat.TEXT: # noqa: RET503
54
+ return Response(segments_to_text(segments), media_type="text/plain")
55
+ elif response_format == ResponseFormat.JSON:
56
+ return Response(
57
+ TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
58
+ media_type="application/json",
59
+ )
60
+ elif response_format == ResponseFormat.VERBOSE_JSON:
61
+ return Response(
62
+ TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
63
+ media_type="application/json",
64
+ )
65
+ elif response_format == ResponseFormat.VTT:
66
+ return Response(
67
+ "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
68
+ )
69
+ elif response_format == ResponseFormat.SRT:
70
+ return Response(
71
+ "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
72
+ )
73
+
74
+
75
+ def format_as_sse(data: str) -> str:
76
+ return f"data: {data}\n\n"
77
+
78
+
79
+ def segments_to_streaming_response(
80
+ segments: Iterable[Segment],
81
+ transcription_info: TranscriptionInfo,
82
+ response_format: ResponseFormat,
83
+ ) -> StreamingResponse:
84
+ def segment_responses() -> Generator[str, None, None]:
85
+ for i, segment in enumerate(segments):
86
+ if response_format == ResponseFormat.TEXT:
87
+ data = segment.text
88
+ elif response_format == ResponseFormat.JSON:
89
+ data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
90
+ elif response_format == ResponseFormat.VERBOSE_JSON:
91
+ data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
92
+ elif response_format == ResponseFormat.VTT:
93
+ data = segments_to_vtt(segment, i)
94
+ elif response_format == ResponseFormat.SRT:
95
+ data = segments_to_srt(segment, i)
96
+ yield format_as_sse(data)
97
+
98
+ return StreamingResponse(segment_responses(), media_type="text/event-stream")
99
+
100
+
101
+ def handle_default_openai_model(model_name: str) -> str:
102
+ """Exists because some callers may not be able override the default("whisper-1") model name.
103
+
104
+ For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
105
+ """
106
+ if model_name == "whisper-1":
107
+ logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
108
+ return config.whisper.model
109
+ return model_name
110
+
111
+
112
+ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
113
+
114
+
115
+ @router.post(
116
+ "/v1/audio/translations",
117
+ response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
118
+ )
119
+ def translate_file(
120
+ file: Annotated[UploadFile, Form()],
121
+ model: Annotated[ModelName, Form()] = config.whisper.model,
122
+ prompt: Annotated[str | None, Form()] = None,
123
+ response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
124
+ temperature: Annotated[float, Form()] = 0.0,
125
+ stream: Annotated[bool, Form()] = False,
126
+ ) -> Response | StreamingResponse:
127
+ whisper = model_manager.load_model(model)
128
+ segments, transcription_info = whisper.transcribe(
129
+ file.file,
130
+ task=Task.TRANSLATE,
131
+ initial_prompt=prompt,
132
+ temperature=temperature,
133
+ vad_filter=True,
134
+ )
135
+ segments = Segment.from_faster_whisper_segments(segments)
136
+
137
+ if stream:
138
+ return segments_to_streaming_response(segments, transcription_info, response_format)
139
+ else:
140
+ return segments_to_response(segments, transcription_info, response_format)
141
+
142
+
143
+ # https://platform.openai.com/docs/api-reference/audio/createTranscription
144
+ # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
145
+ @router.post(
146
+ "/v1/audio/transcriptions",
147
+ response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
148
+ )
149
+ def transcribe_file(
150
+ file: Annotated[UploadFile, Form()],
151
+ model: Annotated[ModelName, Form()] = config.whisper.model,
152
+ language: Annotated[Language | None, Form()] = config.default_language,
153
+ prompt: Annotated[str | None, Form()] = None,
154
+ response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
155
+ temperature: Annotated[float, Form()] = 0.0,
156
+ timestamp_granularities: Annotated[
157
+ list[Literal["segment", "word"]],
158
+ Form(alias="timestamp_granularities[]"),
159
+ ] = ["segment"],
160
+ stream: Annotated[bool, Form()] = False,
161
+ hotwords: Annotated[str | None, Form()] = None,
162
+ ) -> Response | StreamingResponse:
163
+ whisper = model_manager.load_model(model)
164
+ segments, transcription_info = whisper.transcribe(
165
+ file.file,
166
+ task=Task.TRANSCRIBE,
167
+ language=language,
168
+ initial_prompt=prompt,
169
+ word_timestamps="word" in timestamp_granularities,
170
+ temperature=temperature,
171
+ vad_filter=True,
172
+ hotwords=hotwords,
173
+ )
174
+ segments = Segment.from_faster_whisper_segments(segments)
175
+
176
+ if stream:
177
+ return segments_to_streaming_response(segments, transcription_info, response_format)
178
+ else:
179
+ return segments_to_response(segments, transcription_info, response_format)
180
+
181
+
182
+ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
183
+ try:
184
+ while True:
185
+ bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
186
+ logger.debug(f"Received {len(bytes_)} bytes of audio data")
187
+ audio_samples = audio_samples_from_file(BytesIO(bytes_))
188
+ audio_stream.extend(audio_samples)
189
+ if audio_stream.duration - config.inactivity_window_seconds >= 0:
190
+ audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
191
+ vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
192
+ # NOTE: This is a synchronous operation that runs every time new data is received.
193
+ # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
194
+ timestamps = get_speech_timestamps(audio.data, vad_opts)
195
+ if len(timestamps) == 0:
196
+ logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
197
+ break
198
+ elif (
199
+ # last speech end time
200
+ config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
201
+ >= config.max_inactivity_seconds
202
+ ):
203
+ logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
204
+ break
205
+ except TimeoutError:
206
+ logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
207
+ except WebSocketDisconnect as e:
208
+ logger.info(f"Client disconnected: {e}")
209
+ audio_stream.close()
210
+
211
+
212
+ @router.websocket("/v1/audio/transcriptions")
213
+ async def transcribe_stream(
214
+ ws: WebSocket,
215
+ model: Annotated[ModelName, Query()] = config.whisper.model,
216
+ language: Annotated[Language | None, Query()] = config.default_language,
217
+ response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
218
+ temperature: Annotated[float, Query()] = 0.0,
219
+ ) -> None:
220
+ await ws.accept()
221
+ transcribe_opts = {
222
+ "language": language,
223
+ "temperature": temperature,
224
+ "vad_filter": True,
225
+ "condition_on_previous_text": False,
226
+ }
227
+ whisper = model_manager.load_model(model)
228
+ asr = FasterWhisperASR(whisper, **transcribe_opts)
229
+ audio_stream = AudioStream()
230
+ async with asyncio.TaskGroup() as tg:
231
+ tg.create_task(audio_receiver(ws, audio_stream))
232
+ async for transcription in audio_transcriber(asr, audio_stream):
233
+ logger.debug(f"Sending transcription: {transcription.text}")
234
+ if ws.client_state == WebSocketState.DISCONNECTED:
235
+ break
236
+
237
+ if response_format == ResponseFormat.TEXT:
238
+ await ws.send_text(transcription.text)
239
+ elif response_format == ResponseFormat.JSON:
240
+ await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
241
+ elif response_format == ResponseFormat.VERBOSE_JSON:
242
+ await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
243
+
244
+ if ws.client_state != WebSocketState.DISCONNECTED:
245
+ logger.info("Closing the connection.")
246
+ await ws.close()