Fedir Zadniprovskyi commited on
Commit
3e15f14
·
1 Parent(s): fe09516

feat: add a playground

Browse files

Notable change is that whisper model won't be loaded on startup anymore

faster_whisper_server/config.py CHANGED
@@ -168,6 +168,11 @@ class Language(enum.StrEnum):
168
  ZH = "zh"
169
 
170
 
 
 
 
 
 
171
  class WhisperConfig(BaseModel):
172
  model: str = Field(default="Systran/faster-whisper-medium.en")
173
  """
 
168
  ZH = "zh"
169
 
170
 
171
+ class Task(enum.StrEnum):
172
+ TRANSCRIPTION = "transcription"
173
+ TRANSLATION = "translation"
174
+
175
+
176
  class WhisperConfig(BaseModel):
177
  model: str = Field(default="Systran/faster-whisper-medium.en")
178
  """
faster_whisper_server/gradio_app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Generator
3
+
4
+ import gradio as gr
5
+ import httpx
6
+ from httpx_sse import connect_sse
7
+
8
+ from faster_whisper_server.config import Config, Task
9
+
10
+ TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
11
+ TRANSLATION_ENDPOINT = "/v1/audio/translations"
12
+
13
+
14
+ def create_gradio_demo(config: Config) -> gr.Blocks:
15
+ host = os.getenv("UVICORN_HOST", "0.0.0.0")
16
+ port = os.getenv("UVICORN_PORT", 8000)
17
+ # NOTE: worth looking into generated clients
18
+ http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
19
+
20
+ def handler(
21
+ file_path: str | None, model: str, task: Task, temperature: float, stream: bool
22
+ ) -> Generator[str, None, None]:
23
+ if file_path is None:
24
+ yield ""
25
+ return
26
+ if stream:
27
+ yield from transcribe_audio_streaming(file_path, task, temperature, model)
28
+ yield transcribe_audio(file_path, task, temperature, model)
29
+
30
+ def transcribe_audio(
31
+ file_path: str, task: Task, temperature: float, model: str
32
+ ) -> str:
33
+ if task == Task.TRANSCRIPTION:
34
+ endpoint = TRANSCRIPTION_ENDPOINT
35
+ elif task == Task.TRANSLATION:
36
+ endpoint = TRANSLATION_ENDPOINT
37
+
38
+ with open(file_path, "rb") as file:
39
+ response = http_client.post(
40
+ endpoint,
41
+ files={"file": file},
42
+ data={
43
+ "model": model,
44
+ "response_format": "text",
45
+ "temperature": temperature,
46
+ },
47
+ )
48
+
49
+ response.raise_for_status()
50
+ return response.text
51
+
52
+ def transcribe_audio_streaming(
53
+ file_path: str, task: Task, temperature: float, model: str
54
+ ) -> Generator[str, None, None]:
55
+ with open(file_path, "rb") as file:
56
+ kwargs = {
57
+ "files": {"file": file},
58
+ "data": {
59
+ "response_format": "text",
60
+ "temperature": temperature,
61
+ "model": model,
62
+ "stream": True,
63
+ },
64
+ }
65
+ endpoint = (
66
+ TRANSCRIPTION_ENDPOINT
67
+ if task == Task.TRANSCRIPTION
68
+ else TRANSLATION_ENDPOINT
69
+ )
70
+ with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
71
+ for event in event_source.iter_sse():
72
+ yield event.data
73
+
74
+ model_dropdown = gr.Dropdown(
75
+ # TODO: use output from /v1/models
76
+ choices=[config.whisper.model],
77
+ label="Model",
78
+ value=config.whisper.model,
79
+ )
80
+ task_dropdown = gr.Dropdown(
81
+ choices=[task.value for task in Task],
82
+ label="Task",
83
+ value=Task.TRANSCRIPTION,
84
+ )
85
+ temperature_slider = gr.Slider(
86
+ minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0
87
+ )
88
+ stream_checkbox = gr.Checkbox(label="Stream", value=True)
89
+ demo = gr.Interface(
90
+ title="Whisper Playground",
91
+ description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",
92
+ inputs=[
93
+ gr.Audio(type="filepath"),
94
+ model_dropdown,
95
+ task_dropdown,
96
+ temperature_slider,
97
+ stream_checkbox,
98
+ ],
99
+ fn=handler,
100
+ outputs="text",
101
+ )
102
+ return demo
faster_whisper_server/main.py CHANGED
@@ -2,10 +2,10 @@ from __future__ import annotations
2
 
