DmitryRyumin commited on
Commit
21c0992
Β·
1 Parent(s): 257d039

Update submit.py

Browse files
Files changed (1) hide show
  1. app/event_handlers/submit.py +117 -121
app/event_handlers/submit.py CHANGED
@@ -6,6 +6,7 @@ License: MIT License
6
  """
7
 
8
  import spaces
 
9
  import torch
10
  import pandas as pd
11
  import cv2
@@ -14,7 +15,6 @@ import gradio as gr
14
  # Importing necessary components for the Gradio app
15
  from app.config import config_data
16
  from app.utils import (
17
- Timer,
18
  convert_video_to_audio,
19
  readetect_speech,
20
  slice_audio,
@@ -55,134 +55,130 @@ def event_handler_submit(
55
  gr.Textbox,
56
  gr.Textbox,
57
  ]:
58
- with Timer() as timer:
59
- if video:
60
- if video.split(".")[-1] == "webm":
61
- video = convert_webm_to_mp4(video)
62
-
63
- audio_file_path = convert_video_to_audio(
64
- file_path=video, sr=config_data.General_SR
65
- )
66
- wav, vad_info = readetect_speech(
67
- file_path=audio_file_path,
68
- read_audio=read_audio,
69
- get_speech_timestamps=get_speech_timestamps,
70
- vad_model=vad_model,
71
- sr=config_data.General_SR,
72
- )
73
 
74
- audio_windows = slice_audio(
75
- start_time=config_data.General_START_TIME,
76
- end_time=int(len(wav)),
77
- win_max_length=int(
78
- config_data.General_WIN_MAX_LENGTH * config_data.General_SR
79
- ),
80
- win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
81
- win_min_length=int(
82
- config_data.General_WIN_MIN_LENGTH * config_data.General_SR
83
- ),
84
- )
85
 
86
- intersections = find_intersections(
87
- x=audio_windows,
88
- y=vad_info,
89
- min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
90
- )
91
 
92
- vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
93
- vfe.preprocess_video()
94
-
95
- transcriptions, total_text = asr(wav, audio_windows)
96
-
97
- window_frames = []
98
- preds_emo = []
99
- preds_sen = []
100
- for w_idx, window in enumerate(audio_windows):
101
- a_w = intersections[w_idx]
102
- if not a_w["speech"]:
103
- a_pred = None
104
- else:
105
- wave = wav[a_w["start"] : a_w["end"]].clone()
106
- a_pred, _ = audio_model(wave)
107
-
108
- v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
109
-
110
- t_pred, _ = text_model(transcriptions[w_idx][0])
111
-
112
- if a_pred:
113
- pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
114
- pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
115
- else:
116
- pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
117
- pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
118
-
119
- frames = list(
120
- range(
121
- int(window["start"] * vfe.fps / config_data.General_SR) + 1,
122
- int(window["end"] * vfe.fps / config_data.General_SR) + 2,
123
- )
124
  )
125
- preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
126
- preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
127
- window_frames.extend(frames)
128
-
129
- if max(window_frames) < vfe.frame_number:
130
- missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
131
- window_frames.extend(missed_frames)
132
- preds_emo.extend([preds_emo[-1]] * len(missed_frames))
133
- preds_sen.extend([preds_sen[-1]] * len(missed_frames))
134
-
135
- df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
136
- df_pred["frames"] = window_frames
137
- df_pred["pred_emo"] = preds_emo
138
- df_pred["pred_sent"] = preds_sen
139
-
140
- df_pred = df_pred.groupby("frames").agg(
141
- {
142
- "pred_emo": calculate_mode,
143
- "pred_sent": calculate_mode,
144
- }
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
148
- num_frames = len(wav)
149
- time_axis = [i / config_data.General_SR for i in range(num_frames)]
150
- plt_audio = plot_audio(
151
- time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)
 
 
 
 
 
 
152
  )
153
-
154
- all_idx_faces = list(vfe.faces[1].keys())
155
- need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
156
- faces = []
157
- for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
158
- cur_face = cv2.resize(
159
- vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
160
  )
161
- faces.append(
162
- display_frame_info(
163
- cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
164
- )
165
- )
166
- plt_faces = plot_images(faces)
167
-
168
- plt_emo = plot_predictions(
169
- df_pred,
170
- "pred_emo",
171
- "Emotion",
172
- list(config_data.General_DICT_EMO),
173
- (12, 2.5),
174
- [i + 1 for i in frame_indices],
175
- 2,
176
- )
177
- plt_sent = plot_predictions(
178
- df_pred,
179
- "pred_sent",
180
- "Sentiment",
181
- list(config_data.General_DICT_SENT),
182
- (12, 1.5),
183
- [i + 1 for i in frame_indices],
184
- 2,
185
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return (
188
  gr.Textbox(
@@ -203,7 +199,7 @@ def event_handler_submit(
203
  visible=True,
204
  ),
205
  gr.Textbox(
206
- value=timer.execution_time,
207
  info=config_data.InformationMessages_INFERENCE_TIME,
208
  container=True,
209
  visible=True,
 
6
  """
