Spaces:
Running
Running
import gradio as gr | |
import librosa | |
import torch | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
processor = WhisperProcessor.from_pretrained("openai/whisper-large") | |
model = SpeechT5ForSpeechToText.from_pretrained("openai/whisper-large") | |
model.config.forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="english", task="transcribe") | |
def process_audio(sampling_rate, waveform): | |
# convert from int16 to floating point | |
waveform = waveform / 32678.0 | |
# convert to mono if stereo | |
if len(waveform.shape) > 1: | |
waveform = librosa.to_mono(waveform.T) | |
# resample to 16 kHz if necessary | |
if sampling_rate != 16000: | |
waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000) | |
# limit to 30 seconds | |
waveform = waveform[:16000*30] | |
# make PyTorch tensor | |
waveform = torch.tensor(waveform) | |
return waveform | |
def predict(audio, mic_audio=None): | |
# audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels)) | |
if mic_audio is not None: | |
sampling_rate, waveform = mic_audio | |
elif audio is not None: | |
sampling_rate, waveform = audio | |
else: | |
return "(please provide audio)" | |
waveform = process_audio(sampling_rate, waveform) | |
input_features = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features, max_length=400) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
return transcription[0] | |
title = "Demo for Whisper -> Something -> XLS-R" | |
description = """ | |
<b>How to use:</b> Upload an audio file or record using the microphone. The audio is converted to mono and resampled to 16 kHz before | |
being passed into the model. The output is the text transcription of the audio. | |
""" | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Audio(label="Upload Speech", source="upload", type="numpy"), | |
gr.Audio(label="Record Speech", source="microphone", type="numpy"), | |
], | |
outputs=[ | |
gr.Text(label="Transcription"), | |
], | |
title=title, | |
article=article, | |
).launch() |