DmitryRyumin commited on
Commit
cde16b5
Β·
1 Parent(s): 21c0992

Update submit.py

Browse files
Files changed (1) hide show
  1. app/event_handlers/submit.py +121 -117
app/event_handlers/submit.py CHANGED
@@ -6,7 +6,6 @@ License: MIT License
6
  """
7
 
8
  import spaces
9
- import time
10
  import torch
11
  import pandas as pd
12
  import cv2
@@ -15,6 +14,7 @@ import gradio as gr
15
  # Importing necessary components for the Gradio app
16
  from app.config import config_data
17
  from app.utils import (
 
18
  convert_video_to_audio,
19
  readetect_speech,
20
  slice_audio,
@@ -55,130 +55,134 @@ def event_handler_submit(
55
  gr.Textbox,
56
  gr.Textbox,
57
  ]:
58
- start_time = time.time()
59
-
60
- if video:
61
- if video.split(".")[-1] == "webm":
62
- video = convert_webm_to_mp4(video)
63
-
64
- audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR)
65
- wav, vad_info = readetect_speech(
66
- file_path=audio_file_path,
67
- read_audio=read_audio,
68
- get_speech_timestamps=get_speech_timestamps,
69
- vad_model=vad_model,
70
- sr=config_data.General_SR,
71
- )
72
 
73
- audio_windows = slice_audio(
74
- start_time=config_data.General_START_TIME,
75
- end_time=int(len(wav)),
76
- win_max_length=int(config_data.General_WIN_MAX_LENGTH * config_data.General_SR),
77
- win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
78
- win_min_length=int(config_data.General_WIN_MIN_LENGTH * config_data.General_SR),
79
- )
 
 
 
80
 
81
- intersections = find_intersections(
82
- x=audio_windows,
83
- y=vad_info,
84
- min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
85
- )
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
88
- vfe.preprocess_video()
89
-
90
- transcriptions, total_text = asr(wav, audio_windows)
91
-
92
- window_frames = []
93
- preds_emo = []
94
- preds_sen = []
95
- for w_idx, window in enumerate(audio_windows):
96
- a_w = intersections[w_idx]
97
- if not a_w["speech"]:
98
- a_pred = None
99
- else:
100
- wave = wav[a_w["start"] : a_w["end"]].clone()
101
- a_pred, _ = audio_model(wave)
102
-
103
- v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
104
-
105
- t_pred, _ = text_model(transcriptions[w_idx][0])
106
-
107
- if a_pred:
108
- pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
109
- pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
110
- else:
111
- pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
112
- pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
113
-
114
- frames = list(
115
- range(
116
- int(window["start"] * vfe.fps / config_data.General_SR) + 1,
117
- int(window["end"] * vfe.fps / config_data.General_SR) + 2,
 
118
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
- preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
121
- preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
122
- window_frames.extend(frames)
123
-
124
- if max(window_frames) < vfe.frame_number:
125
- missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
126
- window_frames.extend(missed_frames)
127
- preds_emo.extend([preds_emo[-1]] * len(missed_frames))
128
- preds_sen.extend([preds_sen[-1]] * len(missed_frames))
129
-
130
- df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
131
- df_pred["frames"] = window_frames
132
- df_pred["pred_emo"] = preds_emo
133
- df_pred["pred_sent"] = preds_sen
134
-
135
- df_pred = df_pred.groupby("frames").agg(
136
- {
137
- "pred_emo": calculate_mode,
138
- "pred_sent": calculate_mode,
139
- }
140
- )
141
 
142
- frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
143
- num_frames = len(wav)
144
- time_axis = [i / config_data.General_SR for i in range(num_frames)]
145
- plt_audio = plot_audio(time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2))
146
-
147
- all_idx_faces = list(vfe.faces[1].keys())
148
- need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
149
- faces = []
150
- for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
151
- cur_face = cv2.resize(
152
- vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
153
  )
154
- faces.append(
155
- display_frame_info(
156
- cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
 
 
 
 
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  )
159
- plt_faces = plot_images(faces)
160
-
161
- plt_emo = plot_predictions(
162
- df_pred,
163
- "pred_emo",
164
- "Emotion",
165
- list(config_data.General_DICT_EMO),
166
- (12, 2.5),
167
- [i + 1 for i in frame_indices],
168
- 2,
169
- )
170
- plt_sent = plot_predictions(
171
- df_pred,
172
- "pred_sent",
173
- "Sentiment",
174
- list(config_data.General_DICT_SENT),
175
- (12, 1.5),
176
- [i + 1 for i in frame_indices],
177
- 2,
178
- )
179
-
180
- end_time = time.time()
181
- inference_time = end_time - start_time
182
 
183
  return (
184
  gr.Textbox(
@@ -199,7 +203,7 @@ def event_handler_submit(
199
  visible=True,
200
  ),
201
  gr.Textbox(
202
- value=config_data.OtherMessages_SEC.format(inference_time),
203
  info=config_data.InformationMessages_INFERENCE_TIME,
204
  container=True,
205
  visible=True,
 
6
  """
