Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -329,64 +329,77 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
329 |
return shap2_norm - shap1_norm
|
330 |
|
331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
heatmap_data = shap_diff.reshape(1, -1)
|
333 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
334 |
|
335 |
-
# Create figure
|
336 |
-
fig, ax = plt.subplots(figsize=(12, 3
|
337 |
|
338 |
-
# Plot main heatmap
|
339 |
cmap = get_zero_centered_cmap()
|
340 |
-
cax = ax.imshow(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
-
#
|
343 |
num_ticks = 5
|
344 |
-
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
345 |
-
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
346 |
ax.set_xticks(tick_positions)
|
347 |
-
ax.set_xticklabels(tick_labels)
|
|
|
|
|
|
|
|
|
348 |
|
349 |
-
# Create second x-axis for actual positions
|
350 |
-
ax2 = ax.
|
351 |
-
ax2.set_xlim(ax.get_xlim())
|
352 |
|
353 |
-
#
|
354 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
355 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
356 |
|
357 |
-
# Format
|
358 |
def format_position(x):
|
359 |
if x >= 1e6:
|
360 |
return f"{x/1e6:.1f}M"
|
361 |
elif x >= 1e3:
|
362 |
return f"{x/1e3:.0f}K"
|
363 |
else:
|
364 |
-
return
|
365 |
|
366 |
seq1_labels = [format_position(x) for x in seq1_positions]
|
367 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
368 |
|
369 |
-
#
|
|
|
370 |
ax2.set_xticks(tick_positions)
|
371 |
-
ax2.set_xticklabels(
|
372 |
-
|
373 |
-
# Add colorbar at the bottom with more spacing
|
374 |
-
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.6, aspect=40, shrink=0.8)
|
375 |
-
cbar.ax.tick_params(labelsize=8)
|
376 |
-
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
377 |
-
|
378 |
-
# Remove percentage labels from colorbar
|
379 |
-
cbar.ax.set_xticklabels([f'{x:.3f}' for x in cbar.ax.get_xticks()])
|
380 |
-
|
381 |
-
# Adjust labels and layout
|
382 |
-
ax.set_yticks([])
|
383 |
-
ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=10)
|
384 |
ax2.set_xlabel('Sequence Positions', fontsize=10, labelpad=10)
|
385 |
-
ax.set_title(title, pad=10)
|
386 |
-
|
387 |
-
# Increase bottom margin to prevent overlap
|
388 |
-
plt.subplots_adjust(bottom=0.55, left=0.05, right=0.95, top=0.85)
|
389 |
|
|
|
|
|
390 |
return fig
|
391 |
|
392 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
|
|
329 |
return shap2_norm - shap1_norm
|
330 |
|
331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
332 |
+
"""
|
333 |
+
Plot a comparative SHAP-difference heatmap, showing how Sequence 2 differs
|
334 |
+
from Sequence 1 with respect to 'human-likeness'.
|
335 |
+
|
336 |
+
shap_diff: 1D array of (Seq2 - Seq1) SHAP differences, already normalized to the same length
|
337 |
+
seq1_length: length of sequence 1 (in bases)
|
338 |
+
seq2_length: length of sequence 2 (in bases)
|
339 |
+
title: title of the plot
|
340 |
+
"""
|
341 |
+
|
342 |
+
# Prepare data for the heatmap
|
343 |
heatmap_data = shap_diff.reshape(1, -1)
|
344 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
345 |
|
346 |
+
# Create figure and main axis
|
347 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
348 |
|
349 |
+
# Plot the main heatmap
|
350 |
cmap = get_zero_centered_cmap()
|
351 |
+
cax = ax.imshow(
|
352 |
+
heatmap_data,
|
353 |
+
aspect='auto',
|
354 |
+
cmap=cmap,
|
355 |
+
vmin=-extent,
|
356 |
+
vmax=extent
|
357 |
+
)
|
358 |
+
|
359 |
+
# Add a vertical colorbar on the right
|
360 |
+
cbar = plt.colorbar(cax, ax=ax, orientation='vertical', fraction=0.025, pad=0.03)
|
361 |
+
cbar.ax.tick_params(labelsize=9)
|
362 |
+
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
363 |
|
364 |
+
# Configure the top axis for relative (%) positions
|
365 |
num_ticks = 5
|
366 |
+
tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
|
367 |
+
tick_labels = [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)]
|
368 |
ax.set_xticks(tick_positions)
|
369 |
+
ax.set_xticklabels(tick_labels, fontsize=9)
|
370 |
+
|
371 |
+
ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=8)
|
372 |
+
ax.set_title(title, fontsize=12, pad=10)
|
373 |
+
ax.set_yticks([]) # Hide the y-axis ticks for a 1D heatmap
|
374 |
|
375 |
+
# Create a second (bottom) x-axis for actual positions
|
376 |
+
ax2 = ax.secondary_xaxis('bottom')
|
377 |
+
ax2.set_xlim(ax.get_xlim()) # Match the same data range as the top axis
|
378 |
|
379 |
+
# Prepare positions for each sequence
|
380 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
381 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
382 |
|
383 |
+
# Format large numbers with 'K' or 'M'
|
384 |
def format_position(x):
|
385 |
if x >= 1e6:
|
386 |
return f"{x/1e6:.1f}M"
|
387 |
elif x >= 1e3:
|
388 |
return f"{x/1e3:.0f}K"
|
389 |
else:
|
390 |
+
return str(int(x))
|
391 |
|
392 |
seq1_labels = [format_position(x) for x in seq1_positions]
|
393 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
394 |
|
395 |
+
# Combine the two label sets in a single tick label
|
396 |
+
bottom_labels = [f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)]
|
397 |
ax2.set_xticks(tick_positions)
|
398 |
+
ax2.set_xticklabels(bottom_labels, fontsize=9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
ax2.set_xlabel('Sequence Positions', fontsize=10, labelpad=10)
|
|
|
|
|
|
|
|
|
400 |
|
401 |
+
# Use tight_layout to reduce overlap
|
402 |
+
plt.tight_layout()
|
403 |
return fig
|
404 |
|
405 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|