hiyata commited on
Commit
87c2305
·
verified ·
1 Parent(s): 77621ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -21
app.py CHANGED
@@ -320,38 +320,65 @@ def analyze_subregion(state, header, region_start, region_end):
320
  ###############################################################################
321
 
322
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
 
 
 
 
 
323
  x1 = np.linspace(0, 1, len(shap1))
324
  x2 = np.linspace(0, 1, len(shap2))
325
- f1 = interp1d(x1, shap1, kind='linear')
326
- f2 = interp1d(x2, shap2, kind='linear')
327
- x_new = np.linspace(0, 1, num_points)
328
- shap1_norm = f1(x_new)
329
- shap2_norm = f2(x_new)
330
- return shap1_norm, shap2_norm
 
 
 
331
 
332
  def compute_shap_difference(shap1_norm, shap2_norm):
 
333
  return shap2_norm - shap1_norm
334
 
335
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
 
 
 
336
  heatmap_data = shap_diff.reshape(1, -1)
337
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
338
- cmap = get_zero_centered_cmap()
339
  fig, ax = plt.subplots(figsize=(12, 1.8))
 
340
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
 
 
 
 
 
 
 
 
341
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
342
  cbar.ax.tick_params(labelsize=8)
343
  cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
 
344
  ax.set_yticks([])
345
- ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
346
  ax.set_title(title, pad=10)
347
  plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
348
  return fig
349
 
350
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
 
 
 
351
  # Analyze first sequence
352
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
353
  if isinstance(res1[0], str) and "Error" in res1[0]:
354
  return (f"Error in sequence 1: {res1[0]}", None, None)
 
355
  # Analyze second sequence
356
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
357
  if isinstance(res2[0], str) and "Error" in res2[0]:
@@ -359,46 +386,52 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
359
 
360
  shap1 = res1[3]["shap_means"]
361
  shap2 = res2[3]["shap_means"]
 
 
362
  shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
363
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
364
 
 
365
  avg_diff = np.mean(shap_diff)
366
  std_diff = np.std(shap_diff)
367
  max_diff = np.max(shap_diff)
368
  min_diff = np.min(shap_diff)
 
369
  threshold = 0.05
370
  substantial_diffs = np.abs(shap_diff) > threshold
371
  frac_different = np.mean(substantial_diffs)
372
-
373
- classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
374
- classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
375
  len1_formatted = "{:,}".format(len(shap1))
376
  len2_formatted = "{:,}".format(len(shap2))
377
- frac_formatted = "{:.2%}".format(frac_different)
378
-
 
379
  comparison_text = (
380
  "Sequence Comparison Results:\n"
381
- f"Sequence 1: {res1[4]}\n"
382
- f"Length: {len1_formatted} bases\n"
383
  f"Classification: {classification1}\n\n"
384
- f"Sequence 2: {res2[4]}\n"
385
- f"Length: {len2_formatted} bases\n"
386
  f"Classification: {classification2}\n\n"
387
  "Comparison Statistics:\n"
388
  f"Average SHAP difference: {avg_diff:.4f}\n"
389
  f"Standard deviation: {std_diff:.4f}\n"
390
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
391
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
392
- f"Fraction of positions with substantial differences: {frac_formatted}\n\n"
 
393
  "Interpretation:\n"
394
- "Positive values (red) indicate regions where Sequence 2 is more 'human-like'\n"
395
- "Negative values (blue) indicate regions where Sequence 1 is more 'human-like'"
 
396
  )
397
-
 
398
  heatmap_fig = plot_comparative_heatmap(shap_diff)
399
  heatmap_img = fig_to_image(heatmap_fig)
400
  hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
401
  hist_img = fig_to_image(hist_fig)
 
402
  return comparison_text, heatmap_img, hist_img
403
 
404
  ###############################################################################
 
320
  ###############################################################################
321
 
322
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
323
+ """
324
+ Normalize SHAP values to relative positions (0-1 scale).
325
+ Each point represents a relative position in the sequence (e.g., 0.75 = 75% through sequence).
326
+ """
327
+ # Create relative position arrays (0 to 1)
328
  x1 = np.linspace(0, 1, len(shap1))
