hiyata commited on
Commit
1b8562c
·
verified ·
1 Parent(s): 88b80ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -6
app.py CHANGED
@@ -328,32 +328,70 @@ def compute_shap_difference(shap1_norm, shap2_norm):
328
  """Compute the SHAP difference between normalized sequences"""
329
  return shap2_norm - shap1_norm
330
 
331
- def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
332
  """
333
- Plot heatmap using relative positions (0-100%)
 
 
 
 
 
 
334
  """
335
  heatmap_data = shap_diff.reshape(1, -1)
336
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
337
 
338
- fig, ax = plt.subplots(figsize=(12, 1.8))
 
 
 
339
  cmap = get_zero_centered_cmap()
340
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
341
 
342
- # Create percentage-based x-axis ticks
343
  num_ticks = 5
344
  tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
345
  tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
346
  ax.set_xticks(tick_positions)
347
  ax.set_xticklabels(tick_labels)
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
350
  cbar.ax.tick_params(labelsize=8)
351
  cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
352
 
 
353
  ax.set_yticks([])
354
- ax.set_xlabel('Relative Position in Sequence', fontsize=10)
 
355
  ax.set_title(title, pad=10)
356
- plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
357
 
358
  return fig
359
 
 
328
  """Compute the SHAP difference between normalized sequences"""
329
  return shap2_norm - shap1_norm
330
 
331
+ def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
+ Plot heatmap using both relative positions (0-100%) and actual sequence positions
334
+
335
+ Parameters:
336
+ shap_diff: numpy array of SHAP differences
337
+ seq1_length: length of sequence 1
338
+ seq2_length: length of sequence 2
339
+ title: plot title
340
  """
341
  heatmap_data = shap_diff.reshape(1, -1)
342
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
343
 
344
+ # Create figure with additional space for the second x-axis
345
+ fig, ax = plt.subplots(figsize=(12, 2.4))
346
+
347
+ # Plot main heatmap
348
  cmap = get_zero_centered_cmap()
349
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
350
 
351
+ # Create percentage-based x-axis ticks (top)
352
  num_ticks = 5
353
  tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
354
  tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
355
  ax.set_xticks(tick_positions)
356
  ax.set_xticklabels(tick_labels)
357
 
358
+ # Create second x-axis for actual positions (bottom)
359
+ ax2 = ax.twiny()
360
+ ax2.set_xlim(ax.get_xlim())
361
+
362
+ # Calculate actual positions for both sequences
363
+ seq1_positions = np.linspace(0, seq1_length, num_ticks)
364
+ seq2_positions = np.linspace(0, seq2_length, num_ticks)
365
+
366
+ # Format position labels with appropriate scaling
367
+ def format_position(x):
368
+ if x >= 1e6:
369
+ return f"{x/1e6:.1f}M"
370
+ elif x >= 1e3:
371
+ return f"{x/1e3:.0f}K"
372
+ else:
373
+ return f"{int(x)}"
374
+
375
+ seq1_labels = [format_position(x) for x in seq1_positions]
376
+ seq2_labels = [format_position(x) for x in seq2_positions]
377
+
378
+ # Set positions for bottom axis
379
+ ax2.set_xticks(tick_positions)
380
+ ax2.set_xticklabels([f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)])
381
+
382
+ # Add colorbar
383
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
384
  cbar.ax.tick_params(labelsize=8)
385
  cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
386
 
387
+ # Adjust labels and layout
388
  ax.set_yticks([])
389
+ ax.set_xlabel('Relative Position (%)', fontsize=10)
390
+ ax2.set_xlabel('Sequence Positions', fontsize=10)
391
  ax.set_title(title, pad=10)
392
+
393
+ # Adjust layout to prevent label overlap
394
+ plt.subplots_adjust(bottom=0.35, left=0.05, right=0.95, top=0.85)
395
 
396
  return fig
397