3
  import asyncio
4
  import time
5
- from contextlib import asynccontextmanager
6
  from io import BytesIO
7
  from typing import Annotated, Generator, Iterable, Literal, OrderedDict
8
 
 
9
  import huggingface_hub
10
  from fastapi import (
11
  FastAPI,
@@ -33,8 +33,10 @@ from faster_whisper_server.config import (
33
  SAMPLES_PER_SECOND,
34
  Language,
35
  ResponseFormat,
 
36
  config,
37
  )
 
38
  from faster_whisper_server.logger import logger
39
  from faster_whisper_server.server_models import (
40
  ModelObject,
@@ -71,16 +73,7 @@ def load_model(model_name: str) -> WhisperModel:
71
  return whisper
72
 
73
 
74
- @asynccontextmanager
75
- async def lifespan(_: FastAPI):
76
- load_model(config.whisper.model)
77
- yield
78
- for model in loaded_models.keys():
79
- logger.info(f"Unloading {model}")
80
- del loaded_models[model]
81
-
82
-
83
- app = FastAPI(lifespan=lifespan)
84
 
85
 
86
  @app.get("/health")
@@ -210,7 +203,7 @@ def translate_file(
210
  whisper = load_model(model)
211
  segments, transcription_info = whisper.transcribe(
212
  file.file,
213
- task="translate",
214
  initial_prompt=prompt,
215
  temperature=temperature,
216
  vad_filter=True,
@@ -251,7 +244,7 @@ def transcribe_file(
251
  whisper = load_model(model)
252
  segments, transcription_info = whisper.transcribe(
253
  file.file,
254
- task="transcribe",
255
  language=language,
256
  initial_prompt=prompt,
257
  word_timestamps="word" in timestamp_granularities,
@@ -353,3 +346,6 @@ async def transcribe_stream(
353
  if not ws.client_state == WebSocketState.DISCONNECTED:
354
  logger.info("Closing the connection.")
355
  await ws.close()
 
 
 
 
2
 
3
  import asyncio
4
  import time
 
5
  from io import BytesIO
6
  from typing import Annotated, Generator, Iterable, Literal, OrderedDict
7
 
8
+ import gradio as gr
9
  import huggingface_hub
10
  from fastapi import (
11
  FastAPI,
 
33
  SAMPLES_PER_SECOND,
34
  Language,
35
  ResponseFormat,
36
+ Task,
37
  config,
38
  )
39
+ from faster_whisper_server.gradio_app import create_gradio_demo
40
  from faster_whisper_server.logger import logger
41
  from faster_whisper_server.server_models import (
42
  ModelObject,
 
73
  return whisper
74
 
75
 
76
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  @app.get("/health")
 
203
  whisper = load_model(model)
204
  segments, transcription_info = whisper.transcribe(
205
  file.file,
206
+ task=Task.TRANSLATION,
207
  initial_prompt=prompt,
208
  temperature=temperature,
209
  vad_filter=True,
 
244
  whisper = load_model(model)
245
  segments, transcription_info = whisper.transcribe(
246
  file.file,
247
+ task=Task.TRANSCRIPTION,
248
  language=language,
249
  initial_prompt=prompt,
250
  word_timestamps="word" in timestamp_granularities,
 
346
  if not ws.client_state == WebSocketState.DISCONNECTED:
347
  logger.info("Closing the connection.")
348
  await ws.close()
349
+
350
+
351
+ app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")