Fedir Zadniprovskyi commited on
Commit
3344380
·
1 Parent(s): 883ce74

feat: gradio speech generation tab

Browse files
src/faster_whisper_server/gradio_app.py CHANGED
@@ -7,6 +7,15 @@ from httpx_sse import connect_sse
7
  from openai import OpenAI
8
 
9
  from faster_whisper_server.config import Config, Task
 
 
 
 
 
 
 
 
 
10
 
11
  TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
12
  TRANSLATION_ENDPOINT = "/v1/audio/translations"
@@ -14,12 +23,15 @@ TIMEOUT_SECONDS = 180
14
  TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
15
 
16
 
17
- def create_gradio_demo(config: Config) -> gr.Blocks:
18
  base_url = f"http://{config.host}:{config.port}"
19
  http_client = httpx.Client(base_url=base_url, timeout=TIMEOUT)
20
  openai_client = OpenAI(base_url=f"{base_url}/v1", api_key="cant-be-empty")
21
 
22
- def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
 
 
 
23
  if task == Task.TRANSCRIBE:
24
  endpoint = TRANSCRIPTION_ENDPOINT
25
  elif task == Task.TRANSLATE:
@@ -65,7 +77,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
65
  for event in event_source.iter_sse():
66
  yield event.data
67
 
68
- def update_model_dropdown() -> gr.Dropdown:
69
  models = openai_client.models.list().data
70
  model_names: list[str] = [model.id for model in models]
71
  assert config.whisper.model in model_names
@@ -73,37 +85,100 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
73
  other_models = [model for model in model_names if model not in recommended_models]
74
  model_names = list(recommended_models) + other_models
75
  return gr.Dropdown(
76
- # no idea why it's complaining
77
- choices=model_names, # pyright: ignore[reportArgumentType]
78
  label="Model",
79
  value=config.whisper.model,
80
  )
81
 
