hiyata commited on
Commit
6c4adfb
·
verified ·
1 Parent(s): 87c2305

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -64
app.py CHANGED
@@ -319,67 +319,109 @@ def analyze_subregion(state, header, region_start, region_end):
319
  # 9. COMPARISON ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
322
- def normalize_shap_lengths(shap1, shap2, num_points=1000):
323
  """
324
- Normalize SHAP values to relative positions (0-1 scale).
325
- Each point represents a relative position in the sequence (e.g., 0.75 = 75% through sequence).
 
 
326
  """
327
- # Create relative position arrays (0 to 1)
328
- x1 = np.linspace(0, 1, len(shap1))
329
- x2 = np.linspace(0, 1, len(shap2))
330
 
331
- # Create normalized positions for comparison
332
- x_norm = np.linspace(0, 1, num_points)
333
 
334
- # Interpolate both sequences to the normalized positions
335
- shap1_interp = np.interp(x_norm, x1, shap1)
336
- shap2_interp = np.interp(x_norm, x2, shap2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- return shap1_interp, shap2_interp
339
-
340
- def compute_shap_difference(shap1_norm, shap2_norm):
341
- """Compute the SHAP difference between normalized sequences"""
342
- return shap2_norm - shap1_norm
343
 
344
- def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
345
  """
346
- Plot heatmap using relative positions (0-100%)
 
347
  """
348
- heatmap_data = shap_diff.reshape(1, -1)
349
- extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
 
 
350
 
351
- fig, ax = plt.subplots(figsize=(12, 1.8))
352
- cmap = get_zero_centered_cmap()
353
- cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
 
354
 
355
- # Create percentage-based x-axis ticks
356
- num_ticks = 5
357
- tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
358
- tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
359
- ax.set_xticks(tick_positions)
360
- ax.set_xticklabels(tick_labels)
361
 
362
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
363
- cbar.ax.tick_params(labelsize=8)
364
- cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
365
 
366
- ax.set_yticks([])
367
- ax.set_xlabel('Relative Position in Sequence', fontsize=10)
368
- ax.set_title(title, pad=10)
369
- plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
370
 
371
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
374
  """
375
- Compare two sequences using relative positions (0-1 scale)
376
  """
377
- # Analyze first sequence
378
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
379
  if isinstance(res1[0], str) and "Error" in res1[0]:
380
  return (f"Error in sequence 1: {res1[0]}", None, None)
381
 
382
- # Analyze second sequence
383
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
384
  if isinstance(res2[0], str) and "Error" in res2[0]:
385
  return (f"Error in sequence 2: {res2[0]}", None, None)
@@ -387,53 +429,67 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
387
  shap1 = res1[3]["shap_means"]
388
  shap2 = res2[3]["shap_means"]
389
 
390
- # Normalize to relative positions
391
- shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
392
- shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  # Calculate statistics
395
  avg_diff = np.mean(shap_diff)
396
  std_diff = np.std(shap_diff)
397
  max_diff = np.max(shap_diff)
398
  min_diff = np.min(shap_diff)
399
-
400
- threshold = 0.05
401
- substantial_diffs = np.abs(shap_diff) > threshold
402
  frac_different = np.mean(substantial_diffs)
403
 
404
- # Format output text
405
- len1_formatted = "{:,}".format(len(shap1))
406
- len2_formatted = "{:,}".format(len(shap2))
407
- classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
408
- classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
409
-
410
  comparison_text = (
411
  "Sequence Comparison Results:\n"
412
- f"Sequence 1: {res1[4]} (Length: {len1_formatted} bases)\n"
413
- f"Classification: {classification1}\n\n"
414
- f"Sequence 2: {res2[4]} (Length: {len2_formatted} bases)\n"
415
- f"Classification: {classification2}\n\n"
416
- "Comparison Statistics:\n"
 
 
 
 
 
417
  f"Average SHAP difference: {avg_diff:.4f}\n"
418
  f"Standard deviation: {std_diff:.4f}\n"
419
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
420
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
421
- f"Fraction of positions with substantial differences: {frac_different:.2%}\n\n"
422
- "Note: Comparisons shown at relative positions (0-100%) in each sequence\n"
423
  "Interpretation:\n"
424
- "- Red regions: Sequence 2 is more human-like at that relative position\n"
425
- "- Blue regions: Sequence 1 is more human-like at that relative position\n"
426
  "- White regions: Similar between sequences"
427
  )
428
 
429
  # Generate visualizations
430
- heatmap_fig = plot_comparative_heatmap(shap_diff)
 
 
 
431
  heatmap_img = fig_to_image(heatmap_fig)
432
- hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
 
 
 
433
  hist_img = fig_to_image(hist_fig)
434
 
435
  return comparison_text, heatmap_img, hist_img
436
-
437
  ###############################################################################
438
  # 10. BUILD GRADIO INTERFACE
439
  ###############################################################################
 
319
  # 9. COMPARISON ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
322
+ def calculate_adaptive_parameters(len1, len2):
323
  """
324
+ Calculate adaptive parameters based on sequence lengths and their difference.
325
+
326
+ Returns:
327
+ tuple: (num_points, smooth_window, resolution_factor)
328
  """
329
+ length_diff = abs(len1 - len2)
330
+ max_length = max(len1, len2)
331
+ length_ratio = min(len1, len2) / max_length
332
 
333
+ # Base number of points scales with sequence length
334
+ base_points = min(2000, max(500, max_length // 100))
335
 
336
+ # Adjust resolution based on length difference
337
+ if length_diff < 500:
338
+ resolution_factor = 2.0 # Higher resolution for very similar sequences
339
+ num_points = min(3000, base_points * 2)
340
+ smooth_window = max(10, length_diff // 50) # Minimal smoothing
341
+ elif length_diff < 5000:
342
+ resolution_factor = 1.5
343
+ num_points = min(2000, base_points * 1.5)
344
+ smooth_window = max(20, length_diff // 100)
345
+ elif length_diff < 50000:
346
+ resolution_factor = 1.0
347
+ num_points = base_points
348
+ smooth_window = max(50, length_diff // 200)
349
+ else:
350
+ # For very large differences, reduce resolution but increase smoothing
351
+ resolution_factor = 0.75
352
+ num_points = max(500, base_points // 2)
353
+ smooth_window = max(100, length_diff // 500)
354
 
355
+ # Adjust window size based on length ratio
356
+ smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
357
+
358
+ return int(num_points), int(smooth_window), resolution_factor
 
359
 
360
+ def sliding_window_smooth(values, window_size=50):
361
  """
362
+ Apply sliding window smoothing with edge handling.
363
+ Uses exponential decay at edges to reduce boundary effects.
364
  """
365
+ if window_size < 3:
366
+ return values
367
+
368
+ window = np.ones(window_size)
369
 
370
+ # Create exponential decay at edges
371
+ decay = np.exp(-np.linspace(0, 3, window_size // 2))
372
+ window[:window_size // 2] = decay
373
+ window[-(window_size // 2):] = decay[::-1]
374
 
375
+ # Normalize window
376
+ window = window / window.sum()
 
 
 
 
377
 
378
+ # Apply convolution
379
+ smoothed = np.convolve(values, window, mode='valid')
 
380
 
381
+ # Handle edges
382
+ pad_size = len(values) - len(smoothed)
383
+ pad_left = pad_size // 2
384
+ pad_right = pad_size - pad_left
385
 
386
+ # Use actual values at edges instead of padding
387
+ result = np.zeros_like(values)
388
+ result[pad_left:-pad_right] = smoothed
389
+ result[:pad_left] = values[:pad_left] # Keep original values at start
390
+ result[-pad_right:] = values[-pad_right:] # Keep original values at end
391
+
392
+ return result
393
+
394
+ def normalize_shap_lengths(shap1, shap2, num_points=1000, smooth_window=50):
395
+ """
396
+ Normalize and smooth SHAP values with dynamic adaptation.
397
+ """
398
+ # Calculate adaptive parameters
399
+ num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
400
+
401
+ # Apply initial smoothing
402
+ shap1_smooth = sliding_window_smooth(shap1, smooth_window)
403
+ shap2_smooth = sliding_window_smooth(shap2, smooth_window)
404
+
405
+ # Create relative positions
406
+ x1 = np.linspace(0, 1, len(shap1_smooth))
407
+ x2 = np.linspace(0, 1, len(shap2_smooth))
408
+ x_norm = np.linspace(0, 1, num_points)
409
+
410
+ # Interpolate smoothed values
411
+ shap1_interp = np.interp(x_norm, x1, shap1_smooth)
412
+ shap2_interp = np.interp(x_norm, x2, shap2_smooth)
413
+
414
+ return shap1_interp, shap2_interp, smooth_window
415
 
416
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
417
  """
418
+ Fully dynamic sequence comparison with adaptive parameters.
419
  """
420
+ # Analyze sequences
421
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
422
  if isinstance(res1[0], str) and "Error" in res1[0]:
423
  return (f"Error in sequence 1: {res1[0]}", None, None)
424
 
 
425
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
426
  if isinstance(res2[0], str) and "Error" in res2[0]:
427
  return (f"Error in sequence 2: {res2[0]}", None, None)
 
429
  shap1 = res1[3]["shap_means"]
430
  shap2 = res2[3]["shap_means"]
431
 
432
+ # Get sequence properties
433
+ len1, len2 = len(shap1), len(shap2)
434
+ length_diff = abs(len1 - len2)
435
+ length_ratio = min(len1, len2) / max(len1, len2)
436
+
437
+ # Get normalized values with adaptive parameters
438
+ shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
439
+ shap_diff = shap2_norm - shap1_norm
440
+
441
+ # Calculate adaptive threshold
442
+ base_threshold = 0.05
443
+ adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
444
+ if length_diff > 50000:
445
+ adaptive_threshold *= 1.5 # More forgiving for very large differences
446
 
447
  # Calculate statistics
448
  avg_diff = np.mean(shap_diff)
449
  std_diff = np.std(shap_diff)
450
  max_diff = np.max(shap_diff)
451
  min_diff = np.min(shap_diff)
452
+ substantial_diffs = np.abs(shap_diff) > adaptive_threshold
 
 
453
  frac_different = np.mean(substantial_diffs)
454
 
455
+ # Format detailed output
 
 
 
 
 
456
  comparison_text = (
457
  "Sequence Comparison Results:\n"
458
+ f"Sequence 1: {res1[4]} (Length: {len1:,} bases)\n"
459
+ f"Classification: {res1[0].split('Classification: ')[1].split('\n')[0].strip()}\n\n"
460
+ f"Sequence 2: {res2[4]} (Length: {len2:,} bases)\n"
461
+ f"Classification: {res2[0].split('Classification: ')[1].split('\n')[0].strip()}\n\n"
462
+ f"Comparison Parameters:\n"
463
+ f"Length Difference: {length_diff:,} bases\n"
464
+ f"Length Ratio: {length_ratio:.3f}\n"
465
+ f"Smoothing Window: {smooth_window} points\n"
466
+ f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
467
+ "Statistics:\n"
468
  f"Average SHAP difference: {avg_diff:.4f}\n"
469
  f"Standard deviation: {std_diff:.4f}\n"
470
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
471
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
472
+ f"Fraction with substantial differences: {frac_different:.2%}\n\n"
473
+ "Note: All parameters automatically adjusted based on sequence properties\n"
474
  "Interpretation:\n"
475
+ "- Red regions: Sequence 2 more human-like\n"
476
+ "- Blue regions: Sequence 1 more human-like\n"
477
  "- White regions: Similar between sequences"
478
  )
479
 
480
  # Generate visualizations
481
+ heatmap_fig = plot_comparative_heatmap(
482
+ shap_diff,
483
+ title=f"SHAP Difference Heatmap (window: {smooth_window})"
484
+ )
485
  heatmap_img = fig_to_image(heatmap_fig)
486
+
487
+ # Adaptive number of bins based on data
488
+ num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
489
+ hist_fig = plot_shap_histogram(shap_diff, num_bins=num_bins)
490
  hist_img = fig_to_image(hist_fig)
491
 
492
  return comparison_text, heatmap_img, hist_img
 
493
  ###############################################################################
494
  # 10. BUILD GRADIO INTERFACE
495
  ###############################################################################