Spaces:
Sleeping
Sleeping
Commit
Β·
cde16b5
1
Parent(s):
21c0992
Update submit.py
Browse files- app/event_handlers/submit.py +121 -117
app/event_handlers/submit.py
CHANGED
@@ -6,7 +6,6 @@ License: MIT License
|
|
6 |
"""
|
7 |
|
8 |
import spaces
|
9 |
-
import time
|
10 |
import torch
|
11 |
import pandas as pd
|
12 |
import cv2
|
@@ -15,6 +14,7 @@ import gradio as gr
|
|
15 |
# Importing necessary components for the Gradio app
|
16 |
from app.config import config_data
|
17 |
from app.utils import (
|
|
|
18 |
convert_video_to_audio,
|
19 |
readetect_speech,
|
20 |
slice_audio,
|
@@ -55,130 +55,134 @@ def event_handler_submit(
|
|
55 |
gr.Textbox,
|
56 |
gr.Textbox,
|
57 |
]:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
video = convert_webm_to_mp4(video)
|
63 |
-
|
64 |
-
audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR)
|
65 |
-
wav, vad_info = readetect_speech(
|
66 |
-
file_path=audio_file_path,
|
67 |
-
read_audio=read_audio,
|
68 |
-
get_speech_timestamps=get_speech_timestamps,
|
69 |
-
vad_model=vad_model,
|
70 |
-
sr=config_data.General_SR,
|
71 |
-
)
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
118 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
)
|
120 |
-
preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
|
121 |
-
preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
|
122 |
-
window_frames.extend(frames)
|
123 |
-
|
124 |
-
if max(window_frames) < vfe.frame_number:
|
125 |
-
missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
|
126 |
-
window_frames.extend(missed_frames)
|
127 |
-
preds_emo.extend([preds_emo[-1]] * len(missed_frames))
|
128 |
-
preds_sen.extend([preds_sen[-1]] * len(missed_frames))
|
129 |
-
|
130 |
-
df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
|
131 |
-
df_pred["frames"] = window_frames
|
132 |
-
df_pred["pred_emo"] = preds_emo
|
133 |
-
df_pred["pred_sent"] = preds_sen
|
134 |
-
|
135 |
-
df_pred = df_pred.groupby("frames").agg(
|
136 |
-
{
|
137 |
-
"pred_emo": calculate_mode,
|
138 |
-
"pred_sent": calculate_mode,
|
139 |
-
}
|
140 |
-
)
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
all_idx_faces = list(vfe.faces[1].keys())
|
148 |
-
need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
|
149 |
-
faces = []
|
150 |
-
for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
|
151 |
-
cur_face = cv2.resize(
|
152 |
-
vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
|
153 |
)
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
)
|
159 |
-
plt_faces = plot_images(faces)
|
160 |
-
|
161 |
-
plt_emo = plot_predictions(
|
162 |
-
df_pred,
|
163 |
-
"pred_emo",
|
164 |
-
"Emotion",
|
165 |
-
list(config_data.General_DICT_EMO),
|
166 |
-
(12, 2.5),
|
167 |
-
[i + 1 for i in frame_indices],
|
168 |
-
2,
|
169 |
-
)
|
170 |
-
plt_sent = plot_predictions(
|
171 |
-
df_pred,
|
172 |
-
"pred_sent",
|
173 |
-
"Sentiment",
|
174 |
-
list(config_data.General_DICT_SENT),
|
175 |
-
(12, 1.5),
|
176 |
-
[i + 1 for i in frame_indices],
|
177 |
-
2,
|
178 |
-
)
|
179 |
-
|
180 |
-
end_time = time.time()
|
181 |
-
inference_time = end_time - start_time
|
182 |
|
183 |
return (
|
184 |
gr.Textbox(
|
@@ -199,7 +203,7 @@ def event_handler_submit(
|
|
199 |
visible=True,
|
200 |
),
|
201 |
gr.Textbox(
|
202 |
-
value=
|
203 |
info=config_data.InformationMessages_INFERENCE_TIME,
|
204 |
container=True,
|
205 |
visible=True,
|
|
|
6 |
"""
|
7 |
|
8 |
import spaces
|
|
|
9 |
import torch
|
10 |
import pandas as pd
|
11 |
import cv2
|
|
|
14 |
# Importing necessary components for the Gradio app
|
15 |
from app.config import config_data
|
16 |
from app.utils import (
|
17 |
+
Timer,
|
18 |
convert_video_to_audio,
|
19 |
readetect_speech,
|
20 |
slice_audio,
|
|
|
55 |
gr.Textbox,
|
56 |
gr.Textbox,
|
57 |
]:
|
58 |
+
with Timer() as timer:
|
59 |
+
if video:
|
60 |
+
if video.split(".")[-1] == "webm":
|
61 |
+
video = convert_webm_to_mp4(video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
audio_file_path = convert_video_to_audio(
|
64 |
+
file_path=video, sr=config_data.General_SR
|
65 |
+
)
|
66 |
+
wav, vad_info = readetect_speech(
|
67 |
+
file_path=audio_file_path,
|
68 |
+
read_audio=read_audio,
|
69 |
+
get_speech_timestamps=get_speech_timestamps,
|
70 |
+
vad_model=vad_model,
|
71 |
+
sr=config_data.General_SR,
|
72 |
+
)
|
73 |
|
74 |
+
audio_windows = slice_audio(
|
75 |
+
start_time=config_data.General_START_TIME,
|
76 |
+
end_time=int(len(wav)),
|
77 |
+
win_max_length=int(
|
78 |
+
config_data.General_WIN_MAX_LENGTH * config_data.General_SR
|
79 |
+
),
|
80 |
+
win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
|
81 |
+
win_min_length=int(
|
82 |
+
config_data.General_WIN_MIN_LENGTH * config_data.General_SR
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
intersections = find_intersections(
|
87 |
+
x=audio_windows,
|
88 |
+
y=vad_info,
|
89 |
+
min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
|
90 |
+
)
|
91 |
|
92 |
+
vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
|
93 |
+
vfe.preprocess_video()
|
94 |
+
|
95 |
+
transcriptions, total_text = asr(wav, audio_windows)
|
96 |
+
|
97 |
+
window_frames = []
|
98 |
+
preds_emo = []
|
99 |
+
preds_sen = []
|
100 |
+
for w_idx, window in enumerate(audio_windows):
|
101 |
+
a_w = intersections[w_idx]
|
102 |
+
if not a_w["speech"]:
|
103 |
+
a_pred = None
|
104 |
+
else:
|
105 |
+
wave = wav[a_w["start"] : a_w["end"]].clone()
|
106 |
+
a_pred, _ = audio_model(wave)
|
107 |
+
|
108 |
+
v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
|
109 |
+
|
110 |
+
t_pred, _ = text_model(transcriptions[w_idx][0])
|
111 |
+
|
112 |
+
if a_pred:
|
113 |
+
pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
|
114 |
+
pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
|
115 |
+
else:
|
116 |
+
pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
|
117 |
+
pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
|
118 |
+
|
119 |
+
frames = list(
|
120 |
+
range(
|
121 |
+
int(window["start"] * vfe.fps / config_data.General_SR) + 1,
|
122 |
+
int(window["end"] * vfe.fps / config_data.General_SR) + 2,
|
123 |
+
)
|
124 |
)
|
125 |
+
preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
|
126 |
+
preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
|
127 |
+
window_frames.extend(frames)
|
128 |
+
|
129 |
+
if max(window_frames) < vfe.frame_number:
|
130 |
+
missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
|
131 |
+
window_frames.extend(missed_frames)
|
132 |
+
preds_emo.extend([preds_emo[-1]] * len(missed_frames))
|
133 |
+
preds_sen.extend([preds_sen[-1]] * len(missed_frames))
|
134 |
+
|
135 |
+
df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
|
136 |
+
df_pred["frames"] = window_frames
|
137 |
+
df_pred["pred_emo"] = preds_emo
|
138 |
+
df_pred["pred_sent"] = preds_sen
|
139 |
+
|
140 |
+
df_pred = df_pred.groupby("frames").agg(
|
141 |
+
{
|
142 |
+
"pred_emo": calculate_mode,
|
143 |
+
"pred_sent": calculate_mode,
|
144 |
+
}
|
145 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
|
148 |
+
num_frames = len(wav)
|
149 |
+
time_axis = [i / config_data.General_SR for i in range(num_frames)]
|
150 |
+
plt_audio = plot_audio(
|
151 |
+
time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
+
|
154 |
+
all_idx_faces = list(vfe.faces[1].keys())
|
155 |
+
need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
|
156 |
+
faces = []
|
157 |
+
for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
|
158 |
+
cur_face = cv2.resize(
|
159 |
+
vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
|
160 |
)
|
161 |
+
faces.append(
|
162 |
+
display_frame_info(
|
163 |
+
cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
|
164 |
+
)
|
165 |
+
)
|
166 |
+
plt_faces = plot_images(faces)
|
167 |
+
|
168 |
+
plt_emo = plot_predictions(
|
169 |
+
df_pred,
|
170 |
+
"pred_emo",
|
171 |
+
"Emotion",
|
172 |
+
list(config_data.General_DICT_EMO),
|
173 |
+
(12, 2.5),
|
174 |
+
[i + 1 for i in frame_indices],
|
175 |
+
2,
|
176 |
+
)
|
177 |
+
plt_sent = plot_predictions(
|
178 |
+
df_pred,
|
179 |
+
"pred_sent",
|
180 |
+
"Sentiment",
|
181 |
+
list(config_data.General_DICT_SENT),
|
182 |
+
(12, 1.5),
|
183 |
+
[i + 1 for i in frame_indices],
|
184 |
+
2,
|
185 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
return (
|
188 |
gr.Textbox(
|
|
|
203 |
visible=True,
|
204 |
),
|
205 |
gr.Textbox(
|
206 |
+
value=timer.execution_time,
|
207 |
info=config_data.InformationMessages_INFERENCE_TIME,
|
208 |
container=True,
|
209 |
visible=True,
|