82
- model_dropdown = gr.Dropdown(
83
- choices=[config.whisper.model],
84
- label="Model",
85
- value=config.whisper.model,
86
- )
87
- task_dropdown = gr.Dropdown(
88
- choices=[task.value for task in Task],
89
- label="Task",
90
- value=Task.TRANSCRIBE,
91
- )
92
- temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
93
- stream_checkbox = gr.Checkbox(label="Stream", value=True)
94
- with gr.Interface(
95
- title="Whisper Playground",
96
- description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""", # noqa: E501
97
- inputs=[
98
- gr.Audio(type="filepath"),
99
- model_dropdown,
100
- task_dropdown,
101
- temperature_slider,
102
- stream_checkbox,
103
- ],
104
- fn=handler,
105
- outputs="text",
106
- analytics_enabled=False, # disable telemetry
107
- ) as demo:
108
- demo.load(update_model_dropdown, inputs=None, outputs=model_dropdown)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return demo
 
7
  from openai import OpenAI
8
 
9
  from faster_whisper_server.config import Config, Task
10
+ from faster_whisper_server.hf_utils import PiperModel
11
+
12
+ # FIX: this won't work on ARM
13
+ from faster_whisper_server.routers.speech import (
14
+ DEFAULT_VOICE,
15
+ MAX_SAMPLE_RATE,
16
+ MIN_SAMPLE_RATE,
17
+ SUPPORTED_RESPONSE_FORMATS,
18
+ )
19
 
20
  TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
21
  TRANSLATION_ENDPOINT = "/v1/audio/translations"
 
23
  TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
24
 
25
 
26
+ def create_gradio_demo(config: Config) -> gr.Blocks: # noqa: C901, PLR0915
27
  base_url = f"http://{config.host}:{config.port}"
28
  http_client = httpx.Client(base_url=base_url, timeout=TIMEOUT)
29
  openai_client = OpenAI(base_url=f"{base_url}/v1", api_key="cant-be-empty")
30
 
31
+ # TODO: make async
32
+ def whisper_handler(
33
+ file_path: str, model: str, task: Task, temperature: float, stream: bool
34
+ ) -> Generator[str, None, None]:
35
  if task == Task.TRANSCRIBE:
36
  endpoint = TRANSCRIPTION_ENDPOINT
37
  elif task == Task.TRANSLATE:
 
77
  for event in event_source.iter_sse():
78
  yield event.data
79
 
80
+ def update_whisper_model_dropdown() -> gr.Dropdown:
81
  models = openai_client.models.list().data
82
  model_names: list[str] = [model.id for model in models]
83
  assert config.whisper.model in model_names
 
85
  other_models = [model for model in model_names if model not in recommended_models]
86
  model_names = list(recommended_models) + other_models
87
  return gr.Dropdown(
88
+ choices=model_names,
 
89
  label="Model",
90
  value=config.whisper.model,
91
  )
92
 
93
+ def update_piper_voices_dropdown() -> gr.Dropdown:
94
+ res = http_client.get("/v1/audio/speech/voices").raise_for_status()
95
+ piper_models = [PiperModel.model_validate(x) for x in res.json()]
96
+ return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
97
+
98
+ # TODO: make async
99
+ def handle_audio_speech(text: str, voice: str, response_format: str, speed: float, sample_rate: int | None) -> Path:
100
+ res = openai_client.audio.speech.create(
101
+ input=text,
102
+ model="piper",
103
+ voice=voice, # pyright: ignore[reportArgumentType]
104
+ response_format=response_format, # pyright: ignore[reportArgumentType]
105
+ speed=speed,
106
+ extra_body={"sample_rate": sample_rate},
107
+ )
108
+ audio_bytes = res.response.read()
109
+ file_path = Path(f"audio.{response_format}")
110
+ with file_path.open("wb") as file:
111
+ file.write(audio_bytes)
112
+ return file_path
113
+
114
+ with gr.Blocks(title="faster-whisper-server Playground") as demo:
115
+ gr.Markdown(
116
+ "### Consider supporting the project by starring the [repository on GitHub](https://github.com/fedirz/faster-whisper-server)."
117
+ )
118
+ with gr.Tab(label="Transcribe/Translate"):
119
+ audio = gr.Audio(type="filepath")
120
+ model_dropdown = gr.Dropdown(
121
+ choices=[config.whisper.model],
122
+ label="Model",
123
+ value=config.whisper.model,
124
+ )
125
+ task_dropdown = gr.Dropdown(
126
+ choices=[task.value for task in Task],
127
+ label="Task",
128
+ value=Task.TRANSCRIBE,
129
+ )
130
+ temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
131
+ stream_checkbox = gr.Checkbox(label="Stream", value=True)
132
+ button = gr.Button("Generate")
133
+
134
+ output = gr.Textbox()
135
+
136
+ # NOTE: the inputs order must match the `whisper_handler` signature
137
+ button.click(
138
+ whisper_handler, [audio, model_dropdown, task_dropdown, temperature_slider, stream_checkbox], output
139
+ )
140
+
141
+ with gr.Tab(label="Speech Generation"):
142
+ # TODO: add warning about ARM
143
+ text = gr.Textbox(label="Input Text")
144
+ voice_dropdown = gr.Dropdown(
145
+ choices=["en_US-amy-medium"],
146
+ label="Voice",
147
+ value="en_US-amy-medium",
148
+ info="""
149
+ The last part of the voice name is the quality (x_low, low, medium, high).
150
+ Each quality has a different default sample rate:
151
+ - x_low: 16000 Hz
152
+ - low: 16000 Hz
153
+ - medium: 22050 Hz
154
+ - high: 22050 Hz
155
+ """,
156
+ )
157
+ response_fromat_dropdown = gr.Dropdown(
158
+ choices=SUPPORTED_RESPONSE_FORMATS,
159
+ label="Response Format",
160
+ value="wav",
161
+ )
162
+ speed_slider = gr.Slider(minimum=0.25, maximum=4.0, step=0.05, label="Speed", value=1.0)
163
+ sample_rate_slider = gr.Number(
164
+ minimum=MIN_SAMPLE_RATE,
165
+ maximum=MAX_SAMPLE_RATE,
166
+ label="Desired Sample Rate",
167
+ info="""
168
+ Setting this will resample the generated audio to the desired sample rate.
169
+ You may want to set this if you are going to use voices of different qualities but want to keep the same sample rate.
170
+ Default: None (No resampling)
171
+ """,
172
+ value=lambda: None,
173
+ )
174
+ button = gr.Button("Generate Speech")
175
+ output = gr.Audio(type="filepath")
176
+ button.click(
177
+ handle_audio_speech,
178
+ [text, voice_dropdown, response_fromat_dropdown, speed_slider, sample_rate_slider],
179
+ output,
180
+ )
181
+
182
+ demo.load(update_whisper_model_dropdown, inputs=None, outputs=model_dropdown)
183
+ demo.load(update_piper_voices_dropdown, inputs=None, outputs=voice_dropdown)
184
  return demo
src/faster_whisper_server/hf_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  from collections.abc import Generator
2
- from functools import lru_cache
3
  import json
4
  import logging
5
  from pathlib import Path
@@ -8,7 +8,7 @@ from typing import Any, Literal
8
 
9
  import huggingface_hub
10
  from huggingface_hub.constants import HF_HUB_CACHE
11
- from pydantic import BaseModel
12
 
13
  from faster_whisper_server.api_models import Model
14
 
@@ -95,13 +95,51 @@ def get_whisper_models() -> Generator[Model, None, None]:
95
  yield transformed_model
96
 
97
 
 
 
 
 
 
 
 
 
 
98
  class PiperModel(BaseModel):
99
- id: str
 
100
  object: Literal["model"] = "model"
101
  created: int
102
  owned_by: Literal["rhasspy"] = "rhasspy"
103
- path: Path
104
- config_path: Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
@@ -151,12 +189,9 @@ def list_model_files(
151
  def list_piper_models() -> Generator[PiperModel, None, None]:
152
  model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
153
  for model_weights_file in model_weights_files:
154
- model_config_file = model_weights_file.with_suffix(".json")
155
  yield PiperModel(
156
- id=model_weights_file.name,
157
  created=int(model_weights_file.stat().st_mtime),
158
- path=model_weights_file,
159
- config_path=model_config_file,
160
  )
161
 
162
 
 
1
  from collections.abc import Generator
2
+ from functools import cached_property, lru_cache
3
  import json
4
  import logging
5
  from pathlib import Path
 
8
 
9
  import huggingface_hub
10
  from huggingface_hub.constants import HF_HUB_CACHE
11
+ from pydantic import BaseModel, Field, computed_field
12
 
13
  from faster_whisper_server.api_models import Model
14
 
 
95
  yield transformed_model
96
 
97
 
98
+ PiperVoiceQuality = Literal["x_low", "low", "medium", "high"]
99
+ PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP: dict[PiperVoiceQuality, int] = {
100
+ "x_low": 16000,
101
+ "low": 22050,
102
+ "medium": 22050,
103
+ "high": 22050,
104
+ }
105
+
106
+
107
  class PiperModel(BaseModel):
108
+ """Similar structure to the GET /v1/models response but with extra fields."""
109
+
110
  object: Literal["model"] = "model"
111
  created: int
112
  owned_by: Literal["rhasspy"] = "rhasspy"
113
+ model_path: Path = Field(
114
+ examples=[
115
+ "/home/nixos/.cache/huggingface/hub/models--rhasspy--piper-voices/snapshots/3d796cc2f2c884b3517c527507e084f7bb245aea/en/en_US/amy/medium/en_US-amy-medium.onnx"
116
+ ]
117
+ )
118
+
119
+ @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"])
120
+ @cached_property
121
+ def id(self) -> str:
122
+ return f"rhasspy/piper-voices/{self.model_path.name.removesuffix(".onnx")}"
123
+
124
+ @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"])
125
+ @cached_property
126
+ def voice(self) -> str:
127
+ return self.model_path.name.removesuffix(".onnx")
128
+
129
+ @computed_field
130
+ @cached_property
131
+ def config_path(self) -> Path:
132
+ return Path(str(self.model_path) + ".json")
133
+
134
+ @computed_field
135
+ @cached_property
136
+ def quality(self) -> PiperVoiceQuality:
137
+ return self.id.split("-")[-1] # pyright: ignore[reportReturnType]
138
+
139
+ @computed_field
140
+ @cached_property
141
+ def sample_rate(self) -> int:
142
+ return PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP[self.quality]
143
 
144
 
145
  def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
 
189
  def list_piper_models() -> Generator[PiperModel, None, None]:
190
  model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
191
  for model_weights_file in model_weights_files:
 
192
  yield PiperModel(
 
193
  created=int(model_weights_file.stat().st_mtime),
194
+ model_path=model_weights_file,
 
195
  )
196
 
197
 
src/faster_whisper_server/routers/speech.py CHANGED
@@ -12,7 +12,11 @@ from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_v
12
  import soundfile as sf
13
 
14
  from faster_whisper_server.dependencies import PiperModelManagerDependency
15
- from faster_whisper_server.hf_utils import read_piper_voices_config
 
 
 
 
16
 
17
  DEFAULT_MODEL = "piper"
18
  # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
@@ -126,6 +130,14 @@ class CreateSpeechRequestBody(BaseModel):
126
  ],
127
  )
128
  voice: Voice = DEFAULT_VOICE
 
 
 
 
 
 
 
 
129
  response_format: ResponseFormat = Field(
130
  DEFAULT_RESPONSE_FORMAT,
131
  description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501
@@ -136,6 +148,7 @@ class CreateSpeechRequestBody(BaseModel):
136
  """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
137
  sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE)
