hiyata commited on
Commit
ec4c39f
·
verified ·
1 Parent(s): 37ce441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -49
app.py CHANGED
@@ -330,76 +330,64 @@ def compute_shap_difference(shap1_norm, shap2_norm):
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
- Plot a comparative SHAP-difference heatmap, showing how Sequence 2 differs
334
- from Sequence 1 with respect to 'human-likeness'.
335
-
336
- shap_diff: 1D array of (Seq2 - Seq1) SHAP differences, already normalized to the same length
337
- seq1_length: length of sequence 1 (in bases)
338
- seq2_length: length of sequence 2 (in bases)
339
- title: title of the plot
340
  """
341
-
342
- # Prepare data for the heatmap
343
  heatmap_data = shap_diff.reshape(1, -1)
344
- extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
345
 
346
- # Create figure and main axis
347
  fig, ax = plt.subplots(figsize=(12, 3))
 
 
348
 
349
- # Plot the main heatmap
350
- cmap = get_zero_centered_cmap()
351
- cax = ax.imshow(
352
- heatmap_data,
353
- aspect='auto',
354
- cmap=cmap,
355
- vmin=-extent,
356
- vmax=extent
357
- )
358
-
359
- # Add a vertical colorbar on the right
360
- cbar = plt.colorbar(cax, ax=ax, orientation='vertical', fraction=0.025, pad=0.03)
361
- cbar.ax.tick_params(labelsize=9)
362
- cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
363
-
364
- # Configure the top axis for relative (%) positions
365
  num_ticks = 5
366
  tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
367
- tick_labels = [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)]
368
  ax.set_xticks(tick_positions)
369
- ax.set_xticklabels(tick_labels, fontsize=9)
370
-
371
- ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=8)
372
- ax.set_title(title, fontsize=12, pad=10)
373
- ax.set_yticks([]) # Hide the y-axis ticks for a 1D heatmap
374
-
375
- # Create a second (bottom) x-axis for actual positions
376
- ax2 = ax.secondary_xaxis('bottom')
377
- ax2.set_xlim(ax.get_xlim()) # Match the same data range as the top axis
378
 
379
- # Prepare positions for each sequence
 
 
380
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
381
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
382
 
383
- # Format large numbers with 'K' or 'M'
384
  def format_position(x):
385
  if x >= 1e6:
386
- return f"{x/1e6:.1f}M"
387
  elif x >= 1e3:
388
- return f"{x/1e3:.0f}K"
389
  else:
390
- return str(int(x))
391
 
392
  seq1_labels = [format_position(x) for x in seq1_positions]
393
  seq2_labels = [format_position(x) for x in seq2_positions]
394
 
395
- # Combine the two label sets in a single tick label
396
- bottom_labels = [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)]
397
- ax2.set_xticks(tick_positions)
398
- ax2.set_xticklabels(bottom_labels, fontsize=9)
399
- ax2.set_xlabel('Sequence Positions', fontsize=10, labelpad=10)
 
 
 
 
 
 
 
 
 
 
 
400
 
401
- # Use tight_layout to reduce overlap
402
- plt.tight_layout()
403
  return fig
404
 
405
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
 
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
+ Plots a comparative heatmap of SHAP differences between two sequences.
334
+ - The bottom x-axis shows relative positions (%) across the normalized dimension.
335
+ - The top x-axis shows actual positions for each sequence (S1 and S2).
336
+ - A vertical colorbar is placed to the right.
337
+ - Negative (blue) indicates Seq1 is more human-like in that region,
338
+ positive (red) indicates Seq2 is more human-like,
339
+ white indicates no substantial difference.
340
  """
341
+ # Prepare data
 
342
  heatmap_data = shap_diff.reshape(1, -1)
343
+ extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
344
 
345
+ # Create figure and axis
346
  fig, ax = plt.subplots(figsize=(12, 3))
347
+ cmap = get_zero_centered_cmap() # Ensure this function is defined above
348
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
349
 
350
+ # Bottom axis: percentage-based x-axis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  num_ticks = 5
352
  tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
 
353
  ax.set_xticks(tick_positions)
354
+ ax.set_xticklabels([f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)], fontsize=9)
355
+ ax.set_xlabel("Relative Position (%)", fontsize=10)
 
 
 
 
 
 
 
356
 
357
+ # Top axis: actual sequence positions for both Seq1 and Seq2
358
+ ax_top = ax.twiny()
359
+ ax_top.set_xlim(ax.get_xlim())
360
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
361
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
362
 
 
363
  def format_position(x):
364
  if x >= 1e6:
365
+ return f"{x / 1e6:.1f}M"
366
  elif x >= 1e3:
367
+ return f"{x / 1e3:.0f}K"
368
  else:
369
+ return f"{int(x)}"
370
 
371
  seq1_labels = [format_position(x) for x in seq1_positions]
372
  seq2_labels = [format_position(x) for x in seq2_positions]
373
 
374
+ ax_top.set_xticks(tick_positions)
375
+ ax_top.set_xticklabels(
376
+ [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
377
+ fontsize=8
378
+ )
379
+ ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
380
+
381
+ # Colorbar on the right
382
+ cbar = fig.colorbar(cax, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
383
+ cbar.set_label("SHAP Difference (Seq2 - Seq1)", fontsize=10)
384
+
385
+ # Aesthetics
386
+ ax.set_yticks([])
387
+ ax.set_title(title, fontsize=12, pad=10)
388
+
389
+ fig.tight_layout(rect=[0, 0, 0.88, 1])
390
 
 
 
391
  return fig
392
 
393
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):