reab5555 commited on
Commit
eeed558
·
verified ·
1 Parent(s): b58b457

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +19 -14
visualization.py CHANGED
@@ -231,26 +231,24 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
231
  plt.tight_layout()
232
  return plt.gcf()
233
 
234
- def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Stacked MSE Heatmaps"):
235
- plt.figure(figsize=(20, 9), dpi=300)
236
- fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 9), sharex=True)
237
 
238
  # Face heatmap
239
- sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1)
240
- ax1.set_yticks([0.5])
241
- ax1.set_yticklabels(['Face'], rotation=0, va='center')
242
- ax1.set_xticks([])
243
 
244
  # Posture heatmap
245
- sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2)
246
- ax2.set_yticks([0.5])
247
- ax2.set_yticklabels(['Posture'], rotation=0, va='center')
248
- ax2.set_xticks([])
249
 
250
  # Voice heatmap
251
- sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3)
252
- ax3.set_yticks([0.5])
253
- ax3.set_yticklabels(['Voice'], rotation=0, va='center')
254
 
255
  # Set x-axis ticks to timecodes for the bottom subplot
256
  num_ticks = min(60, len(mse_voice))
@@ -259,6 +257,13 @@ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Stack
259
  ax3.set_xticks(tick_locations)
260
  ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
261
 
 
 
 
 
 
 
 
262
  plt.suptitle(title)
263
  plt.tight_layout()
264
  plt.close()
 
231
  plt.tight_layout()
232
  return plt.gcf()
233
 
234
+ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
235
+ plt.figure(figsize=(20, 6), dpi=300)
236
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
237
 
238
  # Face heatmap
239
+ sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
240
+ ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
241
+ ax1.yaxis.set_label_coords(-0.01, 0.5)
 
242
 
243
  # Posture heatmap
244
+ sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
245
+ ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
246
+ ax2.yaxis.set_label_coords(-0.01, 0.5)
 
247
 
248
  # Voice heatmap
249
+ sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3, yticklabels=False)
250
+ ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
251
+ ax3.yaxis.set_label_coords(-0.01, 0.5)
252
 
253
  # Set x-axis ticks to timecodes for the bottom subplot
254
  num_ticks = min(60, len(mse_voice))
 
257
  ax3.set_xticks(tick_locations)
258
  ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
259
 
260
+ # Remove spines
261
+ for ax in [ax1, ax2, ax3]:
262
+ ax.spines['top'].set_visible(False)
263
+ ax.spines['right'].set_visible(False)
264
+ ax.spines['bottom'].set_visible(False)
265
+ ax.spines['left'].set_visible(False)
266
+
267
  plt.suptitle(title)
268
  plt.tight_layout()
269
  plt.close()