Spaces:
Sleeping
Sleeping
File size: 5,520 Bytes
0874d87 5fc7eb1 0874d87 5fc7eb1 0874d87 5fc7eb1 0874d87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import gradio as gr
import torch
import librosa
import numpy as np
import torch.nn.functional as F
import os
from encoders.transformer import Wav2Vec2EmotionClassifier
# Define the emotions
emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
# Load the trained model
model_path = "lora_only_model.pth"
cfg = {
"model": {
"encoder": "Wav2Vec2Classifier",
"optimizer": {
"name": "Adam",
"lr": 0.0003,
"weight_decay": 3e-4
},
"l1_lambda": 0.0
}
}
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict, strict=False)
model.eval()
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.data}")
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
# Preprocessing function
def preprocess_audio(file_path, sample_rate=16000):
"""
Safely loads the file at file_path and returns a (1, samples) torch tensor.
Returns None if the file is invalid or too short.
"""
if not file_path or (not os.path.exists(file_path)):
# file_path could be None or an empty string if user didn't record properly
return None
# Load with librosa (which merges to mono by default if multi-channel)
waveform, sr = librosa.load(file_path, sr=sample_rate)
# Check length
if len(waveform) < MIN_SAMPLES:
return None
# Convert to torch tensor, shape (1, samples)
waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
return waveform_tensor
# Prediction function
def predict_emotion(audio_file):
"""
audio_file is a file path from Gradio (type='filepath').
"""
# Preprocess
waveform = preprocess_audio(audio_file, sample_rate=16000)
# If invalid or too short, return an error-like message
if waveform is None:
return (
"Audio is too short or invalid. Please record/upload a longer clip.",
""
)
# Perform inference
with torch.no_grad():
logits = model(waveform)
probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0]
# Get the predicted class
predicted_class = np.argmax(probabilities)
predicted_emotion = label_mapping[str(predicted_class)]
# Format probabilities for visualization
probabilities_output = [
f"""
<div style='display: flex; align-items: center; margin: 5px 0;'>
<div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div>
<div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'>
<div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div>
</div>
<div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div>
</div>
"""
for i in range(len(emotions))
]
return predicted_emotion, "\n".join(probabilities_output)
# Create Gradio interface
def gradio_interface(audio):
detected_emotion, probabilities_html = predict_emotion(audio)
return detected_emotion, gr.HTML(probabilities_html)
# Define Gradio UI
with gr.Blocks(css="""
body {
background-color: #121212;
color: white;
font-family: Arial, sans-serif;
}
h1 {
color: #FFA500;
font-size: 48px;
text-align: center;
margin-bottom: 10px;
}
p {
text-align: center;
font-size: 18px;
}
.gradio-row {
justify-content: center;
align-items: center;
}
#submit_button {
background-color: #FFA500 !important;
color: black !important;
font-size: 18px;
padding: 10px 20px;
margin-top: 20px;
}
#detected_emotion {
font-size: 24px;
font-weight: bold;
text-align: center;
}
.probabilities-container {
margin-top: 20px;
padding: 10px;
background-color: #1F2937;
border-radius: 8px;
}
""") as demo:
gr.Markdown(
"""
<div>
<h1>Speech Emotion Recognition</h1>
<p>๐ต Upload or record an audio file (max 1 minute) to detect emotions.</p>
<p>Supported Emotions: ๐ Happy | ๐ญ Sad | ๐ก Angry | ๐ Neutral | ๐จ Fear | ๐คข Disgust | ๐ฎ Surprise</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1, elem_id="audio-block"):
# type="filepath" means we get a temporary file path from Gradio
audio_input = gr.Audio(label="๐ค Record or Upload Audio", type="filepath")
submit_button = gr.Button("Submit", elem_id="submit_button")
with gr.Column(scale=1):
detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion")
probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities")
submit_button.click(
fn=gradio_interface,
inputs=audio_input,
outputs=[detected_emotion_label, probabilities_html]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)
|