Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -328,32 +328,70 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
328 |
"""Compute the SHAP difference between normalized sequences"""
|
329 |
return shap2_norm - shap1_norm
|
330 |
|
331 |
-
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
332 |
"""
|
333 |
-
Plot heatmap using relative positions (0-100%)
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
"""
|
335 |
heatmap_data = shap_diff.reshape(1, -1)
|
336 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
337 |
|
338 |
-
|
|
|
|
|
|
|
339 |
cmap = get_zero_centered_cmap()
|
340 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
341 |
|
342 |
-
# Create percentage-based x-axis ticks
|
343 |
num_ticks = 5
|
344 |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
345 |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
346 |
ax.set_xticks(tick_positions)
|
347 |
ax.set_xticklabels(tick_labels)
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
350 |
cbar.ax.tick_params(labelsize=8)
|
351 |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
352 |
|
|
|
353 |
ax.set_yticks([])
|
354 |
-
ax.set_xlabel('Relative Position
|
|
|
355 |
ax.set_title(title, pad=10)
|
356 |
-
|
|
|
|
|
357 |
|
358 |
return fig
|
359 |
|
|
|
328 |
"""Compute the SHAP difference between normalized sequences"""
|
329 |
return shap2_norm - shap1_norm
|
330 |
|
331 |
+
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
332 |
"""
|
333 |
+
Plot heatmap using both relative positions (0-100%) and actual sequence positions
|
334 |
+
|
335 |
+
Parameters:
|
336 |
+
shap_diff: numpy array of SHAP differences
|
337 |
+
seq1_length: length of sequence 1
|
338 |
+
seq2_length: length of sequence 2
|
339 |
+
title: plot title
|
340 |
"""
|
341 |
heatmap_data = shap_diff.reshape(1, -1)
|
342 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
343 |
|
344 |
+
# Create figure with additional space for the second x-axis
|
345 |
+
fig, ax = plt.subplots(figsize=(12, 2.4))
|
346 |
+
|
347 |
+
# Plot main heatmap
|
348 |
cmap = get_zero_centered_cmap()
|
349 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
350 |
|
351 |
+
# Create percentage-based x-axis ticks (top)
|
352 |
num_ticks = 5
|
353 |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
354 |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
355 |
ax.set_xticks(tick_positions)
|
356 |
ax.set_xticklabels(tick_labels)
|
357 |
|
358 |
+
# Create second x-axis for actual positions (bottom)
|
359 |
+
ax2 = ax.twiny()
|
360 |
+
ax2.set_xlim(ax.get_xlim())
|
361 |
+
|
362 |
+
# Calculate actual positions for both sequences
|
363 |
+
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
364 |
+
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
365 |
+
|
366 |
+
# Format position labels with appropriate scaling
|
367 |
+
def format_position(x):
|
368 |
+
if x >= 1e6:
|
369 |
+
return f"{x/1e6:.1f}M"
|
370 |
+
elif x >= 1e3:
|
371 |
+
return f"{x/1e3:.0f}K"
|
372 |
+
else:
|
373 |
+
return f"{int(x)}"
|
374 |
+
|
375 |
+
seq1_labels = [format_position(x) for x in seq1_positions]
|
376 |
+
seq2_labels = [format_position(x) for x in seq2_positions]
|
377 |
+
|
378 |
+
# Set positions for bottom axis
|
379 |
+
ax2.set_xticks(tick_positions)
|
380 |
+
ax2.set_xticklabels([f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)])
|
381 |
+
|
382 |
+
# Add colorbar
|
383 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
384 |
cbar.ax.tick_params(labelsize=8)
|
385 |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
386 |
|
387 |
+
# Adjust labels and layout
|
388 |
ax.set_yticks([])
|
389 |
+
ax.set_xlabel('Relative Position (%)', fontsize=10)
|
390 |
+
ax2.set_xlabel('Sequence Positions', fontsize=10)
|
391 |
ax.set_title(title, pad=10)
|
392 |
+
|
393 |
+
# Adjust layout to prevent label overlap
|
394 |
+
plt.subplots_adjust(bottom=0.35, left=0.05, right=0.95, top=0.85)
|
395 |
|
396 |
return fig
|
397 |
|