138
  """Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used.""" # noqa: E501
 
139
 
140
  # TODO: move into `Voice`
141
  @model_validator(mode="after")
@@ -163,3 +176,8 @@ def synthesize(
163
  )
164
 
165
  return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
 
 
 
 
 
 
12
  import soundfile as sf
13
 
14
  from faster_whisper_server.dependencies import PiperModelManagerDependency
15
+ from faster_whisper_server.hf_utils import (
16
+ PiperModel,
17
+ list_piper_models,
18
+ read_piper_voices_config,
19
+ )
20
 
21
  DEFAULT_MODEL = "piper"
22
  # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
 
130
  ],
131
  )
132
  voice: Voice = DEFAULT_VOICE
133
+ """
134
+ The last part of the voice name is the quality (x_low, low, medium, high).
135
+ Each quality has a different default sample rate:
136
+ - x_low: 16000 Hz
137
+ - low: 16000 Hz
138
+ - medium: 22050 Hz
139
+ - high: 22050 Hz
140
+ """
141
  response_format: ResponseFormat = Field(
142
  DEFAULT_RESPONSE_FORMAT,
143
  description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501
 
148
  """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
149
  sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE)
150
  """Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used.""" # noqa: E501
151
+ # TODO: document default sample rate for each voice quality
152
 
153
  # TODO: move into `Voice`
154
  @model_validator(mode="after")
 
176
  )
177
 
178
  return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
179
+
180
+
181
+ @router.get("/v1/audio/speech/voices")
182
+ def list_voices() -> list[PiperModel]:
183
+ return list(list_piper_models())