import base64 from functools import partial from multiprocessing import Pool import gradio as gr import numpy as np import requests from processing_whisper import WhisperPrePostProcessor from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE from transformers.pipelines.audio_utils import ffmpeg_read title = "Whisper JAX: The Fastest Whisper API ⚡️" description = "Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **12x** faster, making it the fastest Whisper API available." API_URL = "https://whisper-jax.ngrok.io/generate/" article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX code and Gradio demo by 🤗 Hugging Face." language_names = sorted(TO_LANGUAGE_CODE.keys()) CHUNK_LENGTH_S = 30 BATCH_SIZE = 16 NUM_PROC = 16 def query(payload): response = requests.post(API_URL, json=payload) return response.json(), response.status_code def inference(inputs, language=None, task=None, return_timestamps=False): payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps} # langauge can come as an empty string from the Gradio `None` default, so we handle it separately if language: payload["language"] = language data, status_code = query(payload) if status_code == 200: text = data["text"] else: text = data["detail"] if return_timestamps: timestamps = data["chunks"] else: timestamps = None return text, timestamps def chunked_query(payload): response = requests.post("https://whisper-jax.ngrok.io/generate_from_features", json=payload) return response.json() def forward(batch, task=None, return_timestamps=False): feature_shape = batch["input_features"].shape batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode() outputs = chunked_query( {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape} ) outputs["tokens"] = np.asarray(outputs["tokens"]) return outputs if __name__ == "__main__": processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2") pool = Pool(NUM_PROC) def transcribe_chunked_audio(microphone, file_upload, task, return_timestamps): warn_output = "" if (microphone is not None) and (file_upload is not None): warn_output = ( "WARNING: You've uploaded an audio file and used the microphone. " "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" ) elif (microphone is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" inputs = microphone if microphone is not None else file_upload with open(inputs, "rb") as f: inputs = f.read() inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate) inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate} dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE) model_outputs = pool.map(partial(forward, task=task, return_timestamps=return_timestamps), dataloader) post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps) timestamps = post_processed.get("chunks") return warn_output + post_processed["text"], timestamps def _return_yt_html_embed(yt_url): video_id = yt_url.split("?v=")[-1] HTML_str = ( f'
' "
" ) return HTML_str def transcribe_youtube(yt_url, task, return_timestamps): html_embed_str = _return_yt_html_embed(yt_url) text, timestamps = inference(inputs=yt_url, task=task, return_timestamps=return_timestamps) return html_embed_str, text, timestamps audio_chunked = gr.Interface( fn=transcribe_chunked_audio, inputs=[ gr.inputs.Audio(source="microphone", optional=True, type="filepath"), gr.inputs.Audio(source="upload", optional=True, type="filepath"), gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), gr.inputs.Checkbox(default=False, label="Return timestamps"), ], outputs=[ gr.outputs.Textbox(label="Transcription"), gr.outputs.Textbox(label="Timestamps"), ], allow_flagging="never", title=title, description=description, article=article, ) youtube = gr.Interface( fn=transcribe_youtube, inputs=[ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), gr.inputs.Checkbox(default=False, label="Return timestamps"), ], outputs=[ gr.outputs.HTML(label="Video"), gr.outputs.Textbox(label="Transcription"), gr.outputs.Textbox(label="Timestamps"), ], allow_flagging="never", title=title, examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]], cache_examples=False, description=description, article=article, ) demo = gr.Blocks() with demo: gr.TabbedInterface( [audio_chunked, youtube], ["Transcribe Audio", "Transcribe YouTube"] ) demo.queue() demo.launch()