7
 
8
  import spaces
 
9
  import torch
10
  import pandas as pd
11
  import cv2
 
14
  # Importing necessary components for the Gradio app
15
  from app.config import config_data
16
  from app.utils import (
17
+ Timer,
18
  convert_video_to_audio,
19
  readetect_speech,
20
  slice_audio,
 
55
  gr.Textbox,
56
  gr.Textbox,
57
  ]:
58
+ with Timer() as timer:
59
+ if video:
60
+ if video.split(".")[-1] == "webm":
61
+ video = convert_webm_to_mp4(video)
 
 
 
 
 
 
 
 
 
 
62
 
63
+ audio_file_path = convert_video_to_audio(
64
+ file_path=video, sr=config_data.General_SR
65
+ )
66
+ wav, vad_info = readetect_speech(
67
+ file_path=audio_file_path,
68
+ read_audio=read_audio,
69
+ get_speech_timestamps=get_speech_timestamps,
70
+ vad_model=vad_model,
71
+ sr=config_data.General_SR,
72
+ )
73
 
74
+ audio_windows = slice_audio(
75
+ start_time=config_data.General_START_TIME,
76
+ end_time=int(len(wav)),
77
+ win_max_length=int(
78
+ config_data.General_WIN_MAX_LENGTH * config_data.General_SR
79
+ ),
80
+ win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
81
+ win_min_length=int(
82
+ config_data.General_WIN_MIN_LENGTH * config_data.General_SR
83
+ ),
84
+ )
85
+
86
+ intersections = find_intersections(
87
+ x=audio_windows,
88
+ y=vad_info,
89
+ min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
90
+ )
91
 
92
+ vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
93
+ vfe.preprocess_video()
94
+
95
+ transcriptions, total_text = asr(wav, audio_windows)
96
+
97
+ window_frames = []
98
+ preds_emo = []
99
+ preds_sen = []
100
+ for w_idx, window in enumerate(audio_windows):
101
+ a_w = intersections[w_idx]
102
+ if not a_w["speech"]:
103
+ a_pred = None
104
+ else:
105
+ wave = wav[a_w["start"] : a_w["end"]].clone()
106
+ a_pred, _ = audio_model(wave)
107
+
108
+ v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
109
+
110
+ t_pred, _ = text_model(transcriptions[w_idx][0])
111
+
112
+ if a_pred:
113
+ pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
114
+ pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
115
+ else:
116
+ pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
117
+ pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
118
+
119
+ frames = list(
120
+ range(
121
+ int(window["start"] * vfe.fps / config_data.General_SR) + 1,
122
+ int(window["end"] * vfe.fps / config_data.General_SR) + 2,
123
+ )
124
  )
125
+ preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
126
+ preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
127
+ window_frames.extend(frames)
128
+
129
+ if max(window_frames) < vfe.frame_number:
130
+ missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
131
+ window_frames.extend(missed_frames)
132
+ preds_emo.extend([preds_emo[-1]] * len(missed_frames))
133
+ preds_sen.extend([preds_sen[-1]] * len(missed_frames))
134
+
135
+ df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
136
+ df_pred["frames"] = window_frames
137
+ df_pred["pred_emo"] = preds_emo
138
+ df_pred["pred_sent"] = preds_sen
139
+
140
+ df_pred = df_pred.groupby("frames").agg(
141
+ {
142
+ "pred_emo": calculate_mode,
143
+ "pred_sent": calculate_mode,
144
+ }
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
148
+ num_frames = len(wav)
149
+ time_axis = [i / config_data.General_SR for i in range(num_frames)]
150
+ plt_audio = plot_audio(
151
+ time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)
 
 
 
 
 
 
152
  )
153
+
154
+ all_idx_faces = list(vfe.faces[1].keys())
155
+ need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
156
+ faces = []
157
+ for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
158
+ cur_face = cv2.resize(
159
+ vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
160
  )
161
+ faces.append(
162
+ display_frame_info(
163
+ cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
164
+ )
165
+ )
166
+ plt_faces = plot_images(faces)
167
+
168
+ plt_emo = plot_predictions(
169
+ df_pred,
170
+ "pred_emo",
171
+ "Emotion",
172
+ list(config_data.General_DICT_EMO),
173
+ (12, 2.5),
174
+ [i + 1 for i in frame_indices],
175
+ 2,
176
+ )
177
+ plt_sent = plot_predictions(
178
+ df_pred,
179
+ "pred_sent",
180
+ "Sentiment",
181
+ list(config_data.General_DICT_SENT),
182
+ (12, 1.5),
183
+ [i + 1 for i in frame_indices],
184
+ 2,
185
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return (
188
  gr.Textbox(
 
203
  visible=True,
204
  ),
205
  gr.Textbox(
206
+ value=timer.execution_time,
207
  info=config_data.InformationMessages_INFERENCE_TIME,
208
  container=True,
209
  visible=True,