Vishal-Padia commited on
Commit
3d9d3c4
·
verified ·
1 Parent(s): c8fe9e1

Upload speech emotion recognition model

Browse files
Files changed (1) hide show
  1. emotion_predictor.py +157 -0
emotion_predictor.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
6
+ from main import Config, HybridEmotionRecognitionModel, extract_advanced_features
7
+
8
+
9
+ class EmotionPredictor:
10
+ def __init__(self, model_path="best_emotion_model.pth"):
11
+ """
12
+ Initialize the emotion predictor
13
+
14
+ Args:
15
+ model_path (str): Path to the saved model weights
16
+ """
17
+ # Prepare feature extraction specifics
18
+ self.features = Config.FEATURES
19
+
20
+ # Emotion mapping (same as in original script)
21
+ self.emotion_map = {
22
+ "01": "neutral",
23
+ "02": "calm",
24
+ "03": "happy",
25
+ "04": "sad",
26
+ "05": "angry",
27
+ "06": "fearful",
28
+ "07": "disgust",
29
+ "08": "surprised",
30
+ }
31
+
32
+ # Load the model
33
+ # First, prepare a dummy dataset to get the input dimension and number of classes
34
+ dummy_features, dummy_labels = self._prepare_dummy_dataset()
35
+
36
+ # Initialize the model
37
+ self.model = HybridEmotionRecognitionModel(
38
+ input_dim=len(dummy_features[0]), num_classes=len(np.unique(dummy_labels))
39
+ )
40
+
41
+ # Load the saved weights
42
+ self.model.load_state_dict(torch.load(model_path))
43
+ self.model.eval() # Set to evaluation mode
44
+
45
+ # Prepare label encoder
46
+ self.label_encoder = LabelEncoder()
47
+ self.label_encoder.fit(dummy_labels)
48
+
49
+ # Prepare scaler
50
+ self.scaler = StandardScaler()
51
+ self.scaler.fit(dummy_features)
52
+
53
+ def _prepare_dummy_dataset(self):
54
+ """
55
+ Prepare a dummy dataset similar to the original preparation method
56
+
57
+ Returns:
58
+ tuple: Features and labels
59
+ """
60
+ features = []
61
+ labels = []
62
+
63
+ # Walk through all directories and subdirectories
64
+ for root, dirs, files in os.walk(Config.DATA_DIR):
65
+ for filename in files:
66
+ if filename.endswith(".wav"):
67
+ # Full file path
68
+ file_path = os.path.join(root, filename)
69
+
70
+ try:
71
+ # Extract emotion from filename
72
+ emotion_code = filename.split("-")[2]
73
+ emotion = self.emotion_map.get(emotion_code, "unknown")
74
+
75
+ # Extract features
76
+ file_features = extract_advanced_features(file_path)
77
+ features.append(file_features)
78
+ labels.append(emotion)
79
+
80
+ except Exception as e:
81
+ print(f"Error processing {filename}: {e}")
82
+
83
+ # Limit to a small number of files for efficiency
84
+ if len(features) >= 100:
85
+ break
86
+
87
+ if len(features) >= 100:
88
+ break
89
+
90
+ if len(features) >= 100:
91
+ break
92
+
93
+ return np.array(features), np.array(labels)
94
+
95
+ def predict_emotion(self, audio_file_path):
96
+ """
97
+ Predict emotion for a given audio file
98
+
99
+ Args:
100
+ audio_file_path (str): Path to the audio file
101
+
102
+ Returns:
103
+ str: Predicted emotion
104
+ """
105
+ # Extract features
106
+ try:
107
+ features = extract_advanced_features(audio_file_path)
108
+ except Exception as e:
109
+ print(f"Error extracting features: {e}")
110
+ return "Unknown"
111
+
112
+ # Standardize features
113
+ features = self.scaler.transform(features.reshape(1, -1))
114
+
115
+ # Convert to tensor
116
+ features_tensor = torch.FloatTensor(features)
117
+
118
+ # Predict
119
+ with torch.no_grad():
120
+ outputs = self.model(features_tensor)
121
+ _, predicted = torch.max(outputs, 1)
122
+ predicted_label_index = predicted.numpy()[0]
123
+
124
+ # Convert numeric label to emotion string
125
+ return self.label_encoder.classes_[predicted_label_index]
126
+
127
+
128
+ def main():
129
+ # Initialize predictor
130
+ predictor = EmotionPredictor()
131
+
132
+ # Example usage
133
+ print("Emotion Prediction Script")
134
+ print("------------------------")
135
+
136
+ # Prompt user to input audio file path
137
+ while True:
138
+ audio_path = input("Enter the path to an audio file (or 'q' to quit): ").strip()
139
+
140
+ if audio_path.lower() == "q":
141
+ break
142
+
143
+ if not os.path.exists(audio_path):
144
+ print("File does not exist. Please check the path.")
145
+ continue
146
+
147
+ try:
148
+ # Predict emotion
149
+ emotion = predictor.predict_emotion(audio_path)
150
+ print(f"Predicted Emotion: {emotion}")
151
+
152
+ except Exception as e:
153
+ print(f"Error predicting emotion: {e}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()