Update visualization.py
Browse files- visualization.py +22 -29
visualization.py
CHANGED
@@ -226,36 +226,29 @@ def fill_with_zeros(mse_array, total_frames):
|
|
226 |
return result
|
227 |
|
228 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
combined_mse = np.array([
|
235 |
-
mse_embeddings[start_frame:end_frame],
|
236 |
-
mse_posture[start_frame:end_frame],
|
237 |
-
mse_voice[start_frame:end_frame]
|
238 |
-
])
|
239 |
-
|
240 |
-
# Calculate global min and max for consistent scaling
|
241 |
-
vmin = 0
|
242 |
-
vmax = max(np.max(mse_embeddings), np.max(mse_posture), np.max(mse_voice))
|
243 |
-
|
244 |
-
fig, ax = plt.subplots(figsize=(video_width / 100, 0.4)) # Adjusted figure size
|
245 |
-
im = ax.imshow(combined_mse, aspect='auto', cmap='Reds',
|
246 |
-
extent=[start_frame/desired_fps, end_frame/desired_fps, 0, 3],
|
247 |
-
vmin=vmin, vmax=vmax, interpolation='nearest')
|
248 |
|
|
|
|
|
|
|
|
|
|
|
249 |
ax.set_yticks([0.5, 1.5, 2.5])
|
250 |
ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
|
251 |
|
|
|
|
|
|
|
252 |
# Add vertical line for current time
|
253 |
current_time = t
|
254 |
ax.axvline(x=current_time, color='black', linewidth=2)
|
255 |
|
256 |
# Set x-axis ticks and labels
|
257 |
-
ax.set_xticks([
|
258 |
-
ax.set_xticklabels([
|
259 |
|
260 |
plt.tight_layout(pad=0.5)
|
261 |
|
@@ -280,15 +273,15 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
280 |
width, height = video.w, video.h
|
281 |
total_frames = int(video.duration * desired_fps)
|
282 |
|
283 |
-
#
|
284 |
-
def
|
285 |
-
|
286 |
-
|
287 |
-
return
|
288 |
|
289 |
-
mse_embeddings =
|
290 |
-
mse_posture =
|
291 |
-
mse_voice =
|
292 |
|
293 |
def combine_video_and_heatmap(t):
|
294 |
original_frame = int(t * video.fps)
|
|
|
226 |
return result
|
227 |
|
228 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
|
229 |
+
fig, ax = plt.subplots(figsize=(video_width / 100, 0.4))
|
230 |
+
|
231 |
+
# Create the full heatmap for the entire video duration
|
232 |
+
combined_mse = np.array([mse_embeddings, mse_posture, mse_voice])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
# Use pcolormesh for better performance with large datasets
|
235 |
+
im = ax.pcolormesh(np.arange(total_frames) / desired_fps, [0, 1, 2], combined_mse,
|
236 |
+
cmap='Reds', vmin=0, vmax=np.max(combined_mse))
|
237 |
+
|
238 |
+
ax.set_ylim(0, 3)
|
239 |
ax.set_yticks([0.5, 1.5, 2.5])
|
240 |
ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
|
241 |
|
242 |
+
# Set x-axis to show full video duration
|
243 |
+
ax.set_xlim(0, total_frames / desired_fps)
|
244 |
+
|
245 |
# Add vertical line for current time
|
246 |
current_time = t
|
247 |
ax.axvline(x=current_time, color='black', linewidth=2)
|
248 |
|
249 |
# Set x-axis ticks and labels
|
250 |
+
ax.set_xticks([0, current_time, total_frames / desired_fps])
|
251 |
+
ax.set_xticklabels(['0:00', f'{current_time:.2f}', f'{total_frames / desired_fps:.2f}'], fontsize=6)
|
252 |
|
253 |
plt.tight_layout(pad=0.5)
|
254 |
|
|
|
273 |
width, height = video.w, video.h
|
274 |
total_frames = int(video.duration * desired_fps)
|
275 |
|
276 |
+
# Ensure MSE arrays have the same length as total_frames
|
277 |
+
def pad_mse_array(mse_array, total_frames):
|
278 |
+
if len(mse_array) < total_frames:
|
279 |
+
return np.pad(mse_array, (0, total_frames - len(mse_array)), 'constant', constant_values=0)
|
280 |
+
return mse_array[:total_frames]
|
281 |
|
282 |
+
mse_embeddings = pad_mse_array(mse_embeddings, total_frames)
|
283 |
+
mse_posture = pad_mse_array(mse_posture, total_frames)
|
284 |
+
mse_voice = pad_mse_array(mse_voice, total_frames)
|
285 |
|
286 |
def combine_video_and_heatmap(t):
|
287 |
original_frame = int(t * video.fps)
|