Update visualization.py
Browse files- visualization.py +19 -14
visualization.py
CHANGED
@@ -231,26 +231,24 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
|
231 |
plt.tight_layout()
|
232 |
return plt.gcf()
|
233 |
|
234 |
-
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="
|
235 |
-
plt.figure(figsize=(20,
|
236 |
-
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20,
|
237 |
|
238 |
# Face heatmap
|
239 |
-
sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1)
|
240 |
-
ax1.
|
241 |
-
ax1.
|
242 |
-
ax1.set_xticks([])
|
243 |
|
244 |
# Posture heatmap
|
245 |
-
sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2)
|
246 |
-
ax2.
|
247 |
-
ax2.
|
248 |
-
ax2.set_xticks([])
|
249 |
|
250 |
# Voice heatmap
|
251 |
-
sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3)
|
252 |
-
ax3.
|
253 |
-
ax3.
|
254 |
|
255 |
# Set x-axis ticks to timecodes for the bottom subplot
|
256 |
num_ticks = min(60, len(mse_voice))
|
@@ -259,6 +257,13 @@ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Stack
|
|
259 |
ax3.set_xticks(tick_locations)
|
260 |
ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
plt.suptitle(title)
|
263 |
plt.tight_layout()
|
264 |
plt.close()
|
|
|
231 |
plt.tight_layout()
|
232 |
return plt.gcf()
|
233 |
|
234 |
+
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
|
235 |
+
plt.figure(figsize=(20, 6), dpi=300)
|
236 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
|
237 |
|
238 |
# Face heatmap
|
239 |
+
sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
|
240 |
+
ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
|
241 |
+
ax1.yaxis.set_label_coords(-0.01, 0.5)
|
|
|
242 |
|
243 |
# Posture heatmap
|
244 |
+
sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
|
245 |
+
ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
|
246 |
+
ax2.yaxis.set_label_coords(-0.01, 0.5)
|
|
|
247 |
|
248 |
# Voice heatmap
|
249 |
+
sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3, yticklabels=False)
|
250 |
+
ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
|
251 |
+
ax3.yaxis.set_label_coords(-0.01, 0.5)
|
252 |
|
253 |
# Set x-axis ticks to timecodes for the bottom subplot
|
254 |
num_ticks = min(60, len(mse_voice))
|
|
|
257 |
ax3.set_xticks(tick_locations)
|
258 |
ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|
259 |
|
260 |
+
# Remove spines
|
261 |
+
for ax in [ax1, ax2, ax3]:
|
262 |
+
ax.spines['top'].set_visible(False)
|
263 |
+
ax.spines['right'].set_visible(False)
|
264 |
+
ax.spines['bottom'].set_visible(False)
|
265 |
+
ax.spines['left'].set_visible(False)
|
266 |
+
|
267 |
plt.suptitle(title)
|
268 |
plt.tight_layout()
|
269 |
plt.close()
|