sanchit-gandhi HF staff commited on
Commit
af74e64
1 Parent(s): a1c0e65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -1,15 +1,12 @@
1
  import gradio as gr
2
  import requests
3
- import pytube
4
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
5
- from transformers.pipelines.audio_utils import ffmpeg_read
6
 
7
- title = "Whisper JAX: The Fastest Whisper API Available ⚡️"
8
 
9
- description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **12x** faster, making it the fastest Whisper API available.
10
 
11
- You can submit requests to Whisper JAX through this Gradio Demo, or directly through API calls (see below). This notebook demonstrates how you can run the Whisper JAX model yourself on a TPU v2-8 in a Google Colab: TODO.
12
- """
13
 
14
  API_URL = "https://whisper-jax.ngrok.io/generate/"
15
 
@@ -24,9 +21,13 @@ def query(payload):
24
  return response.json(), response.status_code
25
 
26
 
27
- def inference(inputs, task, return_timestamps):
28
  payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
29
 
 
 
 
 
30
  data, status_code = query(payload)
31
 
32
  if status_code == 200:
@@ -72,22 +73,12 @@ def _return_yt_html_embed(yt_url):
72
 
73
 
74
  def transcribe_youtube(yt_url, task, return_timestamps):
75
- yt = pytube.YouTube(yt_url)
76
  html_embed_str = _return_yt_html_embed(yt_url)
77
- stream = yt.streams.filter(only_audio=True)[0]
78
- stream.download(filename="audio.mp3")
79
 
80
- with open("audio.mp3", "rb") as f:
81
- inputs = f.read()
82
 
83
- inputs = ffmpeg_read(inputs, SAMPLING_RATE)
84
- inputs = {"array": inputs.tolist(), "sampling_rate": SAMPLING_RATE}
85
-
86
- yield html_embed_str, "Video loaded, transcribing audio...", None
87
-
88
- text, timestamps = inference(inputs=inputs, task=task, return_timestamps=return_timestamps)
89
 
90
- yield html_embed_str, text, timestamps
91
 
92
  audio = gr.Interface(
93
  fn=transcribe_audio,
 
1
  import gradio as gr
2
  import requests
 
3
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
 
4
 
 
5
 
6
+ title = "Whisper JAX: The Fastest Whisper API Available ⚡️"
7
 
8
+ description = "Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **12x** faster, making it the fastest Whisper API available."
9
+ #description += "\nYou can submit requests to Whisper JAX through this Gradio Demo, or directly through API calls (see below). This notebook demonstrates how you can run the Whisper JAX model yourself on a TPU v2-8 in a Google Colab: TODO."
10
 
11
  API_URL = "https://whisper-jax.ngrok.io/generate/"
12
 
 
21
  return response.json(), response.status_code
22
 
23
 
24
+ def inference(inputs, language=None, task=None, return_timestamps=False):
25
  payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
26
 
27
+ # langauge can come as an empty string from the Gradio `None` default, so we handle it separately
28
+ if language:
29
+ payload["language"] = language
30
+
31
  data, status_code = query(payload)
32
 
33
  if status_code == 200:
 
73
 
74
 
75
  def transcribe_youtube(yt_url, task, return_timestamps):
 
76
  html_embed_str = _return_yt_html_embed(yt_url)
 
 
77
 
78
+ text, timestamps = inference(inputs=yt_url, task=task, return_timestamps=return_timestamps)
 
79
 
80
+ return html_embed_str, text, timestamps
 
 
 
 
 
81
 
 
82
 
83
  audio = gr.Interface(
84
  fn=transcribe_audio,