Spaces:
Running
Running
import tempfile | |
import torch | |
import torch.nn.functional as F | |
import torchaudio | |
import gradio as gr | |
from transformers import Wav2Vec2FeatureExtractor, AutoConfig | |
from models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification | |
# Load model and feature extractor | |
config = AutoConfig.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
model = Wav2Vec2ForSpeechClassification.from_pretrained("Gizachew/wev2vec-large960-agu-amharic") | |
sampling_rate = feature_extractor.sampling_rate | |
# Define inputs and outputs for the Gradio interface | |
audio_input = gr.Audio(label="Upload file", type="filepath") | |
text_output = gr.TextArea(label="Emotion Prediction Output", text_align="right", rtl=True, type="text") | |
def SER(audio): | |
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_audio_file: | |
# Copy the contents of the uploaded audio file to the temporary file | |
temp_audio_file.write(open(audio, "rb").read()) | |
temp_audio_file.flush() | |
# Load the audio file using torchaudio | |
speech_array, _sampling_rate = torchaudio.load(temp_audio_file.name) | |
resampler = torchaudio.transforms.Resample(_sampling_rate) | |
speech = resampler(speech_array).squeeze().numpy() | |
inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
inputs = {key: inputs[key] for key in inputs} | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0] | |
# Get the highest score and its corresponding label | |
max_index = scores.argmax() | |
label = config.id2label[max_index] | |
score = scores[max_index] | |
# Format the output string | |
output = f"{label}: {score * 100:.1f}%" | |
return output | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=SER, | |
inputs=audio_input, | |
outputs=text_output | |
) | |
# Launch the Gradio app | |
iface.launch(share=True) | |