reab5555 commited on
Commit
30a22c3
·
verified ·
1 Parent(s): 869705c

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +22 -29
visualization.py CHANGED
@@ -226,36 +226,29 @@ def fill_with_zeros(mse_array, total_frames):
226
  return result
227
 
228
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
229
- frame_count = int(t * desired_fps)
230
- window_size = min(600, total_frames) # Increased window size for better context
231
- start_frame = max(0, frame_count - window_size // 2)
232
- end_frame = min(total_frames, start_frame + window_size)
233
-
234
- combined_mse = np.array([
235
- mse_embeddings[start_frame:end_frame],
236
- mse_posture[start_frame:end_frame],
237
- mse_voice[start_frame:end_frame]
238
- ])
239
-
240
- # Calculate global min and max for consistent scaling
241
- vmin = 0
242
- vmax = max(np.max(mse_embeddings), np.max(mse_posture), np.max(mse_voice))
243
-
244
- fig, ax = plt.subplots(figsize=(video_width / 100, 0.4)) # Adjusted figure size
245
- im = ax.imshow(combined_mse, aspect='auto', cmap='Reds',
246
- extent=[start_frame/desired_fps, end_frame/desired_fps, 0, 3],
247
- vmin=vmin, vmax=vmax, interpolation='nearest')
248
 
 
 
 
 
 
249
  ax.set_yticks([0.5, 1.5, 2.5])
250
  ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
251
 
 
 
 
252
  # Add vertical line for current time
253
  current_time = t
254
  ax.axvline(x=current_time, color='black', linewidth=2)
255
 
256
  # Set x-axis ticks and labels
257
- ax.set_xticks([start_frame/desired_fps, current_time, end_frame/desired_fps])
258
- ax.set_xticklabels([f'{start_frame/desired_fps:.2f}', f'{current_time:.2f}', f'{end_frame/desired_fps:.2f}'], fontsize=6)
259
 
260
  plt.tight_layout(pad=0.5)
261
 
@@ -280,15 +273,15 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
280
  width, height = video.w, video.h
281
  total_frames = int(video.duration * desired_fps)
282
 
283
- # Interpolate MSE values to match the desired fps
284
- def interpolate_mse(mse_array):
285
- original_indices = np.linspace(0, total_frames - 1, len(mse_array))
286
- new_indices = np.arange(total_frames)
287
- return np.interp(new_indices, original_indices, mse_array)
288
 
289
- mse_embeddings = interpolate_mse(mse_embeddings)
290
- mse_posture = interpolate_mse(mse_posture)
291
- mse_voice = interpolate_mse(mse_voice)
292
 
293
  def combine_video_and_heatmap(t):
294
  original_frame = int(t * video.fps)
 
226
  return result
227
 
228
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
229
+ fig, ax = plt.subplots(figsize=(video_width / 100, 0.4))
230
+
231
+ # Create the full heatmap for the entire video duration
232
+ combined_mse = np.array([mse_embeddings, mse_posture, mse_voice])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # Use pcolormesh for better performance with large datasets
235
+ im = ax.pcolormesh(np.arange(total_frames) / desired_fps, [0, 1, 2], combined_mse,
236
+ cmap='Reds', vmin=0, vmax=np.max(combined_mse))
237
+
238
+ ax.set_ylim(0, 3)
239
  ax.set_yticks([0.5, 1.5, 2.5])
240
  ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
241
 
242
+ # Set x-axis to show full video duration
243
+ ax.set_xlim(0, total_frames / desired_fps)
244
+
245
  # Add vertical line for current time
246
  current_time = t
247
  ax.axvline(x=current_time, color='black', linewidth=2)
248
 
249
  # Set x-axis ticks and labels
250
+ ax.set_xticks([0, current_time, total_frames / desired_fps])
251
+ ax.set_xticklabels(['0:00', f'{current_time:.2f}', f'{total_frames / desired_fps:.2f}'], fontsize=6)
252
 
253
  plt.tight_layout(pad=0.5)
254
 
 
273
  width, height = video.w, video.h
274
  total_frames = int(video.duration * desired_fps)
275
 
276
+ # Ensure MSE arrays have the same length as total_frames
277
+ def pad_mse_array(mse_array, total_frames):
278
+ if len(mse_array) < total_frames:
279
+ return np.pad(mse_array, (0, total_frames - len(mse_array)), 'constant', constant_values=0)
280
+ return mse_array[:total_frames]
281
 
282
+ mse_embeddings = pad_mse_array(mse_embeddings, total_frames)
283
+ mse_posture = pad_mse_array(mse_posture, total_frames)
284
+ mse_voice = pad_mse_array(mse_voice, total_frames)
285
 
286
  def combine_video_and_heatmap(t):
287
  original_frame = int(t * video.fps)