transcribe / app.py
github-actions[bot]
Sync with https://github.com/mozilla-ai/speech-to-text-finetune
166c454
import os
import gradio as gr
import spaces
from huggingface_hub import get_collection, HfApi
from transformers import pipeline, Pipeline
is_hf_space = os.getenv("IS_HF_SPACE")
def get_dropdown_model_ids():
mozilla_ai_model_ids = []
# Get model ids from collection and append the language in () from the model's metadata
for model_i in get_collection(
"mozilla-ai/common-voice-whisper-67b847a74ad7561781aa10fd"
).items:
model_metadata = HfApi().model_info(model_i.item_id)
language = model_metadata.card_data.model_name.split("on ")[1]
mozilla_ai_model_ids.append(model_i.item_id + f" ({language})")
return (
[""]
+ mozilla_ai_model_ids
+ [
"openai/whisper-tiny (Multilingual)",
"openai/whisper-small (Multilingual)",
"openai/whisper-medium (Multilingual)",
"openai/whisper-large-v3 (Multilingual)",
"openai/whisper-large-v3-turbo (Multilingual)",
]
)
def _load_local_model(model_dir: str) -> Pipeline | str:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
processor = WhisperProcessor.from_pretrained(model_dir)
model = WhisperForConditionalGeneration.from_pretrained(model_dir)
try:
return pipeline(
task="automatic-speech-recognition",
model=model,
processor=processor,
chunk_length_s=30, # max input duration for whisper
)
except Exception as e:
return str(e)
def _load_hf_model(model_repo_id: str) -> Pipeline | str:
try:
return pipeline(
"automatic-speech-recognition",
model=model_repo_id,
chunk_length_s=30, # max input duration for whisper
)
except Exception as e:
return str(e)
# Copied from https://github.com/openai/whisper/blob/517a43ecd132a2089d85f4ebc044728a71d49f6e/whisper/utils.py#L50
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
@spaces.GPU(duration=30)
def transcribe(
dropdown_model_id: str,
hf_model_id: str,
local_model_id: str,
audio: gr.Audio,
show_timestamps: bool,
) -> str:
if dropdown_model_id and not hf_model_id and not local_model_id:
dropdown_model_id = dropdown_model_id.split(" (")[0]
pipe = _load_hf_model(dropdown_model_id)
elif hf_model_id and not local_model_id and not dropdown_model_id:
pipe = _load_hf_model(hf_model_id)
elif local_model_id and not hf_model_id and not dropdown_model_id:
pipe = _load_local_model(local_model_id)
else:
return (
"⚠️ Error: Please select or fill at least and only one of the options above"
)
if isinstance(pipe, str):
# Exception raised when loading
return f"⚠️ Error: {pipe}"
output = pipe(
audio,
generate_kwargs={"task": "transcribe"},
batch_size=16,
return_timestamps=show_timestamps,
)
text = output["text"]
if show_timestamps:
timestamps = output["chunks"]
timestamps = [
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
for chunk in timestamps
]
text = "\n".join(str(feature) for feature in timestamps)
return text
def setup_gradio_demo():
with gr.Blocks() as demo:
gr.Markdown(
""" # 🗣️ Speech-to-Text Transcription
### 1. Select which model to use from one of the options below.
### 2. Record a message or upload an audio file.
### 3. Click Transcribe to see the transcription generated by the model.
"""
)
### Model selection ###
model_ids = get_dropdown_model_ids()
with gr.Row():
with gr.Column():
dropdown_model = gr.Dropdown(
choices=model_ids, label="Option 1: Select a model"
)
with gr.Column():
user_model = gr.Textbox(
label="Option 2: Paste HF model id",
placeholder="my-username/my-whisper-tiny",
)
with gr.Column(visible=not is_hf_space):
local_model = gr.Textbox(
label="Option 3: Paste local path to model directory",
placeholder="artifacts/my-whisper-tiny",
)
### Transcription ###
with gr.Group():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record a message / Upload audio file",
show_download_button=True,
)
timestamps_check = gr.Checkbox(label="Show timestamps")
transcribe_button = gr.Button("Transcribe")
transcribe_output = gr.Text(label="Output")
transcribe_button.click(
fn=transcribe,
inputs=[
dropdown_model,
user_model,
local_model,
audio_input,
timestamps_check,
],
outputs=transcribe_output,
)
demo.launch()
if __name__ == "__main__":
setup_gradio_demo()