TTS / app.py
rawanahmed's picture
Update app.py
7e676bf verified
import gradio as gr
from datasets import load_dataset
import torch
from transformers import SpeechT5ForSpeechToText, SpeechT5Processor
# Load the English subset of the VoxPopuli dataset
dataset = load_dataset("facebook/voxpopuli", "en")
# Example function to load audio and transcriptions
def get_sample(dataset):
# Get a random sample from the training set
sample = dataset['train'][0] # You can modify to pick a random sample or any sample index
audio_file = sample["audio"]["path"]
transcription = sample["sentence"]
return audio_file, transcription
# Initialize the SpeechT5 model and processor
processor = SpeechT5Processor.from_pretrained("facebook/speech_t5_base")
model = SpeechT5ForSpeechToText.from_pretrained("facebook/speech_t5_base")
# Example Gradio interface function
def transcribe(audio):
# Process the audio and get transcription
inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
with torch.no_grad():
logits = model(**inputs).logits
transcription = processor.decode(logits[0], skip_special_tokens=True)
return transcription
# Load a sample to check if everything is set up
audio_file, transcription = get_sample(dataset)
# Set up Gradio interface
iface = gr.Interface(fn=transcribe, inputs=gr.Audio(source="upload", type="filepath"), outputs="text")
# Launch the interface
iface.launch()