hiyata commited on
Commit
8ef755b
·
verified ·
1 Parent(s): ec4c39f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -25
app.py CHANGED
@@ -330,41 +330,70 @@ def compute_shap_difference(shap1_norm, shap2_norm):
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
- Plots a comparative heatmap of SHAP differences between two sequences.
334
- - The bottom x-axis shows relative positions (%) across the normalized dimension.
335
- - The top x-axis shows actual positions for each sequence (S1 and S2).
336
- - A vertical colorbar is placed to the right.
337
- - Negative (blue) indicates Seq1 is more human-like in that region,
338
- positive (red) indicates Seq2 is more human-like,
339
- white indicates no substantial difference.
 
 
 
 
 
 
 
 
 
 
 
 
340
  """
341
- # Prepare data
342
  heatmap_data = shap_diff.reshape(1, -1)
 
343
  extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
344
 
345
- # Create figure and axis
346
- fig, ax = plt.subplots(figsize=(12, 3))
347
- cmap = get_zero_centered_cmap() # Ensure this function is defined above
348
- cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
 
 
 
 
 
 
 
 
349
 
350
- # Bottom axis: percentage-based x-axis
351
  num_ticks = 5
352
  tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
 
 
353
  ax.set_xticks(tick_positions)
354
- ax.set_xticklabels([f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)], fontsize=9)
355
- ax.set_xlabel("Relative Position (%)", fontsize=10)
 
 
 
356
 
357
- # Top axis: actual sequence positions for both Seq1 and Seq2
358
  ax_top = ax.twiny()
359
- ax_top.set_xlim(ax.get_xlim())
 
 
360
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
361
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
362
 
 
363
  def format_position(x):
364
  if x >= 1e6:
365
  return f"{x / 1e6:.1f}M"
366
  elif x >= 1e3:
367
- return f"{x / 1e3:.0f}K"
368
  else:
369
  return f"{int(x)}"
370
 
@@ -372,24 +401,33 @@ def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Di
372
  seq2_labels = [format_position(x) for x in seq2_positions]
373
 
374
  ax_top.set_xticks(tick_positions)
 
375
  ax_top.set_xticklabels(
376
  [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
377
- fontsize=8
378
  )
379
  ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
380
 
381
- # Colorbar on the right
382
- cbar = fig.colorbar(cax, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
383
- cbar.set_label("SHAP Difference (Seq2 - Seq1)", fontsize=10)
 
 
 
 
 
 
384
 
385
- # Aesthetics
386
- ax.set_yticks([])
387
  ax.set_title(title, fontsize=12, pad=10)
388
 
389
- fig.tight_layout(rect=[0, 0, 0.88, 1])
 
 
390
 
391
  return fig
392
 
 
393
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
394
  """
395
  Plot histogram of SHAP values with configurable number of bins
 
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
+ Plots a comparative heatmap of SHAP differences between two sequences.
334
+
335
+ Parameters
336
+ ----------
337
+ shap_diff : 1D array of differences (Seq2 SHAP - Seq1 SHAP).
338
+ Negative = Seq1 more human-like, Positive = Seq2 more human-like.
339
+ seq1_length : int, length of sequence 1 (for labeling).
340
+ seq2_length : int, length of sequence 2 (for labeling).
341
+ title : str, figure title.
342
+
343
+ Figure Layout
344
+ -------------
345
+ - Bottom X-Axis: Relative positions in percent (0% to 100%).
346
+ - Top X-Axis : Actual positions for both sequences (S1 and S2).
347
+ - Y-Axis : Hidden (this is effectively a 1D heatmap).
348
+ - Colorbar : Vertical, placed on the right.
349
+
350
+ The final layout uses `tight_layout` with a rectangular constraint to ensure
351
+ nothing overlaps while still providing clear labeling.
352
  """
353
+ # Reshape the 1D differences into a 1 x N image
354
  heatmap_data = shap_diff.reshape(1, -1)
355
+ # Determine the symmetric range around zero
356
  extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
357
 
358
+ # Create the figure (width x height in inches)
359
+ fig, ax = plt.subplots(figsize=(10, 3))
360
+
361
+ # Main heatmap
362
+ cmap = get_zero_centered_cmap()
363
+ cax = ax.imshow(
364
+ heatmap_data,
365
+ aspect='auto',
366
+ cmap=cmap,
367
+ vmin=-extent,
368
+ vmax=extent
369
+ )
370
 
371
+ # Number of major ticks along x-axis
372
  num_ticks = 5
373
  tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
374
+
375
+ # ----------------- Bottom Axis: Percentage ----------------- #
376
  ax.set_xticks(tick_positions)
377
+ ax.set_xticklabels(
378
+ [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)],
379
+ fontsize=9
380
+ )
381
+ ax.set_xlabel("Relative Position (%)", fontsize=10, labelpad=10)
382
 
383
+ # ----------------- Top Axis: Actual Positions ----------------- #
384
  ax_top = ax.twiny()
385
+ ax_top.set_xlim(ax.get_xlim()) # Match the bottom axis
386
+
387
+ # Create position arrays for both sequences
388
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
389
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
390
 
391
+ # Helper function to format large positions nicely
392
  def format_position(x):
393
  if x >= 1e6:
394
  return f"{x / 1e6:.1f}M"
395
  elif x >= 1e3:
396
+ return f"{int(x / 1e3)}K"
397
  else:
398
  return f"{int(x)}"
399
 
 
401
  seq2_labels = [format_position(x) for x in seq2_positions]
402
 
403
  ax_top.set_xticks(tick_positions)
404
+ # Each tick label shows the corresponding position in Seq1 and Seq2
405
  ax_top.set_xticklabels(
406
  [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
407
+ fontsize=9
408
  )
409
  ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
410
 
411
+ # ----------------- Colorbar (Vertical, on the right) ----------------- #
412
+ # 'fraction' = thickness of colorbar, 'pad' = gap from the right edge
413
+ cbar = fig.colorbar(
414
+ cax, ax=ax, orientation='vertical', fraction=0.03, pad=0.07
415
+ )
416
+ cbar.set_label("SHAP Difference\n(Seq2 - Seq1)", fontsize=10, labelpad=5)
417
+
418
+ # Hide the y-axis (not needed in a 1D heatmap)
419
+ ax.set_yticks([])
420
 
421
+ # Set a descriptive title
 
422
  ax.set_title(title, fontsize=12, pad=10)
423
 
424
+ # Adjust layout so everything fits without overlapping
425
+ # The rect parameter leaves space on the right for the colorbar
426
+ fig.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust as necessary
427
 
428
  return fig
429
 
430
+
431
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
432
  """
433
  Plot histogram of SHAP values with configurable number of bins