File size: 1,763 Bytes
0c4e1ba
 
 
60620c4
0c4e1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60620c4
 
 
 
0c4e1ba
60620c4
 
0c4e1ba
 
 
60620c4
 
0c4e1ba
60620c4
 
 
0c4e1ba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import streamlit as st
import torch
import librosa
import numpy as np
from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
from transformers import pipeline

# Title of the app
st.title("Emotion Recognition from Speech")

# Upload audio file
uploaded_file = st.file_uploader("Choose an audio file...", type=["wav"])

# Load the model and feature extractor
model = HubertForSequenceClassification.from_pretrained("superb/hubert-large-superb-er")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-large-superb-er")
classifier = pipeline("audio-classification", model="superb/hubert-large-superb-er")

if uploaded_file is not None:
    # Load and preprocess audio file
    speech, sr = librosa.load(uploaded_file, sr=16000, mono=True)
    
    # Display audio player
    st.audio(uploaded_file, format='audio/wav')
    
    # Convert the audio file to the format expected by the classifier
    inputs = feature_extractor(speech, sampling_rate=16000, padding=True, return_tensors="np")

    # Predict emotion using the model directly
    with torch.no_grad():
        inputs_pt = feature_extractor(speech, sampling_rate=16000, padding=True, return_tensors="pt")
        logits = model(**inputs_pt).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        labels = [model.config.id2label[_id] for _id in predicted_ids.tolist()]
    
    # Display the result from the model directly
    st.write("Predicted Emotion:", labels[0])
    
    # Alternatively, using the pipeline
    inputs_ndarray = inputs["input_values"][0]
    results = classifier(inputs_ndarray, top_k=5)
    st.write("Top 5 Predicted Emotions:")
    for result in results:
        st.write(f"{result['label']}: {result['score']:.4f}")