Boltz79's picture
Update app.py
88c3f37 verified
raw
history blame
7.1 kB
import gradio as gr
import librosa
import numpy as np
import os
import tempfile
from collections import Counter
from speechbrain.inference.interfaces import foreign_class
import io
import matplotlib.pyplot as plt
import librosa.display
from PIL import Image # Added for image conversion
# Try to import noisereduce (if not available, noise reduction will be skipped)
try:
import noisereduce as nr
NOISEREDUCE_AVAILABLE = True
except ImportError:
NOISEREDUCE_AVAILABLE = False
# Mapping from emotion labels to emojis
emotion_to_emoji = {
"angry": "😠",
"happy": "😊",
"sad": "😒",
"neutral": "😐",
"excited": "πŸ˜„",
"fear": "😨",
"disgust": "🀒",
"surprise": "😲"
}
def add_emoji_to_label(label):
"""Append an emoji corresponding to the emotion label."""
emoji = emotion_to_emoji.get(label.lower(), "")
return f"{label.capitalize()} {emoji}"
# Load the pre-trained SpeechBrain classifier
classifier = foreign_class(
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
pymodule_file="custom_interface.py",
classname="CustomEncoderWav2vec2Classifier",
run_opts={"device": "cpu"} # Change to {"device": "cuda"} if GPU is available
)
def preprocess_audio(audio_file, apply_noise_reduction=False):
"""Load and preprocess the audio file: convert to 16kHz mono, optionally apply noise reduction, and normalize."""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
y = nr.reduce_noise(y=y, sr=sr)
if np.max(np.abs(y)) > 0:
y = y / np.max(np.abs(y))
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
import soundfile as sf
sf.write(temp_file.name, y, sr)
return temp_file.name
def ensemble_prediction(audio_file, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
"""Split longer audio files into overlapping segments, predict each segment, and return the majority-voted emotion label."""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
total_duration = librosa.get_duration(y=y, sr=sr)
if total_duration <= segment_duration:
temp_file = preprocess_audio(audio_file, apply_noise_reduction)
_, _, _, label = classifier.classify_file(temp_file)
os.remove(temp_file)
return label[0]
step = segment_duration - overlap
segments = []
for start in np.arange(0, total_duration - segment_duration + 0.001, step):
start_sample = int(start * sr)
end_sample = int((start + segment_duration) * sr)
segment_audio = y[start_sample:end_sample]
temp_seg = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
import soundfile as sf
sf.write(temp_seg.name, segment_audio, sr)
segments.append(temp_seg.name)
predictions = []
for seg in segments:
temp_file = preprocess_audio(seg, apply_noise_reduction)
_, _, _, label = classifier.classify_file(temp_file)
predictions.append(label[0]) # Extract the predicted emotion
os.remove(temp_file)
os.remove(seg)
vote = Counter(predictions)
most_common = vote.most_common(1)[0][0]
return most_common
def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False, segment_duration=3.0, overlap=1.0):
"""Predict emotion from an audio file and return the emotion with an emoji."""
try:
if use_ensemble:
label = ensemble_prediction(audio_file, apply_noise_reduction, segment_duration, overlap)
else:
temp_file = preprocess_audio(audio_file, apply_noise_reduction)
result = classifier.classify_file(temp_file)
os.remove(temp_file)
if isinstance(result, tuple) and len(result) > 3:
label = result[3][0] # Extract the predicted emotion label
else:
label = str(result) # Convert to string if unexpected format
return add_emoji_to_label(label.lower()) # Format and add an emoji
except Exception as e:
return f"Error processing file: {str(e)}"
def plot_waveform(audio_file):
"""Generate and return a waveform plot image (as a PIL Image) for the given audio file."""
y, sr = librosa.load(audio_file, sr=16000, mono=True)
plt.figure(figsize=(10, 3))
librosa.display.waveshow(y, sr=sr)
plt.title("Waveform")
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return Image.open(buf)
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
"""Run emotion prediction and generate a waveform plot."""
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
waveform = plot_waveform(audio_file)
return emotion # Ensure emoji is included here
# Build the enhanced UI using Gradio Blocks
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
gr.Markdown(
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
"The prediction is accompanied by an emoji, and you can also view the audio's waveform. "
"Use the options below to adjust ensemble prediction and noise reduction settings."
)
with gr.Tabs():
with gr.TabItem("Emotion Recognition"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
use_ensemble = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
apply_noise_reduction = gr.Checkbox(label="Apply Noise Reduction", value=False)
with gr.Row():
segment_duration = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
overlap = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
predict_button = gr.Button("Predict Emotion")
result_text = gr.Textbox(label="Predicted Emotion")
waveform_image = gr.Image(label="Audio Waveform", type="pil")
predict_button.click(
predict_and_plot,
inputs=[audio_input, use_ensemble, apply_noise_reduction, segment_duration, overlap],
outputs=[result_text, waveform_image]
)
with gr.TabItem("About"):
gr.Markdown("""
**Enhanced Emotion Recognition App**
- **Model:** SpeechBrain's wav2vec2 model fine-tuned on IEMOCAP for emotion recognition.
- **Features:**
- Ensemble Prediction for long audio files.
- Optional Noise Reduction.
- Visualization of the audio waveform.
- Emoji representation of the predicted emotion.
**Credits:**
- [SpeechBrain](https://speechbrain.github.io)
- [Gradio](https://gradio.app)
""")
if __name__ == "__main__":
demo.launch()