|
import streamlit as st |
|
import soundfile as sf |
|
import torch |
|
from transformers import AutoModel, AutoFeatureExtractor |
|
import os |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
try: |
|
model = AutoModel.from_pretrained("sami606713/emotion_classification", use_auth_token=token) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained("sami606713/emotion_classification", use_auth_token=token) |
|
except Exception as e: |
|
st.write(f"Error loading model: {e}") |
|
|
|
|
|
st.title("Audio Emotion Classification") |
|
st.write("Upload an audio file and the model will classify the emotion.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "ogg"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
audio_input, sample_rate = sf.read(uploaded_file) |
|
sample_rate = 16000 |
|
|
|
|
|
st.audio(uploaded_file) |
|
|
|
|
|
if st.button("Classifying"): |
|
try: |
|
inputs = feature_extractor(audio_input, sampling_rate=sample_rate, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
embeddings = outputs.pooler_output |
|
|
|
|
|
id2label={ |
|
0:"angry", |
|
1:'calm', |
|
2:'disgust', |
|
3:'fearful', |
|
4:'happy', |
|
5:'neutral', |
|
6:'sad', |
|
7:'surprised' |
|
} |
|
classifier = torch.nn.Linear(embeddings.shape[-1], len(id2label)) |
|
|
|
|
|
logits = classifier(embeddings) |
|
|
|
|
|
predicted_class_idx = logits.argmax(-1).item() |
|
predicted_class = id2label[predicted_class_idx] |
|
|
|
st.write(f"Predicted Emotion: {predicted_class}") |
|
except Exception as e: |
|
st.write(f"Error during classification: {e}") |
|
|