Antonio commited on
Commit
df9bdb0
1 Parent(s): 313b56d
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import librosa
8
+ import av
9
+ from transformers import VivitImageProcessor, VivitForVideoClassification
10
+ from transformers import AutoConfig, Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
11
+ from moviepy.editor import VideoFileClip
12
+
13
+ def get_emotion_from_filename(filename):
14
+ parts = filename.split('-')
15
+ emotion_code = int(parts[2])
16
+ emotion_labels = {
17
+ 1: 'neutral',
18
+ 3: 'happy',
19
+ 4: 'sad',
20
+ 5: 'angry',
21
+ 6: 'fearful',
22
+ 7: 'disgust'
23
+ }
24
+ return emotion_labels.get(emotion_code, None)
25
+
26
+ def separate_video_audio(file_path):
27
+ output_dir = './temp/'
28
+ video_path = os.path.join(output_dir, os.path.basename(file_path).replace('.mp4', '_video.mp4'))
29
+ audio_path = os.path.join(output_dir, os.path.basename(file_path).replace('.mp4', '_audio.wav'))
30
+
31
+ video_cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', file_path, '-an', '-c:v', 'libx264', '-preset', 'ultrafast', video_path]
32
+ subprocess.run(video_cmd, check=True)
33
+
34
+ audio_cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', file_path, '-vn', '-acodec', 'pcm_s16le', '-ar', '16000', audio_path]
35
+ subprocess.run(audio_cmd, check=True)
36
+
37
+ return video_path, audio_path
38
+
39
+ def delete_files_in_directory(directory):
40
+ for filename in os.listdir(directory):
41
+ file_path = os.path.join(directory, filename)
42
+ try:
43
+ if os.path.isfile(file_path):
44
+ os.remove(file_path)
45
+ except Exception as e:
46
+ print(f"Failed to delete {file_path}. Reason: {e}")
47
+
48
+ def process_video(file_path):
49
+ container = av.open(file_path)
50
+ indices = sample_frame_indices(clip_len=32, frame_sample_rate=2, seg_len=container.streams.video[0].frames)
51
+ video = read_video_pyav(container=container, indices=indices)
52
+ container.close()
53
+ return video
54
+
55
+ def read_video_pyav(container, indices):
56
+ frames = []
57
+ container.seek(0)
58
+ start_index = indices[0]
59
+ end_index = indices[-1]
60
+ for i, frame in enumerate(container.decode(video=0)):
61
+ if i > end_index:
62
+ break
63
+ if i >= start_index and i in indices:
64
+ frame = frame.reformat(width=224, height=224)
65
+ frames.append(frame)
66
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
67
+
68
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
69
+ converted_len = int(clip_len * frame_sample_rate)
70
+ end_idx = np.random.randint(converted_len, seg_len)
71
+ start_idx = end_idx - converted_len
72
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
73
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
74
+ return indices
75
+
76
+ def video_label_to_emotion(label):
77
+ label_map = {0: 'neutral', 1: 'happy', 2: 'sad', 3: 'angry', 4: 'fearful', 5: 'disgust'}
78
+ label_index = int(label.split('_')[1])
79
+ return label_map.get(label_index, "Unknown Label")
80
+
81
+ def predict_video(file_path, video_model, image_processor):
82
+ video = process_video(file_path)
83
+ inputs = image_processor(list(video), return_tensors="pt")
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ inputs = inputs.to(device)
86
+
87
+ with torch.no_grad():
88
+ outputs = video_model(**inputs)
89
+ logits = outputs.logits
90
+ probs = F.softmax(logits, dim=-1).squeeze()
91
+
92
+ emotion_probabilities = {video_label_to_emotion(video_model.config.id2label[idx]): float(prob) for idx, prob in enumerate(probs)}
93
+ return emotion_probabilities
94
+
95
+ def audio_label_to_emotion(label):
96
+ label_map = {0: 'angry', 1: 'disgust', 2: 'fearful', 3: 'happy', 4: 'neutral', 5: 'sad'}
97
+ label_index = int(label.split('_')[1])
98
+ return label_map.get(label_index, "Unknown Label")
99
+
100
+ def preprocess_and_predict_audio(file_path, model, processor):
101
+ audio_array, _ = librosa.load(file_path, sr=16000)
102
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True, max_length=75275)
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ model = model.to(device)
105
+ inputs = {k: v.to(device) for k, v in inputs.items()}
106
+
107
+ with torch.no_grad():
108
+ output = model(**inputs)
109
+ logits = output.logits
110
+ probabilities = F.softmax(logits, dim=-1)
111
+ emotion_probabilities = {audio_label_to_emotion(model.config.id2label[idx]): float(prob) for idx, prob in enumerate(probabilities[0])}
112
+ return emotion_probabilities
113
+
114
+ def averaging_method(video_prediction, audio_prediction):
115
+ combined_probabilities = {}
116
+ for label in set(video_prediction) | set(audio_prediction):
117
+ combined_probabilities[label] = (video_prediction.get(label, 0) + audio_prediction.get(label, 0)) / 2
118
+ consensus_label = max(combined_probabilities, key=combined_probabilities.get)
119
+ return consensus_label
120
+
121
+ def weighted_average_method(video_prediction, audio_prediction):
122
+ video_weight = 0.88
123
+ audio_weight = 0.6
124
+ combined_probabilities = {}
125
+ for label in set(video_prediction) | set(audio_prediction):
126
+ video_prob = video_prediction.get(label, 0)
127
+ audio_prob = audio_prediction.get(label, 0)
128
+ combined_probabilities[label] = (video_weight * video_prob + audio_weight * audio_prob) / (video_weight + audio_weight)
129
+ consensus_label = max(combined_probabilities, key=combined_probabilities.get)
130
+ return consensus_label
131
+
132
+ def confidence_level_method(video_prediction, audio_prediction, threshold=0.7):
133
+ highest_video_label = max(video_prediction, key=video_prediction.get)
134
+ highest_video_confidence = video_prediction[highest_video_label]
135
+ if highest_video_confidence >= threshold:
136
+ return highest_video_label
137
+ combined_probabilities = {}
138
+ for label in set(video_prediction) | set(audio_prediction):
139
+ video_prob = video_prediction.get(label, 0)
140
+ audio_prob = audio_prediction.get(label, 0)
141
+ combined_probabilities[label] = (video_prob + audio_prob) / 2
142
+ return max(combined_probabilities, key=combined_probabilities.get)
143
+
144
+ def dynamic_weighting_method(video_prediction, audio_prediction):
145
+ combined_probabilities = {}
146
+ for label in set(video_prediction) | set(audio_prediction):
147
+ video_prob = video_prediction.get(label, 0)
148
+ audio_prob = audio_prediction.get(label, 0)
149
+ video_confidence = video_prob / sum(video_prediction.values())
150
+ audio_confidence = audio_prob / sum(audio_prediction.values())
151
+ video_weight = video_confidence / (video_confidence + audio_confidence)
152
+ audio_weight = audio_confidence / (video_confidence + audio_confidence)
153
+ combined_probabilities[label] = (video_weight * video_prob + audio_weight * audio_prob)
154
+ return max(combined_probabilities, key=combined_probabilities.get)
155
+
156
+ def rule_based_method(video_prediction, audio_prediction, threshold=0.5):
157
+ highest_video_label = max(video_prediction, key=video_prediction.get)
158
+ highest_audio_label = max(audio_prediction, key=audio_prediction.get)
159
+ video_confidence = video_prediction[highest_video_label] / sum(video_prediction.values())
160
+ audio_confidence = audio_prediction[highest_audio_label] / sum(audio_prediction.values())
161
+ combined_probabilities = {}
162
+ for label in set(video_prediction) | set(audio_prediction):
163
+ video_prob = video_prediction.get(label, 0)
164
+ audio_prob = audio_prediction.get(label, 0)
165
+ combined_probabilities[label] = (video_prob + audio_prob) / 2
166
+ if (highest_video_label == highest_audio_label and video_confidence > threshold and audio_confidence > threshold):
167
+ return highest_video_label
168
+ elif video_confidence > audio_confidence:
169
+ return highest_video_label
170
+ elif audio_confidence > video_confidence:
171
+ return highest_audio_label
172
+ return max(combined_probabilities, key=combined_probabilities.get)
173
+
174
+ decision_frameworks = {
175
+ "Averaging": averaging_method,
176
+ "Weighted Average": weighted_average_method,
177
+ "Confidence Level": confidence_level_method,
178
+ "Dynamic Weighting": dynamic_weighting_method,
179
+ "Rule-Based": rule_based_method
180
+ }
181
+
182
+ # Define the prediction function
183
+ def predict(video_file, video_model_name, audio_model_name, framework_name):
184
+
185
+ image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
186
+ video_model = torch.load(video_model_name)
187
+
188
+ model_id = "facebook/wav2vec2-large"
189
+ config = AutoConfig.from_pretrained(model_id, num_labels=6)
190
+ audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
191
+ audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
192
+ audio_model.load_state_dict(torch.load(audio_model_name))
193
+ audio_model.eval()
194
+
195
+ delete_directory_path = "./temp/"
196
+
197
+ # Separate video and audio
198
+ video_path, audio_path = separate_video_audio(video_file.name)
199
+
200
+ # Predict video
201
+ video_prediction = predict_video(video_path, video_model, image_processor)
202
+
203
+ # Predict audio
204
+ audio_prediction = preprocess_and_predict_audio(audio_path, audio_model, audio_processor)
205
+
206
+ # Use selected decision framework
207
+ framework_function = decision_frameworks[framework_name]
208
+ consensus_label = framework_function(video_prediction, audio_prediction)
209
+
210
+ # Clean up the temporary files
211
+ delete_files_in_directory(delete_directory_path)
212
+
213
+ return {
214
+ "Video Predictions": video_prediction,
215
+ "Audio Predictions": audio_prediction,
216
+ "Consensus Label": consensus_label
217
+ }
218
+
219
+ # Create Gradio Interface
220
+ inputs = [
221
+ gr.inputs.File(label="Upload Video", type="file"),
222
+ gr.inputs.Dropdown(["video_model_60_acc.pth", "video_model_80_acc.pth"], label="Select Video Model"),
223
+ gr.inputs.Dropdown(["audio_model_state_dict_6e.pth"], label="Select Audio Model"),
224
+ gr.inputs.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
225
+ ]
226
+
227
+ outputs = [
228
+ gr.outputs.JSON(label="Predictions")
229
+ ]
230
+
231
+ iface = gr.Interface(
232
+ fn=predict,
233
+ inputs=inputs,
234
+ outputs=outputs,
235
+ title="Video and Audio Emotion Prediction",
236
+ description="Upload a video to get emotion predictions from selected video and audio models."
237
+ )
238
+
239
+ iface.launch()
audio_model_state_dict_6e.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7de405afabfe8d0b81a95fdc9de37e11d3abb46564e4a5d2f21febb41fd6f0b
3
+ size 1262945578
video_model_60_acc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ad865fb090facae3cdfc80f22ac8aac576945a2a42d19bbc92ae4efe4a68778
3
+ size 354725762
video_model_80_acc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66e87da97d7bea2bf99e8a12dfc56bccd1e54360d3774b0812cd86d76ab93de
3
+ size 354725826