hiyata commited on
Commit
37ce441
·
verified ·
1 Parent(s): c305525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -32
app.py CHANGED
@@ -329,64 +329,77 @@ def compute_shap_difference(shap1_norm, shap2_norm):
329
  return shap2_norm - shap1_norm
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
 
 
 
 
 
 
 
 
 
 
 
332
  heatmap_data = shap_diff.reshape(1, -1)
333
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
334
 
335
- # Create figure with additional space
336
- fig, ax = plt.subplots(figsize=(12, 3.2))
337
 
338
- # Plot main heatmap
339
  cmap = get_zero_centered_cmap()
340
- cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- # Create percentage-based x-axis ticks (top)
343
  num_ticks = 5
344
- tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
345
- tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
346
  ax.set_xticks(tick_positions)
347
- ax.set_xticklabels(tick_labels)
 
 
 
 
348
 
349
- # Create second x-axis for actual positions (bottom)
350
- ax2 = ax.twiny()
351
- ax2.set_xlim(ax.get_xlim())
352
 
353
- # Calculate actual positions for both sequences
354
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
355
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
356
 
357
- # Format position labels with appropriate scaling
358
  def format_position(x):
359
  if x >= 1e6:
360
  return f"{x/1e6:.1f}M"
361
  elif x >= 1e3:
362
  return f"{x/1e3:.0f}K"
363
  else:
364
- return f"{int(x)}"
365
 
366
  seq1_labels = [format_position(x) for x in seq1_positions]
367
  seq2_labels = [format_position(x) for x in seq2_positions]
368
 
369
- # Set positions for bottom axis
 
370
  ax2.set_xticks(tick_positions)
371
- ax2.set_xticklabels([f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)])
372
-
373
- # Add colorbar at the bottom with more spacing
374
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.6, aspect=40, shrink=0.8)
375
- cbar.ax.tick_params(labelsize=8)
376
- cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
377
-
378
- # Remove percentage labels from colorbar
379
- cbar.ax.set_xticklabels([f'{x:.3f}' for x in cbar.ax.get_xticks()])
380
-
381
- # Adjust labels and layout
382
- ax.set_yticks([])
383
- ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=10)
384
  ax2.set_xlabel('Sequence Positions', fontsize=10, labelpad=10)
385
- ax.set_title(title, pad=10)
386
-
387
- # Increase bottom margin to prevent overlap
388
- plt.subplots_adjust(bottom=0.55, left=0.05, right=0.95, top=0.85)
389
 
 
 
390
  return fig
391
 
392
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
 
329
  return shap2_norm - shap1_norm
330
 
331
  def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
332
+ """
333
+ Plot a comparative SHAP-difference heatmap, showing how Sequence 2 differs
334
+ from Sequence 1 with respect to 'human-likeness'.
335
+
336
+ shap_diff: 1D array of (Seq2 - Seq1) SHAP differences, already normalized to the same length
337
+ seq1_length: length of sequence 1 (in bases)
338
+ seq2_length: length of sequence 2 (in bases)
339
+ title: title of the plot
340
+ """
341
+
342
+ # Prepare data for the heatmap
343
  heatmap_data = shap_diff.reshape(1, -1)
344
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
345
 
346
+ # Create figure and main axis
347
+ fig, ax = plt.subplots(figsize=(12, 3))
348
 
349
+ # Plot the main heatmap
350
  cmap = get_zero_centered_cmap()
351
+ cax = ax.imshow(
352
+ heatmap_data,
353
+ aspect='auto',
354
+ cmap=cmap,
355
+ vmin=-extent,
356
+ vmax=extent
357
+ )
358
+
359
+ # Add a vertical colorbar on the right
360
+ cbar = plt.colorbar(cax, ax=ax, orientation='vertical', fraction=0.025, pad=0.03)
361
+ cbar.ax.tick_params(labelsize=9)
362
+ cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
363
 
364
+ # Configure the top axis for relative (%) positions
365
  num_ticks = 5
366
+ tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
367
+ tick_labels = [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)]
368
  ax.set_xticks(tick_positions)
369
+ ax.set_xticklabels(tick_labels, fontsize=9)
370
+
371
+ ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=8)
372
+ ax.set_title(title, fontsize=12, pad=10)
373
+ ax.set_yticks([]) # Hide the y-axis ticks for a 1D heatmap
374
 
375
+ # Create a second (bottom) x-axis for actual positions
376
+ ax2 = ax.secondary_xaxis('bottom')
377
+ ax2.set_xlim(ax.get_xlim()) # Match the same data range as the top axis
378
 
379
+ # Prepare positions for each sequence
380
  seq1_positions = np.linspace(0, seq1_length, num_ticks)
381
  seq2_positions = np.linspace(0, seq2_length, num_ticks)
382
 
383
+ # Format large numbers with 'K' or 'M'
384
  def format_position(x):
385
  if x >= 1e6:
386
  return f"{x/1e6:.1f}M"
387
  elif x >= 1e3:
388
  return f"{x/1e3:.0f}K"
389
  else:
390
+ return str(int(x))
391
 
392
  seq1_labels = [format_position(x) for x in seq1_positions]
393
  seq2_labels = [format_position(x) for x in seq2_positions]
394
 
395
+ # Combine the two label sets in a single tick label
396
+ bottom_labels = [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)]
397
  ax2.set_xticks(tick_positions)
398
+ ax2.set_xticklabels(bottom_labels, fontsize=9)
 
 
 
 
 
 
 
 
 
 
 
 
399
  ax2.set_xlabel('Sequence Positions', fontsize=10, labelpad=10)
 
 
 
 
400
 
401
+ # Use tight_layout to reduce overlap
402
+ plt.tight_layout()
403
  return fig
404
 
405
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):