7
 
8
  import spaces
9
+ import time
10
  import torch
11
  import pandas as pd
12
  import cv2
 
15
  # Importing necessary components for the Gradio app
16
  from app.config import config_data
17
  from app.utils import (
 
18
  convert_video_to_audio,
19
  readetect_speech,
20
  slice_audio,
 
55
  gr.Textbox,
56
  gr.Textbox,
57
  ]:
58
+ start_time = time.time()
59
+
60
+ if video:
61
+ if video.split(".")[-1] == "webm":
62
+ video = convert_webm_to_mp4(video)
63
+
64
+ audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR)
65
+ wav, vad_info = readetect_speech(
66
+ file_path=audio_file_path,
67
+ read_audio=read_audio,
68
+ get_speech_timestamps=get_speech_timestamps,
69
+ vad_model=vad_model,
70
+ sr=config_data.General_SR,
71
+ )
 
72
 
73
+ audio_windows = slice_audio(
74
+ start_time=config_data.General_START_TIME,
75
+ end_time=int(len(wav)),
76
+ win_max_length=int(config_data.General_WIN_MAX_LENGTH * config_data.General_SR),
77
+ win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
78
+ win_min_length=int(config_data.General_WIN_MIN_LENGTH * config_data.General_SR),
79
+ )
 
 
 
 
80
 
81
+ intersections = find_intersections(
82
+ x=audio_windows,
83
+ y=vad_info,
84
+ min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
85
+ )
86
 
87
+ vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
88
+ vfe.preprocess_video()
89
+
90
+ transcriptions, total_text = asr(wav, audio_windows)
91
+
92
+ window_frames = []
93
+ preds_emo = []
94
+ preds_sen = []
95
+ for w_idx, window in enumerate(audio_windows):
96
+ a_w = intersections[w_idx]
97
+ if not a_w["speech"]:
98
+ a_pred = None
99
+ else:
100
+ wave = wav[a_w["start"] : a_w["end"]].clone()
101
+ a_pred, _ = audio_model(wave)
102
+
103
+ v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
104
+
105
+ t_pred, _ = text_model(transcriptions[w_idx][0])
106
+
107
+ if a_pred:
108
+ pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
109
+ pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
110
+ else:
111
+ pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
112
+ pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
113
+
114
+ frames = list(
115
+ range(
116
+ int(window["start"] * vfe.fps / config_data.General_SR) + 1,
117
+ int(window["end"] * vfe.fps / config_data.General_SR) + 2,
 
118
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
+ preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
121
+ preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
122
+ window_frames.extend(frames)
123
+
124
+ if max(window_frames) < vfe.frame_number:
125
+ missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
126
+ window_frames.extend(missed_frames)
127
+ preds_emo.extend([preds_emo[-1]] * len(missed_frames))
128
+ preds_sen.extend([preds_sen[-1]] * len(missed_frames))
129
+
130
+ df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
131
+ df_pred["frames"] = window_frames
132
+ df_pred["pred_emo"] = preds_emo
133
+ df_pred["pred_sent"] = preds_sen
134
+
135
+ df_pred = df_pred.groupby("frames").agg(
136
+ {
137
+ "pred_emo": calculate_mode,
138
+ "pred_sent": calculate_mode,
139
+ }
140
+ )
141
 
142
+ frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
143
+ num_frames = len(wav)
144
+ time_axis = [i / config_data.General_SR for i in range(num_frames)]
145
+ plt_audio = plot_audio(time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2))
146
+
147
+ all_idx_faces = list(vfe.faces[1].keys())
148
+ need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
149
+ faces = []
150
+ for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
151
+ cur_face = cv2.resize(
152
+ vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
153
  )
154
+ faces.append(
155
+ display_frame_info(
156
+ cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
 
 
 
 
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  )
159
+ plt_faces = plot_images(faces)
160
+
161
+ plt_emo = plot_predictions(
162
+ df_pred,
163
+ "pred_emo",
164
+ "Emotion",
165
+ list(config_data.General_DICT_EMO),
166
+ (12, 2.5),
167
+ [i + 1 for i in frame_indices],
168
+ 2,
169
+ )
170
+ plt_sent = plot_predictions(
171
+ df_pred,
172
+ "pred_sent",
173
+ "Sentiment",
174
+ list(config_data.General_DICT_SENT),
175
+ (12, 1.5),
176
+ [i + 1 for i in frame_indices],
177
+ 2,
178
+ )
179
+
180
+ end_time = time.time()
181
+ inference_time = end_time - start_time
182
 
183
  return (
184
  gr.Textbox(
 
199
  visible=True,
200
  ),
201
  gr.Textbox(
202
+ value=config_data.OtherMessages_SEC.format(inference_time),
203
  info=config_data.InformationMessages_INFERENCE_TIME,
204
  container=True,
205
  visible=True,