hiyata commited on
Commit
e502db5
·
verified ·
1 Parent(s): 5a41c75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -319,6 +319,39 @@ def analyze_subregion(state, header, region_start, region_end):
319
  # 9. COMPARISON ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  def calculate_adaptive_parameters(len1, len2):
323
  """
324
  Calculate adaptive parameters based on sequence lengths and their difference.
 
319
  # 9. COMPARISON ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
322
+ def compute_shap_difference(shap1_norm, shap2_norm):
323
+ """Compute the SHAP difference between normalized sequences"""
324
+ return shap2_norm - shap1_norm
325
+
326
+ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
327
+ """
328
+ Plot heatmap using relative positions (0-100%)
329
+ """
330
+ heatmap_data = shap_diff.reshape(1, -1)
331
+ extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
332
+
333
+ fig, ax = plt.subplots(figsize=(12, 1.8))
334
+ cmap = get_zero_centered_cmap()
335
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
336
+
337
+ # Create percentage-based x-axis ticks
338
+ num_ticks = 5
339
+ tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
340
+ tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
341
+ ax.set_xticks(tick_positions)
342
+ ax.set_xticklabels(tick_labels)
343
+
344
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
345
+ cbar.ax.tick_params(labelsize=8)
346
+ cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
347
+
348
+ ax.set_yticks([])
349
+ ax.set_xlabel('Relative Position in Sequence', fontsize=10)
350
+ ax.set_title(title, pad=10)
351
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
352
+
353
+ return fig
354
+
355
  def calculate_adaptive_parameters(len1, len2):
356
  """
357
  Calculate adaptive parameters based on sequence lengths and their difference.