import base64 from functools import partial from multiprocessing import Pool import gradio as gr import numpy as np import requests from processing_whisper import WhisperPrePostProcessor from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE from transformers.pipelines.audio_utils import ffmpeg_read title = "Whisper JAX: The Fastest Whisper API ⚡️" description = "Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **12x** faster, making it the fastest Whisper API available." API_URL = "https://whisper-jax.ngrok.io/generate/" article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX code and Gradio demo by 🤗 Hugging Face." language_names = sorted(TO_LANGUAGE_CODE.keys()) CHUNK_LENGTH_S = 30 BATCH_SIZE = 16 NUM_PROC = 16 def query(payload): response = requests.post(API_URL, json=payload) return response.json(), response.status_code def inference(inputs, language=None, task=None, return_timestamps=False): payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps} # langauge can come as an empty string from the Gradio `None` default, so we handle it separately if language: payload["language"] = language data, status_code = query(payload) if status_code == 200: text = data["text"] else: text = data["detail"] if return_timestamps: timestamps = data["chunks"] else: timestamps = None return text, timestamps def chunked_query(payload): response = requests.post("https://whisper-jax.ngrok.io/generate_from_features", json=payload) return response.json() def forward(batch, task=None, return_timestamps=False): feature_shape = batch["input_features"].shape batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode() outputs = chunked_query( {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape} ) outputs["tokens"] = np.asarray(outputs["tokens"]) return outputs if __name__ == "__main__": processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2") pool = Pool(NUM_PROC) def transcribe_chunked_audio(microphone, file_upload, task, return_timestamps): warn_output = "" if (microphone is not None) and (file_upload is not None): warn_output = ( "WARNING: You've uploaded an audio file and used the microphone. " "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" ) elif (microphone is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" inputs = microphone if microphone is not None else file_upload with open(inputs, "rb") as f: inputs = f.read() inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate) inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate} dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE) model_outputs = pool.map(partial(forward, task=task, return_timestamps=return_timestamps), dataloader) post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps) timestamps = post_processed.get("chunks") return warn_output + post_processed["text"], timestamps def _return_yt_html_embed(yt_url): video_id = yt_url.split("?v=")[-1] HTML_str = ( f'