chinmaydan commited on
Commit
47bfd84
·
1 Parent(s): 666f810

accepting input from upload or mic

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -3,13 +3,21 @@ os.system("pip install git+https://github.com/openai/whisper.git")
3
  import gradio as gr
4
  import whisper
5
 
6
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
-
8
  model = whisper.load_model("small")
9
 
 
 
10
 
11
- def inference(audio):
12
- audio = whisper.load_audio(audio)
 
 
 
 
 
 
 
 
13
  audio = whisper.pad_or_trim(audio)
14
 
15
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
@@ -23,6 +31,7 @@ def inference(audio):
23
  return result.text, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
24
 
25
 
 
26
  title = "Demo for Whisper -> Something -> XLS-R"
27
 
28
  description = """
@@ -31,7 +40,7 @@ being passed into the model. The output is the text transcription of the audio.
31
  """
32
 
33
  gr.Interface(
34
- fn=inference,
35
  inputs=[
36
  gr.Audio(label="Upload Speech", source="upload", type="numpy"),
37
  gr.Audio(label="Record Speech", source="microphone", type="numpy"),
 
3
  import gradio as gr
4
  import whisper
5
 
 
 
6
  model = whisper.load_model("small")
7
 
8
+ model.config.forced_decoder_ids = None
9
+
10
 
11
+ def predict(audio, mic_audio=None):
12
+ # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
13
+ if mic_audio is not None:
14
+ sampling_rate, waveform = mic_audio
15
+ elif audio is not None:
16
+ sampling_rate, waveform = audio
17
+ else:
18
+ return "(please provide audio)"
19
+
20
+ audio = whisper.load_audio(waveform)
21
  audio = whisper.pad_or_trim(audio)
22
 
23
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
 
31
  return result.text, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
32
 
33
 
34
+
35
  title = "Demo for Whisper -> Something -> XLS-R"
36
 
37
  description = """
 
40
  """
41
 
42
  gr.Interface(
43
+ fn=predict,
44
  inputs=[
45
  gr.Audio(label="Upload Speech", source="upload", type="numpy"),
46
  gr.Audio(label="Record Speech", source="microphone", type="numpy"),