329
  x2 = np.linspace(0, 1, len(shap2))
330
+
331
+ # Create normalized positions for comparison
332
+ x_norm = np.linspace(0, 1, num_points)
333
+
334
+ # Interpolate both sequences to the normalized positions
335
+ shap1_interp = np.interp(x_norm, x1, shap1)
336
+ shap2_interp = np.interp(x_norm, x2, shap2)
337
+
338
+ return shap1_interp, shap2_interp
339
 
340
  def compute_shap_difference(shap1_norm, shap2_norm):
341
+ """Compute the SHAP difference between normalized sequences"""
342
  return shap2_norm - shap1_norm
343
 
344
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
345
+ """
346
+ Plot heatmap using relative positions (0-100%)
347
+ """
348
  heatmap_data = shap_diff.reshape(1, -1)
349
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
350
+
351
  fig, ax = plt.subplots(figsize=(12, 1.8))
352
+ cmap = get_zero_centered_cmap()
353
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
354
+
355
+ # Create percentage-based x-axis ticks
356
+ num_ticks = 5
357
+ tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
358
+ tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
359
+ ax.set_xticks(tick_positions)
360
+ ax.set_xticklabels(tick_labels)
361
+
362
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
363
  cbar.ax.tick_params(labelsize=8)
364
  cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
365
+
366
  ax.set_yticks([])
367
+ ax.set_xlabel('Relative Position in Sequence', fontsize=10)
368
  ax.set_title(title, pad=10)
369
  plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
370
+
371
  return fig
372
 
373
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
374
+ """
375
+ Compare two sequences using relative positions (0-1 scale)
376
+ """
377
  # Analyze first sequence
378
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
379
  if isinstance(res1[0], str) and "Error" in res1[0]:
380
  return (f"Error in sequence 1: {res1[0]}", None, None)
381
+
382
  # Analyze second sequence
383
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
384
  if isinstance(res2[0], str) and "Error" in res2[0]:
 
386
 
387
  shap1 = res1[3]["shap_means"]
388
  shap2 = res2[3]["shap_means"]
389
+
390
+ # Normalize to relative positions
391
  shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
392
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
393
 
394
+ # Calculate statistics
395
  avg_diff = np.mean(shap_diff)
396
  std_diff = np.std(shap_diff)
397
  max_diff = np.max(shap_diff)
398
  min_diff = np.min(shap_diff)
399
+
400
  threshold = 0.05
401
  substantial_diffs = np.abs(shap_diff) > threshold
402
  frac_different = np.mean(substantial_diffs)
403
+
404
+ # Format output text
 
405
  len1_formatted = "{:,}".format(len(shap1))
406
  len2_formatted = "{:,}".format(len(shap2))
407
+ classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
408
+ classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
409
+
410
  comparison_text = (
411
  "Sequence Comparison Results:\n"
412
+ f"Sequence 1: {res1[4]} (Length: {len1_formatted} bases)\n"
 
413
  f"Classification: {classification1}\n\n"
414
+ f"Sequence 2: {res2[4]} (Length: {len2_formatted} bases)\n"
 
415
  f"Classification: {classification2}\n\n"
416
  "Comparison Statistics:\n"
417
  f"Average SHAP difference: {avg_diff:.4f}\n"
418
  f"Standard deviation: {std_diff:.4f}\n"
419
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
420
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
421
+ f"Fraction of positions with substantial differences: {frac_different:.2%}\n\n"
422
+ "Note: Comparisons shown at relative positions (0-100%) in each sequence\n"
423
  "Interpretation:\n"
424
+ "- Red regions: Sequence 2 is more human-like at that relative position\n"
425
+ "- Blue regions: Sequence 1 is more human-like at that relative position\n"
426
+ "- White regions: Similar between sequences"
427
  )
428
+
429
+ # Generate visualizations
430
  heatmap_fig = plot_comparative_heatmap(shap_diff)
431
  heatmap_img = fig_to_image(heatmap_fig)
432
  hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
433
  hist_img = fig_to_image(hist_fig)
434
+
435
  return comparison_text, heatmap_img, hist_img
436
 
437
  ###############################################################################