hiyata commited on
Commit
1869cbd
·
verified ·
1 Parent(s): 8ef755b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -88
app.py CHANGED
@@ -328,106 +328,35 @@ def compute_shap_difference(shap1_norm, shap2_norm):
328
  """Compute the SHAP difference between normalized sequences"""
329
  return shap2_norm - shap1_norm
330
 
331
- def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
  """
333
- Plots a comparative heatmap of SHAP differences between two sequences.
334
-
335
- Parameters
336
- ----------
337
- shap_diff : 1D array of differences (Seq2 SHAP - Seq1 SHAP).
338
- Negative = Seq1 more human-like, Positive = Seq2 more human-like.
339
- seq1_length : int, length of sequence 1 (for labeling).
340
- seq2_length : int, length of sequence 2 (for labeling).
341
- title : str, figure title.
342
-
343
- Figure Layout
344
- -------------
345
- - Bottom X-Axis: Relative positions in percent (0% to 100%).
346
- - Top X-Axis : Actual positions for both sequences (S1 and S2).
347
- - Y-Axis : Hidden (this is effectively a 1D heatmap).
348
- - Colorbar : Vertical, placed on the right.
349
-
350
- The final layout uses `tight_layout` with a rectangular constraint to ensure
351
- nothing overlaps while still providing clear labeling.
352
  """
353
- # Reshape the 1D differences into a 1 x N image
354
  heatmap_data = shap_diff.reshape(1, -1)
355
- # Determine the symmetric range around zero
356
- extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
357
-
358
- # Create the figure (width x height in inches)
359
- fig, ax = plt.subplots(figsize=(10, 3))
360
 
361
- # Main heatmap
362
  cmap = get_zero_centered_cmap()
363
- cax = ax.imshow(
364
- heatmap_data,
365
- aspect='auto',
366
- cmap=cmap,
367
- vmin=-extent,
368
- vmax=extent
369
- )
370
 
371
- # Number of major ticks along x-axis
372
  num_ticks = 5
373
- tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
374
-
375
- # ----------------- Bottom Axis: Percentage ----------------- #
376
  ax.set_xticks(tick_positions)
377
- ax.set_xticklabels(
378
- [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)],
379
- fontsize=9
380
- )
381
- ax.set_xlabel("Relative Position (%)", fontsize=10, labelpad=10)
382
-
383
- # ----------------- Top Axis: Actual Positions ----------------- #
384
- ax_top = ax.twiny()
385
- ax_top.set_xlim(ax.get_xlim()) # Match the bottom axis
386
-
387
- # Create position arrays for both sequences
388
- seq1_positions = np.linspace(0, seq1_length, num_ticks)
389
- seq2_positions = np.linspace(0, seq2_length, num_ticks)
390
-
391
- # Helper function to format large positions nicely
392
- def format_position(x):
393
- if x >= 1e6:
394
- return f"{x / 1e6:.1f}M"
395
- elif x >= 1e3:
396
- return f"{int(x / 1e3)}K"
397
- else:
398
- return f"{int(x)}"
399
 
400
- seq1_labels = [format_position(x) for x in seq1_positions]
401
- seq2_labels = [format_position(x) for x in seq2_positions]
402
-
403
- ax_top.set_xticks(tick_positions)
404
- # Each tick label shows the corresponding position in Seq1 and Seq2
405
- ax_top.set_xticklabels(
406
- [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
407
- fontsize=9
408
- )
409
- ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
410
-
411
- # ----------------- Colorbar (Vertical, on the right) ----------------- #
412
- # 'fraction' = thickness of colorbar, 'pad' = gap from the right edge
413
- cbar = fig.colorbar(
414
- cax, ax=ax, orientation='vertical', fraction=0.03, pad=0.07
415
- )
416
- cbar.set_label("SHAP Difference\n(Seq2 - Seq1)", fontsize=10, labelpad=5)
417
 
418
- # Hide the y-axis (not needed in a 1D heatmap)
419
  ax.set_yticks([])
420
-
421
- # Set a descriptive title
422
- ax.set_title(title, fontsize=12, pad=10)
423
-
424
- # Adjust layout so everything fits without overlapping
425
- # The rect parameter leaves space on the right for the colorbar
426
- fig.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust as necessary
427
 
428
  return fig
429
 
430
-
431
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
432
  """
433
  Plot histogram of SHAP values with configurable number of bins
@@ -608,8 +537,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
608
  # Generate visualizations
609
  heatmap_fig = plot_comparative_heatmap(
610
  shap_diff,
611
- seq1_length=len1,
612
- seq2_length=len2,
613
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
614
  )
615
  heatmap_img = fig_to_image(heatmap_fig)
 
328
  """Compute the SHAP difference between normalized sequences"""
329
  return shap2_norm - shap1_norm
330
 
331
+ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
332
  """
333
+ Plot heatmap using relative positions (0-100%)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  """
 
335
  heatmap_data = shap_diff.reshape(1, -1)
336
+ extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
 
 
 
 
337
 
338
+ fig, ax = plt.subplots(figsize=(12, 1.8))
339
  cmap = get_zero_centered_cmap()
340
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
 
 
 
 
 
 
341
 
342
+ # Create percentage-based x-axis ticks
343
  num_ticks = 5
344
+ tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
345
+ tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
 
346
  ax.set_xticks(tick_positions)
347
+ ax.set_xticklabels(tick_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
350
+ cbar.ax.tick_params(labelsize=8)
351
+ cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
 
353
  ax.set_yticks([])
354
+ ax.set_xlabel('Relative Position in Sequence', fontsize=10)
355
+ ax.set_title(title, pad=10)
356
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
 
 
357
 
358
  return fig
359
 
 
360
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
361
  """
362
  Plot histogram of SHAP values with configurable number of bins
 
537
  # Generate visualizations
538
  heatmap_fig = plot_comparative_heatmap(
539
  shap_diff,
 
 
540
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
541
  )
542
  heatmap_img = fig_to_image(heatmap_fig)