omsandeeppatil's picture
Update app.py
d5038df verified
import gradio as gr
import torch
import numpy as np
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
import librosa
# Initialize model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "Hatman/audio-emotion-detection"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
model.to(device)
# Define emotion labels
EMOTION_LABELS = {
0: "angry",
1: "disgust",
2: "fear",
3: "happy",
4: "neutral",
5: "sad",
6: "surprise"
}
def preprocess_audio(audio, target_sr=16000):
"""Enhanced audio preprocessing"""
try:
# Convert to numpy array and ensure float32
audio = np.array(audio, dtype=np.float32)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.T)
# Resample if needed
if target_sr != 16000:
audio = librosa.resample(audio, orig_sr=target_sr, target_sr=16000)
# Apply preprocessing steps
# 1. Noise reduction
audio = librosa.effects.preemphasis(audio)
# 2. Normalize
audio = librosa.util.normalize(audio)
# 3. Voice activity detection
intervals = librosa.effects.split(audio, top_db=20)
if len(intervals) > 0:
audio = np.concatenate([audio[start:end] for start, end in intervals])
# 4. Ensure minimum length (1 second)
if len(audio) < 16000:
audio = np.pad(audio, (0, 16000 - len(audio)))
# 5. Take center 3 seconds if too long
if len(audio) > 48000: # 3 seconds at 16kHz
center = len(audio) // 2
start = center - 24000
end = center + 24000
audio = audio[start:end]
return audio
except Exception as e:
print(f"Preprocessing error: {str(e)}")
return None
def get_emotion_history():
"""Get emotion detection history"""
if not hasattr(get_emotion_history, "history"):
get_emotion_history.history = []
return get_emotion_history.history
def process_audio(audio):
"""Process audio chunk and return emotion"""
if audio is None:
return "No audio input detected"
try:
# Get the audio data
if isinstance(audio, tuple):
audio = audio[1]
# Preprocess audio
processed_audio = preprocess_audio(audio)
if processed_audio is None:
return "Audio preprocessing failed"
if np.max(np.abs(processed_audio)) < 0.01:
return "Audio too quiet"
# Prepare input for the model
inputs = feature_extractor(
processed_audio,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
# Move to device and ensure float32
inputs = {k: v.to(device, dtype=torch.float32) for k, v in inputs.items()}
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get top 2 predictions
top2_probs, top2_ids = torch.topk(probs, 2)
# Convert to percentages
top2_probs = [p * 100 for p in top2_probs.cpu().numpy()]
top2_emotions = [EMOTION_LABELS[idx.item()] for idx in top2_ids]
# Update history
history = get_emotion_history()
history.append(top2_emotions[0])
if len(history) > 5:
history.pop(0)
# Get most common emotion in history
if len(history) >= 3:
from collections import Counter
most_common = Counter(history).most_common(1)[0][0]
else:
most_common = top2_emotions[0]
result = f"Primary: {top2_emotions[0]} ({top2_probs[0]:.1f}%)\n"
result += f"Secondary: {top2_emotions[1]} ({top2_probs[1]:.1f}%)\n"
result += f"Trending: {most_common}"
return result
except Exception as e:
print(f"Error in processing: {str(e)}")
return "Processing error. Please try again."
# Create Gradio interface
demo = gr.Interface(
fn=process_audio,
inputs=[
gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="Speak into your microphone",
show_label=True
)
],
outputs=gr.Textbox(
label="Detected Emotions",
lines=3
),
title="Enhanced Live Emotion Detection",
description="Speak naturally into your microphone. Shows primary and secondary emotions with confidence levels.",
live=True,
allow_flagging=False
)
# Launch with a small queue for better real-time performance
if __name__ == "__main__":
demo.queue(max_size=1).launch(share=True)