Update visualization.py
Browse files- visualization.py +33 -1
visualization.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import matplotlib.pyplot as plt
|
|
|
2 |
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
3 |
import matplotlib.colors as mcolors
|
4 |
from matplotlib.colors import LinearSegmentedColormap
|
@@ -315,4 +316,35 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
315 |
return heatmap_video_path
|
316 |
else:
|
317 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
318 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
+
from mpl_toolkits.mplot3d import Axes3D
|
3 |
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
4 |
import matplotlib.colors as mcolors
|
5 |
from matplotlib.colors import LinearSegmentedColormap
|
|
|
316 |
return heatmap_video_path
|
317 |
else:
|
318 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
319 |
+
return None
|
320 |
+
|
321 |
+
|
322 |
+
# Function to create the correlation heatmap
|
323 |
+
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
324 |
+
mse_data = {
|
325 |
+
'Facial Features MSE': mse_embeddings,
|
326 |
+
'Body Posture MSE': mse_posture,
|
327 |
+
'Voice MSE': mse_voice
|
328 |
+
}
|
329 |
+
mse_df = pd.DataFrame(mse_data)
|
330 |
+
correlation_matrix = mse_df.corr()
|
331 |
+
|
332 |
+
plt.figure(figsize=(8, 6))
|
333 |
+
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
334 |
+
plt.title("Correlation Heatmap of MSEs")
|
335 |
+
plt.close()
|
336 |
+
return plt.gcf()
|
337 |
+
|
338 |
+
# Function to create the 3D scatter plot
|
339 |
+
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
340 |
+
fig = plt.figure(figsize=(10, 8))
|
341 |
+
ax = fig.add_subplot(111, projection='3d')
|
342 |
+
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
|
343 |
+
|
344 |
+
ax.set_xlabel('Body Posture MSE')
|
345 |
+
ax.set_ylabel('Facial Features MSE')
|
346 |
+
ax.set_zlabel('Voice MSE')
|
347 |
+
ax.set_title('3D Scatter Plot of MSEs')
|
348 |
+
|
349 |
+
plt.close()
|
350 |
+
return fig
|