reab5555 commited on
Commit
859cec7
·
verified ·
1 Parent(s): 58366da

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +33 -1
visualization.py CHANGED
@@ -1,4 +1,5 @@
1
  import matplotlib.pyplot as plt
 
2
  from matplotlib.backends.backend_agg import FigureCanvasAgg
3
  import matplotlib.colors as mcolors
4
  from matplotlib.colors import LinearSegmentedColormap
@@ -315,4 +316,35 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
315
  return heatmap_video_path
316
  else:
317
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
318
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import matplotlib.pyplot as plt
2
+ from mpl_toolkits.mplot3d import Axes3D
3
  from matplotlib.backends.backend_agg import FigureCanvasAgg
4
  import matplotlib.colors as mcolors
5
  from matplotlib.colors import LinearSegmentedColormap
 
316
  return heatmap_video_path
317
  else:
318
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
319
+ return None
320
+
321
+
322
+ # Function to create the correlation heatmap
323
+ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
324
+ mse_data = {
325
+ 'Facial Features MSE': mse_embeddings,
326
+ 'Body Posture MSE': mse_posture,
327
+ 'Voice MSE': mse_voice
328
+ }
329
+ mse_df = pd.DataFrame(mse_data)
330
+ correlation_matrix = mse_df.corr()
331
+
332
+ plt.figure(figsize=(8, 6))
333
+ sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
334
+ plt.title("Correlation Heatmap of MSEs")
335
+ plt.close()
336
+ return plt.gcf()
337
+
338
+ # Function to create the 3D scatter plot
339
+ def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
340
+ fig = plt.figure(figsize=(10, 8))
341
+ ax = fig.add_subplot(111, projection='3d')
342
+ ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
343
+
344
+ ax.set_xlabel('Body Posture MSE')
345
+ ax.set_ylabel('Facial Features MSE')
346
+ ax.set_zlabel('Voice MSE')
347
+ ax.set_title('3D Scatter Plot of MSEs')
348
+
349
+ plt.close()
350
+ return fig