Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -330,41 +330,70 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
330 |
|
331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
332 |
"""
|
333 |
-
Plots a comparative heatmap of SHAP differences between two sequences.
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
"""
|
341 |
-
#
|
342 |
heatmap_data = shap_diff.reshape(1, -1)
|
|
|
343 |
extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
|
344 |
|
345 |
-
# Create figure
|
346 |
-
fig, ax = plt.subplots(figsize=(
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
#
|
351 |
num_ticks = 5
|
352 |
tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
|
|
|
|
|
353 |
ax.set_xticks(tick_positions)
|
354 |
-
ax.set_xticklabels(
|
355 |
-
|
|
|
|
|
|
|
356 |
|
357 |
-
# Top
|
358 |
ax_top = ax.twiny()
|
359 |
-
ax_top.set_xlim(ax.get_xlim())
|
|
|
|
|
360 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
361 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
362 |
|
|
|
363 |
def format_position(x):
|
364 |
if x >= 1e6:
|
365 |
return f"{x / 1e6:.1f}M"
|
366 |
elif x >= 1e3:
|
367 |
-
return f"{x / 1e3
|
368 |
else:
|
369 |
return f"{int(x)}"
|
370 |
|
@@ -372,24 +401,33 @@ def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Di
|
|
372 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
373 |
|
374 |
ax_top.set_xticks(tick_positions)
|
|
|
375 |
ax_top.set_xticklabels(
|
376 |
[f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
|
377 |
-
fontsize=
|
378 |
)
|
379 |
ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
|
380 |
|
381 |
-
# Colorbar on the right
|
382 |
-
|
383 |
-
cbar
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
#
|
386 |
-
ax.set_yticks([])
|
387 |
ax.set_title(title, fontsize=12, pad=10)
|
388 |
|
389 |
-
|
|
|
|
|
390 |
|
391 |
return fig
|
392 |
|
|
|
393 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
394 |
"""
|
395 |
Plot histogram of SHAP values with configurable number of bins
|
|
|
330 |
|
331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
332 |
"""
|
333 |
+
Plots a comparative heatmap of SHAP differences between two sequences.
|
334 |
+
|
335 |
+
Parameters
|
336 |
+
----------
|
337 |
+
shap_diff : 1D array of differences (Seq2 SHAP - Seq1 SHAP).
|
338 |
+
Negative = Seq1 more human-like, Positive = Seq2 more human-like.
|
339 |
+
seq1_length : int, length of sequence 1 (for labeling).
|
340 |
+
seq2_length : int, length of sequence 2 (for labeling).
|
341 |
+
title : str, figure title.
|
342 |
+
|
343 |
+
Figure Layout
|
344 |
+
-------------
|
345 |
+
- Bottom X-Axis: Relative positions in percent (0% to 100%).
|
346 |
+
- Top X-Axis : Actual positions for both sequences (S1 and S2).
|
347 |
+
- Y-Axis : Hidden (this is effectively a 1D heatmap).
|
348 |
+
- Colorbar : Vertical, placed on the right.
|
349 |
+
|
350 |
+
The final layout uses `tight_layout` with a rectangular constraint to ensure
|
351 |
+
nothing overlaps while still providing clear labeling.
|
352 |
"""
|
353 |
+
# Reshape the 1D differences into a 1 x N image
|
354 |
heatmap_data = shap_diff.reshape(1, -1)
|
355 |
+
# Determine the symmetric range around zero
|
356 |
extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
|
357 |
|
358 |
+
# Create the figure (width x height in inches)
|
359 |
+
fig, ax = plt.subplots(figsize=(10, 3))
|
360 |
+
|
361 |
+
# Main heatmap
|
362 |
+
cmap = get_zero_centered_cmap()
|
363 |
+
cax = ax.imshow(
|
364 |
+
heatmap_data,
|
365 |
+
aspect='auto',
|
366 |
+
cmap=cmap,
|
367 |
+
vmin=-extent,
|
368 |
+
vmax=extent
|
369 |
+
)
|
370 |
|
371 |
+
# Number of major ticks along x-axis
|
372 |
num_ticks = 5
|
373 |
tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
|
374 |
+
|
375 |
+
# ----------------- Bottom Axis: Percentage ----------------- #
|
376 |
ax.set_xticks(tick_positions)
|
377 |
+
ax.set_xticklabels(
|
378 |
+
[f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)],
|
379 |
+
fontsize=9
|
380 |
+
)
|
381 |
+
ax.set_xlabel("Relative Position (%)", fontsize=10, labelpad=10)
|
382 |
|
383 |
+
# ----------------- Top Axis: Actual Positions ----------------- #
|
384 |
ax_top = ax.twiny()
|
385 |
+
ax_top.set_xlim(ax.get_xlim()) # Match the bottom axis
|
386 |
+
|
387 |
+
# Create position arrays for both sequences
|
388 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
389 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
390 |
|
391 |
+
# Helper function to format large positions nicely
|
392 |
def format_position(x):
|
393 |
if x >= 1e6:
|
394 |
return f"{x / 1e6:.1f}M"
|
395 |
elif x >= 1e3:
|
396 |
+
return f"{int(x / 1e3)}K"
|
397 |
else:
|
398 |
return f"{int(x)}"
|
399 |
|
|
|
401 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
402 |
|
403 |
ax_top.set_xticks(tick_positions)
|
404 |
+
# Each tick label shows the corresponding position in Seq1 and Seq2
|
405 |
ax_top.set_xticklabels(
|
406 |
[f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
|
407 |
+
fontsize=9
|
408 |
)
|
409 |
ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
|
410 |
|
411 |
+
# ----------------- Colorbar (Vertical, on the right) ----------------- #
|
412 |
+
# 'fraction' = thickness of colorbar, 'pad' = gap from the right edge
|
413 |
+
cbar = fig.colorbar(
|
414 |
+
cax, ax=ax, orientation='vertical', fraction=0.03, pad=0.07
|
415 |
+
)
|
416 |
+
cbar.set_label("SHAP Difference\n(Seq2 - Seq1)", fontsize=10, labelpad=5)
|
417 |
+
|
418 |
+
# Hide the y-axis (not needed in a 1D heatmap)
|
419 |
+
ax.set_yticks([])
|
420 |
|
421 |
+
# Set a descriptive title
|
|
|
422 |
ax.set_title(title, fontsize=12, pad=10)
|
423 |
|
424 |
+
# Adjust layout so everything fits without overlapping
|
425 |
+
# The rect parameter leaves space on the right for the colorbar
|
426 |
+
fig.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust as necessary
|
427 |
|
428 |
return fig
|
429 |
|
430 |
+
|
431 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
432 |
"""
|
433 |
Plot histogram of SHAP values with configurable number of bins
|