omsandeeppatil commited on
Commit
0a54d22
·
verified ·
1 Parent(s): db29d72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -69
app.py CHANGED
@@ -2,82 +2,87 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
5
- from queue import Queue
6
- import threading
7
- import numpy as np
8
 
9
- # Check for device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- # Model setup
13
  model_name = "Hatman/audio-emotion-detection"
14
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
15
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name).to(device)
16
-
17
- # Real-time audio processing setup
18
- def preprocess_audio_chunk(audio_chunk, sampling_rate):
19
- resampled_waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(audio_chunk)
20
- return {'speech': resampled_waveform.numpy().flatten(), 'sampling_rate': 16000}
21
-
22
- def inference_chunk(audio_chunk, sampling_rate):
23
- example = preprocess_audio_chunk(audio_chunk, sampling_rate)
24
- inputs = feature_extractor(example['speech'], sampling_rate=16000, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
26
  with torch.no_grad():
27
- logits = model(**inputs).logits
28
- predicted_ids = torch.argmax(logits, dim=-1)
29
- emotion = model.config.id2label[predicted_ids.item()]
 
 
 
 
 
 
 
30
  return emotion
31
 
32
- # Queue for processing audio chunks
33
- audio_queue = Queue()
34
- results_queue = Queue()
35
-
36
- # Thread for processing audio in real-time
37
- def audio_processing_thread():
38
- while True:
39
- if not audio_queue.empty():
40
- audio_chunk, sampling_rate = audio_queue.get()
41
- emotion = inference_chunk(audio_chunk, sampling_rate)
42
- results_queue.put(emotion)
43
-
44
- processing_thread = threading.Thread(target=audio_processing_thread, daemon=True)
45
- processing_thread.start()
46
-
47
- # Gradio interface for real-time streaming
48
- def real_time_inference_live(microphone_audio):
49
- waveform = torch.tensor(microphone_audio["array"]).float()
50
- sampling_rate = microphone_audio["sampling_rate"]
51
-
52
- # Chunk size in samples (5 seconds chunks)
53
- chunk_size = int(5 * sampling_rate)
54
-
55
- # Process each chunk and collect live emotions
56
- emotions = []
57
- for start in range(0, len(waveform), chunk_size):
58
- end = min(start + chunk_size, len(waveform))
59
- audio_chunk = waveform[start:end]
60
- if audio_chunk.size(0) > 0:
61
- audio_queue.put((audio_chunk, sampling_rate))
62
-
63
- # Retrieve results from the results queue
64
- while not results_queue.empty():
65
- emotion = results_queue.get()
66
- emotions.append(emotion)
67
-
68
- return "\n".join(emotions)
69
-
70
  with gr.Blocks() as demo:
71
- gr.Markdown("# Live Emotion Detection from Audio")
72
-
73
- audio_input = gr.Audio(streaming=True, label="Real-Time Audio Input", type="numpy")
74
- emotion_output = gr.Textbox(label="Detected Emotions", lines=10)
75
-
76
- def stream_audio_live(audio):
77
- return real_time_inference_live(audio)
78
-
79
- audio_input.stream(stream_audio_live, outputs=emotion_output)
80
-
81
- gr.Markdown("This application processes audio in 5-second chunks and detects emotions in real-time.")
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  demo.launch(share=True)
 
2
  import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
 
 
 
5
 
6
+ # Initialize device and model
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
8
  model_name = "Hatman/audio-emotion-detection"
9
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
10
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
11
+
12
+ # Define emotion labels
13
+ EMOTION_LABELS = {
14
+ 0: "angry",
15
+ 1: "disgust",
16
+ 2: "fear",
17
+ 3: "happy",
18
+ 4: "neutral",
19
+ 5: "sad",
20
+ 6: "surprise"
21
+ }
22
+
23
+ def preprocess_audio(audio):
24
+ """Preprocess audio file for model input"""
25
+ waveform, sampling_rate = torchaudio.load(audio)
26
+ resampled_waveform = torchaudio.transforms.Resample(
27
+ orig_freq=sampling_rate,
28
+ new_freq=16000
29
+ )(waveform)
30
+ return {
31
+ 'speech': resampled_waveform.numpy().flatten(),
32
+ 'sampling_rate': 16000
33
+ }
34
+
35
+ def inference(audio):
36
+ """Full inference function returning emotion, logits, and predicted IDs"""
37
+ example = preprocess_audio(audio)
38
+ inputs = feature_extractor(
39
+ example['speech'],
40
+ sampling_rate=16000,
41
+ return_tensors="pt",
42
+ padding=True
43
+ )
44
+
45
+ # Move inputs to appropriate device
46
  inputs = {k: v.to(device) for k, v in inputs.items()}
47
+
48
  with torch.no_grad():
49
+ outputs = model(**inputs)
50
+ logits = outputs.logits
51
+ predicted_ids = torch.argmax(logits, dim=-1)
52
+
53
+ predicted_emotion = EMOTION_LABELS[predicted_ids.item()]
54
+ return predicted_emotion, logits.tolist(), predicted_ids.tolist()
55
+
56
+ def inference_label(audio):
57
+ """Simplified inference function returning only the emotion label"""
58
+ emotion, _, _ = inference(audio)
59
  return emotion
60
 
61
+ # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("# Audio Emotion Detection")
64
+
65
+ with gr.Tab("Quick Analysis"):
66
+ gr.Interface(
67
+ fn=inference_label,
68
+ inputs=gr.Audio(type="filepath"),
69
+ outputs=gr.Label(label="Detected Emotion"),
70
+ title="Audio Emotion Analysis",
71
+ description="Upload or record audio to detect the emotional content."
72
+ )
73
+
74
+ with gr.Tab("Detailed Analysis"):
75
+ gr.Interface(
76
+ fn=inference,
77
+ inputs=gr.Audio(type="filepath"),
78
+ outputs=[
79
+ gr.Label(label="Detected Emotion"),
80
+ gr.JSON(label="Confidence Scores"),
81
+ gr.JSON(label="Internal IDs")
82
+ ],
83
+ title="Audio Emotion Analysis (Detailed)",
84
+ description="Get detailed analysis including confidence scores for each emotion."
85
+ )
86
+
87
+ # Launch the app
88
  demo.launch(share=True)