File size: 5,520 Bytes
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc7eb1
0874d87
 
 
 
 
 
 
 
 
 
 
 
5fc7eb1
 
 
 
0874d87
 
5fc7eb1
 
 
 
 
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import torch
import librosa
import numpy as np
import torch.nn.functional as F
import os

from encoders.transformer import Wav2Vec2EmotionClassifier

# Define the emotions
emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}

# Load the trained model
model_path = "lora_only_model.pth"
cfg = {
    "model": {
        "encoder": "Wav2Vec2Classifier",
        "optimizer": {
            "name": "Adam",
            "lr": 0.0003,
            "weight_decay": 3e-4
        },
        "l1_lambda": 0.0
    }
}
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict, strict=False)


model.eval()


for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.data}")

# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
MIN_SAMPLES = 10  # or 16000 if you want at least 1 second

# Preprocessing function
def preprocess_audio(file_path, sample_rate=16000):
    """
    Safely loads the file at file_path and returns a (1, samples) torch tensor.
    Returns None if the file is invalid or too short.
    """
    if not file_path or (not os.path.exists(file_path)):
        # file_path could be None or an empty string if user didn't record properly
        return None

    # Load with librosa (which merges to mono by default if multi-channel)
    waveform, sr = librosa.load(file_path, sr=sample_rate)

    # Check length
    if len(waveform) < MIN_SAMPLES:
        return None

    # Convert to torch tensor, shape (1, samples)
    waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)

    return waveform_tensor

# Prediction function
def predict_emotion(audio_file):
    """
    audio_file is a file path from Gradio (type='filepath').
    """
    # Preprocess
    waveform = preprocess_audio(audio_file, sample_rate=16000)
    
    # If invalid or too short, return an error-like message
    if waveform is None:
        return (
            "Audio is too short or invalid. Please record/upload a longer clip.",
            ""
        )

    # Perform inference
    with torch.no_grad():
        logits = model(waveform)
        probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0]

    # Get the predicted class
    predicted_class = np.argmax(probabilities)
    predicted_emotion = label_mapping[str(predicted_class)]

    # Format probabilities for visualization
    probabilities_output = [
        f"""
        <div style='display: flex; align-items: center; margin: 5px 0;'>
            <div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div>
            <div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'>
                <div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div>
            </div>
            <div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div>
        </div>
        """
        for i in range(len(emotions))
    ]

    return predicted_emotion, "\n".join(probabilities_output)

# Create Gradio interface
def gradio_interface(audio):
    detected_emotion, probabilities_html = predict_emotion(audio)
    return detected_emotion, gr.HTML(probabilities_html)

# Define Gradio UI
with gr.Blocks(css="""
    body {
        background-color: #121212;
        color: white;
        font-family: Arial, sans-serif;
    }
    h1 {
        color: #FFA500;
        font-size: 48px;
        text-align: center;
        margin-bottom: 10px;
    }
    p {
        text-align: center;
        font-size: 18px;
    }
    .gradio-row {
        justify-content: center;
        align-items: center;
    }
    #submit_button {
        background-color: #FFA500 !important;
        color: black !important;
        font-size: 18px;
        padding: 10px 20px;
        margin-top: 20px;
    }
    #detected_emotion {
        font-size: 24px;
        font-weight: bold;
        text-align: center;
    }
    .probabilities-container {
        margin-top: 20px;
        padding: 10px;
        background-color: #1F2937;
        border-radius: 8px;
    }
""") as demo:
    gr.Markdown(
        """
        <div>
            <h1>Speech Emotion Recognition</h1>
            <p>๐ŸŽต Upload or record an audio file (max 1 minute) to detect emotions.</p>
            <p>Supported Emotions: ๐Ÿ˜Š Happy | ๐Ÿ˜ญ Sad | ๐Ÿ˜ก Angry | ๐Ÿ˜ Neutral | ๐Ÿ˜จ Fear | ๐Ÿคข Disgust | ๐Ÿ˜ฎ Surprise</p>
        </div>
        """
    )

    with gr.Row():
        with gr.Column(scale=1, elem_id="audio-block"):
            #  type="filepath" means we get a temporary file path from Gradio
            audio_input = gr.Audio(label="๐ŸŽค Record or Upload Audio", type="filepath")
            submit_button = gr.Button("Submit", elem_id="submit_button")
        with gr.Column(scale=1):
            detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion")
            probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities")

    submit_button.click(
        fn=gradio_interface,
        inputs=audio_input,
        outputs=[detected_emotion_label, probabilities_html]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(share=True)