Spaces:
Runtime error
Runtime error
import tempfile | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
import gradio as gr | |
from transformers import Wav2Vec2FeatureExtractor,AutoConfig | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification | |
config = AutoConfig.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1") | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1") | |
model = HubertForSpeechClassification.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1") | |
sampling_rate = feature_extractor.sampling_rate | |
audio_input = gr.Audio(label="صوت گفتار فارسی",type="filepath") | |
text_output = gr.TextArea(label="هیجان موجود در صوت گفتار",text_align="right",rtl=True,type="text") | |
def SER(audio): | |
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_audio_file: | |
# Copy the contents of the uploaded audio file to the temporary file | |
temp_audio_file.write(open(audio, "rb").read()) | |
temp_audio_file.flush() | |
# Load the audio file using torchaudio | |
speech_array, _sampling_rate = torchaudio.load(temp_audio_file.name) | |
resampler = torchaudio.transforms.Resample(_sampling_rate) | |
speech = resampler(speech_array).squeeze().numpy() | |
inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
inputs = {key: inputs[key].to(device) for key in inputs} | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0] | |
outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)] | |
return outputs | |
iface = gr.Interface(fn=SER, inputs=audio_input, outputs=text_output) | |
iface.launch(share=False) |