hiyata commited on
Commit
18efb8a
·
verified ·
1 Parent(s): d01c414

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -285
app.py CHANGED
@@ -67,6 +67,9 @@ def parse_fasta(text):
67
  return sequences
68
 
69
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
 
 
 
70
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
71
  kmer_dict = {km: i for i, km in enumerate(kmers)}
72
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -84,11 +87,15 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
84
  ###############################################################################
85
 
86
  def calculate_shap_values(model, x_tensor):
 
 
 
 
87
  model.eval()
88
  with torch.no_grad():
89
  baseline_output = model(x_tensor)
90
  baseline_probs = torch.softmax(baseline_output, dim=1)
91
- baseline_prob = baseline_probs[0, 1].item() # Prob of 'human'
92
  shap_values = []
93
  x_zeroed = x_tensor.clone()
94
  for i in range(x_tensor.shape[1]):
@@ -106,6 +113,9 @@ def calculate_shap_values(model, x_tensor):
106
  ###############################################################################
107
 
108
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
 
 
109
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
110
  kmer_dict = {km: i for i, km in enumerate(kmers)}
111
  seq_len = len(sequence)
@@ -126,6 +136,9 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
126
  ###############################################################################
127
 
128
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
 
 
129
  n = len(shap_means)
130
  if n == 0:
131
  return (0, 0, 0.0)
@@ -152,6 +165,9 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
152
  ###############################################################################
153
 
154
  def fig_to_image(fig):
 
 
 
155
  buf = io.BytesIO()
156
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
157
  buf.seek(0)
@@ -160,10 +176,16 @@ def fig_to_image(fig):
160
  return img
161
 
162
  def get_zero_centered_cmap():
 
 
 
163
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
164
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
165
 
166
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
 
 
 
167
  if start is not None and end is not None:
168
  local_shap = shap_means[start:end]
169
  subtitle = f" (positions {start}-{end})"
@@ -189,6 +211,9 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
189
  return fig
190
 
191
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
 
 
192
  plt.rcParams.update({'font.size': 10})
193
  fig = plt.figure(figsize=(10, 5))
194
  indices = np.argsort(np.abs(shap_values))[-top_k:]
@@ -204,6 +229,9 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
204
  return fig
205
 
206
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
 
 
 
207
  fig, ax = plt.subplots(figsize=(6, 4))
208
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
209
  ax.axvline(0, color='red', linestyle='--', label='0.0')
@@ -215,8 +243,11 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bin
215
  return fig
216
 
217
  def compute_gc_content(sequence):
 
 
 
218
  if not sequence:
219
- return 0
220
  gc_count = sequence.count('G') + sequence.count('C')
221
  return (gc_count / len(sequence)) * 100.0
222
 
@@ -225,6 +256,11 @@ def compute_gc_content(sequence):
225
  ###############################################################################
226
 
227
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
 
 
 
 
 
228
  if fasta_text.strip():
229
  text = fasta_text.strip()
230
  elif file_obj is not None:
@@ -236,14 +272,15 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
236
  else:
237
  return ("Please provide a FASTA sequence.", None, None, None, None, None)
238
 
 
239
  sequences = parse_fasta(text)
240
  if not sequences:
241
  return ("No valid FASTA sequences found.", None, None, None, None, None)
242
  header, seq = sequences[0]
243
 
 
244
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
245
  try:
246
- # IMPORTANT: adjust how you load your model as needed
247
  state_dict = torch.load('model.pt', map_location=device)
248
  model = VirusClassifier(256).to(device)
249
  model.load_state_dict(state_dict)
@@ -260,10 +297,12 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
260
  classification = "Human" if prob_human > 0.5 else "Non-human"
261
  confidence = max(prob_human, prob_nonhuman)
262
 
 
263
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
264
  max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
265
  min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
266
 
 
267
  results_text = (
268
  f"Sequence: {header}\n"
269
  f"Length: {len(seq):,} bases\n"
@@ -277,6 +316,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
277
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
278
  )
279
 
 
280
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
281
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
282
  bar_img = fig_to_image(bar_fig)
@@ -284,10 +324,10 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
284
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
285
  heatmap_img = fig_to_image(heatmap_fig)
286
 
287
- # You might want to provide a CSV or other data for the 6th return item
288
- # Here, we'll simply return None for the file download:
289
  state_dict_out = {"seq": seq, "shap_means": shap_means}
290
 
 
291
  return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
292
 
293
  ###############################################################################
@@ -295,6 +335,9 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
295
  ###############################################################################
296
 
297
  def analyze_subregion(state, header, region_start, region_end):
 
 
 
298
  if not state or "seq" not in state or "shap_means" not in state:
299
  return ("No sequence data found. Please run Step 1 first.", None, None, None)
300
  seq = state["seq"]
@@ -305,18 +348,22 @@ def analyze_subregion(state, header, region_start, region_end):
305
  region_end = max(0, min(region_end, len(seq)))
306
  if region_end <= region_start:
307
  return ("Invalid region range. End must be > Start.", None, None, None)
 
308
  region_seq = seq[region_start:region_end]
309
  region_shap = shap_means[region_start:region_end]
 
310
  gc_percent = compute_gc_content(region_seq)
311
  avg_shap = float(np.mean(region_shap))
312
  positive_fraction = np.mean(region_shap > 0)
313
  negative_fraction = np.mean(region_shap < 0)
 
314
  if avg_shap > 0.05:
315
  region_classification = "Likely pushing toward human"
316
  elif avg_shap < -0.05:
317
  region_classification = "Likely pushing toward non-human"
318
  else:
319
  region_classification = "Near neutral (no strong push)"
 
320
  region_info = (
321
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
322
  f"Region length: {len(region_seq)} bases\n"
@@ -326,30 +373,29 @@ def analyze_subregion(state, header, region_start, region_end):
326
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
327
  f"Subregion interpretation: {region_classification}\n"
328
  )
 
329
  heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
330
  heatmap_img = fig_to_image(heatmap_fig)
 
331
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
332
  hist_img = fig_to_image(hist_fig)
333
-
334
- # For demonstration, returning None for the file download as well
335
  return (region_info, heatmap_img, hist_img, None)
336
 
337
  ###############################################################################
338
- # 9. COMPARISON ANALYSIS FUNCTIONS
339
  ###############################################################################
340
 
341
- def get_zero_centered_cmap():
342
- """Create a zero-centered blue-white-red colormap"""
343
- colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
344
- return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
345
-
346
  def compute_shap_difference(shap1_norm, shap2_norm):
347
- """Compute the SHAP difference between normalized sequences"""
 
 
348
  return shap2_norm - shap1_norm
349
 
350
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
351
  """
352
- Plot heatmap using relative positions (0-100%)
353
  """
354
  heatmap_data = shap_diff.reshape(1, -1)
355
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
@@ -378,7 +424,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
378
 
379
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
380
  """
381
- Plot histogram of SHAP values with configurable number of bins
382
  """
383
  fig, ax = plt.subplots(figsize=(6, 4))
384
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
@@ -392,18 +438,16 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
392
 
393
  def calculate_adaptive_parameters(len1, len2):
394
  """
395
- Calculate adaptive parameters based on sequence lengths and their difference.
396
- Returns: (num_points, smooth_window, resolution_factor)
397
  """
398
  length_diff = abs(len1 - len2)
399
  max_length = max(len1, len2)
400
  min_length = min(len1, len2)
401
  length_ratio = min_length / max_length
402
 
403
- # Base number of points scales with sequence length
404
  base_points = min(2000, max(500, max_length // 100))
405
 
406
- # Adjust parameters based on sequence properties
407
  if length_diff < 500:
408
  resolution_factor = 2.0
409
  num_points = min(3000, base_points * 2)
@@ -421,29 +465,22 @@ def calculate_adaptive_parameters(len1, len2):
421
  num_points = max(500, base_points // 2)
422
  smooth_window = max(100, length_diff // 500)
423
 
424
- # Adjust window size based on length ratio
425
  smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
426
-
427
  return int(num_points), int(smooth_window), resolution_factor
428
 
429
  def sliding_window_smooth(values, window_size=50):
430
  """
431
- Apply sliding window smoothing with edge handling
432
  """
433
  if window_size < 3:
434
  return values
435
-
436
- # Create window with exponential decay at edges
437
  window = np.ones(window_size)
438
  decay = np.exp(-np.linspace(0, 3, window_size // 2))
439
  window[:window_size // 2] = decay
440
  window[-(window_size // 2):] = decay[::-1]
441
  window = window / window.sum()
442
 
443
- # Apply convolution
444
  smoothed = np.convolve(values, window, mode='valid')
445
-
446
- # Handle edges
447
  pad_size = len(values) - len(smoothed)
448
  pad_left = pad_size // 2
449
  pad_right = pad_size - pad_left
@@ -457,16 +494,13 @@ def sliding_window_smooth(values, window_size=50):
457
 
458
  def normalize_shap_lengths(shap1, shap2):
459
  """
460
- Normalize and smooth SHAP values with dynamic adaptation
461
  """
462
- # Calculate adaptive parameters
463
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
464
 
465
- # Apply initial smoothing
466
  shap1_smooth = sliding_window_smooth(shap1, smooth_window)
467
  shap2_smooth = sliding_window_smooth(shap2, smooth_window)
468
 
469
- # Create relative positions and interpolate
470
  x1 = np.linspace(0, 1, len(shap1_smooth))
471
  x2 = np.linspace(0, 1, len(shap2_smooth))
472
  x_norm = np.linspace(0, 1, num_points)
@@ -478,7 +512,8 @@ def normalize_shap_lengths(shap1, shap2):
478
 
479
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
480
  """
481
- Compare two sequences with adaptive parameters and visualization
 
482
  """
483
  try:
484
  # Analyze first sequence
@@ -491,26 +526,23 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
491
  if isinstance(res2[0], str) and "Error" in res2[0]:
492
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
493
 
494
- # Extract SHAP values and sequence info
495
  shap1 = res1[3]["shap_means"]
496
  shap2 = res2[3]["shap_means"]
497
 
498
- # Calculate sequence properties
499
  len1, len2 = len(shap1), len(shap2)
500
  length_diff = abs(len1 - len2)
501
  length_ratio = min(len1, len2) / max(len1, len2)
502
-
503
- # Normalize and compare sequences
504
  shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
505
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
506
 
507
- # Calculate adaptive threshold and statistics
508
  base_threshold = 0.05
509
  adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
510
  if length_diff > 50000:
511
  adaptive_threshold *= 1.5
512
 
513
- # Calculate comparison statistics
514
  avg_diff = np.mean(shap_diff)
515
  std_diff = np.std(shap_diff)
516
  max_diff = np.max(shap_diff)
@@ -518,7 +550,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
518
  substantial_diffs = np.abs(shap_diff) > adaptive_threshold
519
  frac_different = np.mean(substantial_diffs)
520
 
521
- # Extract classifications
522
  try:
523
  classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
524
  classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
@@ -526,7 +558,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
526
  classification1 = "Unknown"
527
  classification2 = "Unknown"
528
 
529
- # Format output text
530
  comparison_text = (
531
  "Sequence Comparison Results:\n"
532
  f"Sequence 1: {res1[4]}\n"
@@ -553,14 +584,12 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
553
  "- White regions: Similar between sequences"
554
  )
555
 
556
- # Generate visualizations
557
  heatmap_fig = plot_comparative_heatmap(
558
  shap_diff,
559
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
560
  )
561
  heatmap_img = fig_to_image(heatmap_fig)
562
 
563
- # Create histogram with adaptive bins
564
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
565
  hist_fig = plot_shap_histogram(
566
  shap_diff,
@@ -569,7 +598,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
569
  )
570
  hist_img = fig_to_image(hist_fig)
571
 
572
- # Return 4 outputs (text, image, image, and a file or None for the last)
573
  return (comparison_text, heatmap_img, hist_img, None)
574
 
575
  except Exception as e:
@@ -577,23 +605,55 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
577
  return (error_msg, None, None, None)
578
 
579
  ###############################################################################
580
- # 11. GENE FEATURE ANALYSIS
581
  ###############################################################################
582
 
583
- import io
584
- from io import BytesIO
585
- from PIL import Image, ImageDraw, ImageFont
586
- import numpy as np
587
- import pandas as pd
588
- import tempfile
589
- import os
590
- from typing import List, Dict, Tuple, Optional, Any
591
- import matplotlib.pyplot as plt
592
- from matplotlib.colors import LinearSegmentedColormap
593
- import seaborn as sns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
  def parse_gene_features(text: str) -> List[Dict[str, Any]]:
596
- """Parse gene features from text file in FASTA-like format"""
597
  genes = []
598
  current_header = None
599
  current_sequence = []
@@ -602,7 +662,6 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
602
  line = line.strip()
603
  if not line:
604
  continue
605
-
606
  if line.startswith('>'):
607
  if current_header:
608
  genes.append({
@@ -614,36 +673,29 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
614
  current_sequence = []
615
  else:
616
  current_sequence.append(line.upper())
617
-
618
  if current_header:
619
  genes.append({
620
  'header': current_header,
621
  'sequence': ''.join(current_sequence),
622
  'metadata': parse_gene_metadata(current_header)
623
  })
624
-
625
  return genes
626
 
627
  def parse_gene_metadata(header: str) -> Dict[str, str]:
628
- """Extract metadata from gene header"""
629
  metadata = {}
630
  parts = header.split()
631
-
632
  for part in parts:
633
  if '[' in part and ']' in part:
634
  key_value = part[1:-1].split('=', 1)
635
  if len(key_value) == 2:
636
  metadata[key_value[0]] = key_value[1]
637
-
638
  return metadata
639
 
640
  def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
641
- """Parse gene location string, handling both forward and complement strands"""
642
  try:
643
- # Remove 'complement(' and ')' if present
644
  clean_loc = location_str.replace('complement(', '').replace(')', '')
645
-
646
- # Split on '..' and convert to integers
647
  if '..' in clean_loc:
648
  start, end = map(int, clean_loc.split('..'))
649
  return start, end
@@ -654,48 +706,41 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
654
  return None, None
655
 
656
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
657
- """Compute statistical measures for gene SHAP values"""
658
  return {
659
- 'avg_shap': float(np.mean(gene_shap)),
660
- 'median_shap': float(np.median(gene_shap)),
661
- 'std_shap': float(np.std(gene_shap)),
662
- 'max_shap': float(np.max(gene_shap)),
663
- 'min_shap': float(np.min(gene_shap)),
664
- 'pos_fraction': float(np.mean(gene_shap > 0))
665
  }
666
 
667
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
668
  """
669
- Create a simple genome diagram using PIL, forcing a minimum color intensity
670
- so that small SHAP values don't appear white.
671
  """
672
- from PIL import Image, ImageDraw, ImageFont
673
-
674
- # Validate inputs
675
  if not gene_results or genome_length <= 0:
676
  img = Image.new('RGB', (800, 100), color='white')
677
  draw = ImageDraw.Draw(img)
678
  draw.text((10, 40), "Error: Invalid input data", fill='black')
679
  return img
680
-
681
- # Ensure all gene coordinates are valid integers
682
  for gene in gene_results:
683
  gene['start'] = max(0, int(gene['start']))
684
  gene['end'] = min(genome_length, int(gene['end']))
685
  if gene['start'] >= gene['end']:
686
- print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}")
687
 
688
- # Image dimensions
689
  width = 1500
690
  height = 600
691
  margin = 50
692
  track_height = 40
693
 
694
- # Create image with white background
695
  img = Image.new('RGB', (width, height), 'white')
696
  draw = ImageDraw.Draw(img)
697
 
698
- # Try to load font, fall back to default if unavailable
699
  try:
700
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
701
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
@@ -703,24 +748,16 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
703
  font = ImageFont.load_default()
704
  title_font = ImageFont.load_default()
705
 
706
- # Draw title
707
- draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font)
708
 
709
- # Draw genome line
710
  line_y = height // 2
711
  draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
712
 
713
- # Calculate scale factor
714
  scale = float(width - 2 * margin) / float(genome_length)
715
 
716
- # Determine a reasonable step for scale markers
717
  num_ticks = 10
718
- if genome_length < num_ticks:
719
- step = 1
720
- else:
721
- step = genome_length // num_ticks
722
-
723
- # Draw scale markers
724
  for i in range(0, genome_length + 1, step):
725
  x_coord = margin + i * scale
726
  draw.line([
@@ -729,50 +766,33 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
729
  ], fill='black', width=1)
730
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
731
 
732
- # Sort genes by absolute SHAP value for drawing
733
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
734
-
735
- # Draw genes
736
  for idx, gene in enumerate(sorted_genes):
737
- # Calculate position and ensure integers
738
  start_x = margin + int(gene['start'] * scale)
739
  end_x = margin + int(gene['end'] * scale)
740
-
741
- # Calculate color based on SHAP value
742
  avg_shap = gene['avg_shap']
743
-
744
- # Convert shap -> color intensity (0 to 255)
745
- # Then clamp to a minimum intensity so it never ends up plain white
746
  intensity = int(abs(avg_shap) * 500)
747
- intensity = max(50, min(255, intensity)) # clamp between 50 and 255
748
 
749
  if avg_shap > 0:
750
- # Red-ish for positive
751
- color = (255, 255 - intensity, 255 - intensity)
752
  else:
753
- # Blue-ish for negative or zero
754
- color = (255 - intensity, 255 - intensity, 255)
755
 
756
- # Draw gene rectangle
757
  draw.rectangle([
758
  (int(start_x), int(line_y - track_height // 2)),
759
  (int(end_x), int(line_y + track_height // 2))
760
  ], fill=color, outline='black')
761
 
762
- # Prepare gene name label
763
  label = str(gene.get('gene_name','?'))
764
-
765
- # Fallback for label size
766
  label_mask = font.getmask(label)
767
  label_width, label_height = label_mask.size
768
 
769
- # Alternate label positions
770
  if idx % 2 == 0:
771
  text_y = line_y - track_height - 15
772
  else:
773
  text_y = line_y + track_height + 5
774
 
775
- # Decide whether to rotate text based on space
776
  gene_width = end_x - start_x
777
  if gene_width > label_width:
778
  text_x = start_x + (gene_width - label_width) // 2
@@ -784,64 +804,113 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
784
  rotated_img = txt_img.rotate(90, expand=True)
785
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
786
 
787
- # Draw legend
788
- legend_x = margin
789
- legend_y = height - margin
790
- draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font)
791
-
792
- # Draw legend boxes
793
- box_width = 20
794
- box_height = 20
795
- spacing = 15
796
-
797
- # Strong human-like
798
- draw.rectangle([
799
- (int(legend_x), int(legend_y - 45)),
800
- (int(legend_x + box_width), int(legend_y - 45 + box_height))
801
- ], fill=(255, 0, 0), outline='black')
802
- draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)),
803
- "Strong human-like signal", fill='black', font=font)
804
-
805
- # Weak human-like
806
- draw.rectangle([
807
- (int(legend_x), int(legend_y - 20)),
808
- (int(legend_x + box_width), int(legend_y - 20 + box_height))
809
- ], fill=(255, 200, 200), outline='black')
810
- draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)),
811
- "Weak human-like signal", fill='black', font=font)
812
-
813
- # Weak non-human-like
814
- draw.rectangle([
815
- (int(legend_x + 250), int(legend_y - 45)),
816
- (int(legend_x + 250 + box_width), int(legend_y - 45 + box_height))
817
- ], fill=(200, 200, 255), outline='black')
818
- draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)),
819
- "Weak non-human-like signal", fill='black', font=font)
820
-
821
- # Strong non-human-like
822
- draw.rectangle([
823
- (int(legend_x + 250), int(legend_y - 20)),
824
- (int(legend_x + 250 + box_width), int(legend_y - 20 + box_height))
825
- ], fill=(0, 0, 255), outline='black')
826
- draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)),
827
- "Strong non-human-like signal", fill='black', font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  return img
830
 
831
  def analyze_gene_features(sequence_file: str,
832
  features_file: str,
833
  fasta_text: str = "",
834
- features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
835
- """Analyze SHAP values for each gene feature"""
836
- # First analyze whole sequence
 
 
 
 
 
837
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
838
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
839
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
840
 
841
- # Get SHAP values
842
  shap_means = sequence_results[3]["shap_means"]
843
-
844
- # Parse gene features
 
845
  try:
846
  if features_text.strip():
847
  genes = parse_gene_features(features_text)
@@ -850,98 +919,100 @@ def analyze_gene_features(sequence_file: str,
850
  genes = parse_gene_features(f.read())
851
  except Exception as e:
852
  return f"Error reading features file: {str(e)}", None, None
853
-
854
- # Analyze each gene
855
  gene_results = []
856
  for gene in genes:
857
- try:
858
- location = gene['metadata'].get('location', '')
859
- if not location:
860
- continue
861
-
862
- start, end = parse_location(location)
863
- if start is None or end is None:
864
- continue
865
-
866
- # Get SHAP values for this region
867
- gene_shap = shap_means[start:end]
868
- stats = compute_gene_statistics(gene_shap)
869
-
870
- gene_results.append({
871
- 'gene_name': gene['metadata'].get('gene', 'Unknown'),
872
- 'location': location,
873
- 'start': start,
874
- 'end': end,
875
- 'locus_tag': gene['metadata'].get('locus_tag', ''),
876
- 'avg_shap': stats['avg_shap'],
877
- 'median_shap': stats['median_shap'],
878
- 'std_shap': stats['std_shap'],
879
- 'max_shap': stats['max_shap'],
880
- 'min_shap': stats['min_shap'],
881
- 'pos_fraction': stats['pos_fraction'],
882
- 'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human',
883
- 'confidence': abs(stats['avg_shap'])
884
- })
885
-
886
- except Exception as e:
887
- print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
888
  continue
889
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  if not gene_results:
891
  return "No valid genes could be processed", None, None
892
-
893
- # Sort genes by absolute SHAP value
894
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
895
-
896
- # Create results text
897
  results_text = "Gene Analysis Results:\n\n"
898
  results_text += f"Total genes analyzed: {len(gene_results)}\n"
899
- results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n"
900
- results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n"
 
901
 
902
- results_text += "Top 10 most distinctive genes:\n"
903
  for gene in sorted_genes[:10]:
904
  results_text += (
905
  f"Gene: {gene['gene_name']}\n"
906
  f"Location: {gene['location']}\n"
907
  f"Classification: {gene['classification']} "
908
  f"(confidence: {gene['confidence']:.4f})\n"
909
- f"Average SHAP: {gene['avg_shap']:.4f}\n\n"
 
910
  )
911
-
912
- # Create CSV content
913
- csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap,"
914
- csv_content += "pos_fraction,classification,confidence,locus_tag\n"
915
-
916
- for gene in gene_results:
917
  csv_content += (
918
- f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f},"
919
- f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f},"
920
- f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']},"
921
- f"{gene['confidence']:.4f},{gene['locus_tag']}\n"
922
  )
923
-
924
- # Save CSV to temp file
925
  try:
926
  temp_dir = tempfile.gettempdir()
927
  temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
928
-
929
  with open(temp_path, 'w') as f:
930
  f.write(csv_content)
931
  except Exception as e:
932
  print(f"Error saving CSV: {str(e)}")
933
  temp_path = None
934
-
935
- # Create visualization
936
  try:
937
- diagram_img = create_simple_genome_diagram(gene_results, len(shap_means))
 
 
 
938
  except Exception as e:
939
  print(f"Error creating visualization: {str(e)}")
940
- # Create error image
941
  diagram_img = Image.new('RGB', (800, 100), color='white')
942
  draw = ImageDraw.Draw(diagram_img)
943
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
944
-
945
  return results_text, temp_path, diagram_img
946
 
947
  ###############################################################################
@@ -949,13 +1020,14 @@ def analyze_gene_features(sequence_file: str,
949
  ###############################################################################
950
 
951
  def prepare_csv_download(data, filename="analysis_results.csv"):
952
- """Prepare CSV data for download"""
 
 
953
  if isinstance(data, str):
954
  return data.encode(), filename
955
  elif isinstance(data, (list, dict)):
956
  import csv
957
  from io import StringIO
958
-
959
  output = StringIO()
960
  writer = csv.DictWriter(output, fieldnames=data[0].keys())
961
  writer.writeheader()
@@ -979,22 +1051,22 @@ css = """
979
 
980
  with gr.Blocks(css=css) as iface:
981
  gr.Markdown("""
982
- # Virus Host Classifier
983
- **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
984
- **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
985
- **Step 3**: Analyze gene features and their contributions.
986
- **Step 4**: Compare sequences and analyze differences.
987
-
988
- **Color Scale**: Negative SHAP = Blue, Zero = White, Positive SHAP = Red.
989
  """)
990
 
991
  with gr.Tab("1) Full-Sequence Analysis"):
992
  with gr.Row():
993
  with gr.Column(scale=1):
994
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
995
- text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
996
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
997
- win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
998
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
999
  with gr.Column(scale=2):
1000
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
@@ -1013,8 +1085,7 @@ with gr.Blocks(css=css) as iface:
1013
  with gr.Tab("2) Subregion Exploration"):
1014
  gr.Markdown("""
1015
  **Subregion Analysis**
1016
- Select start/end positions to view local SHAP signals, distribution, GC content, etc.
1017
- The heatmap uses the same Blue-White-Red scale.
1018
  """)
1019
  with gr.Row():
1020
  region_start = gr.Number(label="Region Start", value=0)
@@ -1024,7 +1095,7 @@ with gr.Blocks(css=css) as iface:
1024
  with gr.Row():
1025
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1026
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1027
- download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
1028
 
1029
  region_btn.click(
1030
  analyze_subregion,
@@ -1035,60 +1106,48 @@ with gr.Blocks(css=css) as iface:
1035
  with gr.Tab("3) Gene Features Analysis"):
1036
  gr.Markdown("""
1037
  **Analyze Gene Features**
1038
- Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
1039
- Gene features should be in the format:
1040
- ```
1041
- >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
1042
- SEQUENCE
1043
- ```
1044
- The genome viewer will show genes color-coded by their contribution:
1045
- - Red: Genes pushing toward human origin
1046
- - Blue: Genes pushing toward non-human origin
1047
- - Color intensity indicates strength of signal
1048
  """)
1049
  with gr.Row():
1050
  with gr.Column(scale=1):
1051
- gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1052
- gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
1053
  with gr.Column(scale=1):
1054
- features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath")
1055
- features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5)
1056
-
1057
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
1058
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
1059
- gene_diagram = gr.Image(label="Genome Diagram with Gene Features")
1060
  download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
1061
 
1062
  analyze_genes_btn.click(
1063
  analyze_gene_features,
1064
- inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text],
1065
  outputs=[gene_results, download_gene_results, gene_diagram]
1066
  )
1067
 
1068
  with gr.Tab("4) Comparative Analysis"):
1069
  gr.Markdown("""
1070
  **Compare Two Sequences**
1071
- Upload or paste two FASTA sequences to compare their SHAP patterns.
1072
- The sequences will be normalized to the same length for comparison.
1073
-
1074
- **Color Scale**:
1075
- - Red: Sequence 2 more human-like
1076
- - Blue: Sequence 1 more human-like
1077
- - White: No substantial difference
1078
  """)
1079
  with gr.Row():
1080
  with gr.Column(scale=1):
1081
- file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1082
- text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5)
1083
  with gr.Column(scale=1):
1084
- file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1085
- text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5)
1086
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1087
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1088
  with gr.Row():
1089
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1090
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1091
- download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
1092
 
1093
  compare_btn.click(
1094
  analyze_sequence_comparison,
@@ -1097,25 +1156,12 @@ with gr.Blocks(css=css) as iface:
1097
  )
1098
 
1099
  gr.Markdown("""
1100
- ### Interface Features
1101
- - **Overall Classification** (human vs non-human) using k-mer frequencies
1102
- - **SHAP Analysis** shows which k-mers push classification toward or away from human
1103
- - **White-Centered SHAP Gradient**:
1104
- - Negative (blue), 0 (white), Positive (red)
1105
- - Symmetrical color range around 0
1106
- - **Identify Subregions** with strongest push for human or non-human
1107
- - **Gene Feature Analysis**:
1108
- - Analyze individual genes' contributions
1109
- - Interactive genome viewer
1110
- - Gene-level statistics and classification
1111
- - **Sequence Comparison**:
1112
- - Compare two sequences to identify regions of difference
1113
- - Normalized comparison to handle different lengths
1114
- - Statistical summary of differences
1115
- - **Data Export**:
1116
- - Download results as CSV files
1117
- - Save analysis outputs for further processing
1118
  """)
1119
-
1120
  if __name__ == "__main__":
1121
  iface.launch()
 
67
  return sequences
68
 
69
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
70
+ """
71
+ Convert a sequence into a frequency vector of all possible 4-mer combinations.
72
+ """
73
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
74
  kmer_dict = {km: i for i, km in enumerate(kmers)}
75
  vec = np.zeros(len(kmers), dtype=np.float32)
 
87
  ###############################################################################
88
 
89
  def calculate_shap_values(model, x_tensor):
90
+ """
91
+ A simple ablation-based SHAP approximation. Zero out each position
92
+ and measure the impact on the 'human' probability.
93
+ """
94
  model.eval()
95
  with torch.no_grad():
96
  baseline_output = model(x_tensor)
97
  baseline_probs = torch.softmax(baseline_output, dim=1)
98
+ baseline_prob = baseline_probs[0, 1].item() # Probability for 'human'
99
  shap_values = []
100
  x_zeroed = x_tensor.clone()
101
  for i in range(x_tensor.shape[1]):
 
113
  ###############################################################################
114
 
115
  def compute_positionwise_scores(sequence, shap_values, k=4):
116
+ """
117
+ Distribute each k-mer's SHAP contribution across its k underlying positions.
118
+ """
119
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
120
  kmer_dict = {km: i for i, km in enumerate(kmers)}
121
  seq_len = len(sequence)
 
136
  ###############################################################################
137
 
138
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
139
+ """
140
+ Use a sliding window to find the subregion with the highest (or lowest) average SHAP.
141
+ """
142
  n = len(shap_means)
143
  if n == 0:
144
  return (0, 0, 0.0)
 
165
  ###############################################################################
166
 
167
  def fig_to_image(fig):
168
+ """
169
+ Render a Matplotlib figure to a PIL Image.
170
+ """
171
  buf = io.BytesIO()
172
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
173
  buf.seek(0)
 
176
  return img
177
 
178
  def get_zero_centered_cmap():
179
+ """
180
+ Create a symmetrical (blue-white-red) colormap around zero.
181
+ """
182
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
183
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
184
 
185
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
186
+ """
187
+ Plot an inline heatmap for the chosen region (or entire genome if start/end not provided).
188
+ """
189
  if start is not None and end is not None:
190
  local_shap = shap_means[start:end]
191
  subtitle = f" (positions {start}-{end})"
 
211
  return fig
212
 
213
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
214
+ """
215
+ Show bar chart of top k-mers by absolute SHAP value.
216
+ """
217
  plt.rcParams.update({'font.size': 10})
218
  fig = plt.figure(figsize=(10, 5))
219
  indices = np.argsort(np.abs(shap_values))[-top_k:]
 
229
  return fig
230
 
231
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
232
+ """
233
+ Plot a histogram of SHAP values in some region.
234
+ """
235
  fig, ax = plt.subplots(figsize=(6, 4))
236
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
237
  ax.axvline(0, color='red', linestyle='--', label='0.0')
 
243
  return fig
244
 
245
  def compute_gc_content(sequence):
246
+ """
247
+ Compute GC content (%) for a given sequence.
248
+ """
249
  if not sequence:
250
+ return 0.0
251
  gc_count = sequence.count('G') + sequence.count('C')
252
  return (gc_count / len(sequence)) * 100.0
253
 
 
256
  ###############################################################################
257
 
258
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
259
+ """
260
+ Perform the main classification, SHAP analysis, and extreme subregion detection
261
+ for a single sequence.
262
+ """
263
+ # 1) Read input
264
  if fasta_text.strip():
265
  text = fasta_text.strip()
266
  elif file_obj is not None:
 
272
  else:
273
  return ("Please provide a FASTA sequence.", None, None, None, None, None)
274
 
275
+ # 2) Parse FASTA
276
  sequences = parse_fasta(text)
277
  if not sequences:
278
  return ("No valid FASTA sequences found.", None, None, None, None, None)
279
  header, seq = sequences[0]
280
 
281
+ # 3) Load model, scaler, and run inference
282
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
283
  try:
 
284
  state_dict = torch.load('model.pt', map_location=device)
285
  model = VirusClassifier(256).to(device)
286
  model.load_state_dict(state_dict)
 
297
  classification = "Human" if prob_human > 0.5 else "Non-human"
298
  confidence = max(prob_human, prob_nonhuman)
299
 
300
+ # 4) Per-base SHAP & subregion detection
301
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
302
  max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
303
  min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
304
 
305
+ # 5) Prepare result text
306
  results_text = (
307
  f"Sequence: {header}\n"
308
  f"Length: {len(seq):,} bases\n"
 
316
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
317
  )
318
 
319
+ # 6) Create bar & heatmap figures
320
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
321
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
322
  bar_img = fig_to_image(bar_fig)
 
324
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
325
  heatmap_img = fig_to_image(heatmap_fig)
326
 
327
+ # 7) Build the "state" dictionary so we can do subregion analysis
 
328
  state_dict_out = {"seq": seq, "shap_means": shap_means}
329
 
330
+ # Return 6 items to match your Gradio output
331
  return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
332
 
333
  ###############################################################################
 
335
  ###############################################################################
336
 
337
  def analyze_subregion(state, header, region_start, region_end):
338
+ """
339
+ Examine a subregion’s SHAP distribution, GC content, etc.
340
+ """
341
  if not state or "seq" not in state or "shap_means" not in state:
342
  return ("No sequence data found. Please run Step 1 first.", None, None, None)
343
  seq = state["seq"]
 
348
  region_end = max(0, min(region_end, len(seq)))
349
  if region_end <= region_start:
350
  return ("Invalid region range. End must be > Start.", None, None, None)
351
+
352
  region_seq = seq[region_start:region_end]
353
  region_shap = shap_means[region_start:region_end]
354
+
355
  gc_percent = compute_gc_content(region_seq)
356
  avg_shap = float(np.mean(region_shap))
357
  positive_fraction = np.mean(region_shap > 0)
358
  negative_fraction = np.mean(region_shap < 0)
359
+
360
  if avg_shap > 0.05:
361
  region_classification = "Likely pushing toward human"
362
  elif avg_shap < -0.05:
363
  region_classification = "Likely pushing toward non-human"
364
  else:
365
  region_classification = "Near neutral (no strong push)"
366
+
367
  region_info = (
368
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
369
  f"Region length: {len(region_seq)} bases\n"
 
373
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
374
  f"Subregion interpretation: {region_classification}\n"
375
  )
376
+
377
  heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
378
  heatmap_img = fig_to_image(heatmap_fig)
379
+
380
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
381
  hist_img = fig_to_image(hist_fig)
382
+
383
+ # Return 4 items to match your Gradio output
384
  return (region_info, heatmap_img, hist_img, None)
385
 
386
  ###############################################################################
387
+ # 9. COMPARISON ANALYSIS FUNCTIONS (Step 4)
388
  ###############################################################################
389
 
 
 
 
 
 
390
  def compute_shap_difference(shap1_norm, shap2_norm):
391
+ """
392
+ Compute the SHAP difference (Seq2 - Seq1).
393
+ """
394
  return shap2_norm - shap1_norm
395
 
396
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
397
  """
398
+ Plot a 1D heatmap of differences using relative positions 0-100%.
399
  """
400
  heatmap_data = shap_diff.reshape(1, -1)
401
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
 
424
 
425
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
426
  """
427
+ Plot a histogram of SHAP values with optional # of bins.
428
  """
429
  fig, ax = plt.subplots(figsize=(6, 4))
430
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
 
438
 
439
  def calculate_adaptive_parameters(len1, len2):
440
  """
441
+ Choose smoothing & interpolation parameters automatically based on length difference.
 
442
  """
443
  length_diff = abs(len1 - len2)
444
  max_length = max(len1, len2)
445
  min_length = min(len1, len2)
446
  length_ratio = min_length / max_length
447
 
448
+ # Base number of points
449
  base_points = min(2000, max(500, max_length // 100))
450
 
 
451
  if length_diff < 500:
452
  resolution_factor = 2.0
453
  num_points = min(3000, base_points * 2)
 
465
  num_points = max(500, base_points // 2)
466
  smooth_window = max(100, length_diff // 500)
467
 
 
468
  smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
 
469
  return int(num_points), int(smooth_window), resolution_factor
470
 
471
  def sliding_window_smooth(values, window_size=50):
472
  """
473
+ A custom smoothing approach, including exponential decay at edges.
474
  """
475
  if window_size < 3:
476
  return values
 
 
477
  window = np.ones(window_size)
478
  decay = np.exp(-np.linspace(0, 3, window_size // 2))
479
  window[:window_size // 2] = decay
480
  window[-(window_size // 2):] = decay[::-1]
481
  window = window / window.sum()
482
 
 
483
  smoothed = np.convolve(values, window, mode='valid')
 
 
484
  pad_size = len(values) - len(smoothed)
485
  pad_left = pad_size // 2
486
  pad_right = pad_size - pad_left
 
494
 
495
  def normalize_shap_lengths(shap1, shap2):
496
  """
497
+ Smooth, interpolate, and return arrays of the same length for direct comparison.
498
  """
 
499
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
500
 
 
501
  shap1_smooth = sliding_window_smooth(shap1, smooth_window)
502
  shap2_smooth = sliding_window_smooth(shap2, smooth_window)
503
 
 
504
  x1 = np.linspace(0, 1, len(shap1_smooth))
505
  x2 = np.linspace(0, 1, len(shap2_smooth))
506
  x_norm = np.linspace(0, 1, num_points)
 
512
 
513
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
514
  """
515
+ Compare two sequences using the previously defined analysis pipeline
516
+ and produce difference visualizations & stats.
517
  """
518
  try:
519
  # Analyze first sequence
 
526
  if isinstance(res2[0], str) and "Error" in res2[0]:
527
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
528
 
 
529
  shap1 = res1[3]["shap_means"]
530
  shap2 = res2[3]["shap_means"]
531
 
 
532
  len1, len2 = len(shap1), len(shap2)
533
  length_diff = abs(len1 - len2)
534
  length_ratio = min(len1, len2) / max(len1, len2)
535
+
536
+ # Normalize both to the same length
537
  shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
538
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
539
 
540
+ # Compute stats
541
  base_threshold = 0.05
542
  adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
543
  if length_diff > 50000:
544
  adaptive_threshold *= 1.5
545
 
 
546
  avg_diff = np.mean(shap_diff)
547
  std_diff = np.std(shap_diff)
548
  max_diff = np.max(shap_diff)
 
550
  substantial_diffs = np.abs(shap_diff) > adaptive_threshold
551
  frac_different = np.mean(substantial_diffs)
552
 
553
+ # Extract classification from text
554
  try:
555
  classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
556
  classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
 
558
  classification1 = "Unknown"
559
  classification2 = "Unknown"
560
 
 
561
  comparison_text = (
562
  "Sequence Comparison Results:\n"
563
  f"Sequence 1: {res1[4]}\n"
 
584
  "- White regions: Similar between sequences"
585
  )
586
 
 
587
  heatmap_fig = plot_comparative_heatmap(
588
  shap_diff,
589
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
590
  )
591
  heatmap_img = fig_to_image(heatmap_fig)
592
 
 
593
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
594
  hist_fig = plot_shap_histogram(
595
  shap_diff,
 
598
  )
599
  hist_img = fig_to_image(hist_fig)
600
 
 
601
  return (comparison_text, heatmap_img, hist_img, None)
602
 
603
  except Exception as e:
 
605
  return (error_msg, None, None, None)
606
 
607
  ###############################################################################
608
+ # 10. ADDITIONAL / ADVANCED VISUALIZATIONS & STATISTICS
609
  ###############################################################################
610
 
611
+ def n50_length(sequence):
612
+ """
613
+ Calculate the N50 for a single continuous sequence (for demonstration).
614
+ For a single sequence, N50 is typically the length if it's just one piece,
615
+ but let's do a simplistic example.
616
+ """
617
+ # If you had contigs, you'd do a sorted list, cumulative sums, etc.
618
+ # We'll do a trivial approach here:
619
+ return len(sequence) # Because we have only one contiguous region
620
+
621
+ def sequence_complexity(sequence):
622
+ """
623
+ Compute a simple measure of 'sequence complexity'.
624
+ Here, we define complexity as the Shannon entropy over the nucleotides.
625
+ """
626
+ from math import log2
627
+ length = len(sequence)
628
+ if length == 0:
629
+ return 0.0
630
+ freq = {}
631
+ for base in sequence:
632
+ freq[base] = freq.get(base, 0) + 1
633
+ complexity = 0.0
634
+ for base, count in freq.items():
635
+ p = count / length
636
+ complexity -= p * log2(p)
637
+ return complexity
638
+
639
+ def advanced_gene_statistics(gene_shap: np.ndarray, gene_seq: str) -> Dict[str, float]:
640
+ """
641
+ Additional stats: N50, complexity, etc.
642
+ """
643
+ stats = {}
644
+ stats['n50'] = len(gene_seq) # trivial for a single gene region
645
+ stats['entropy'] = sequence_complexity(gene_seq)
646
+ stats['avg_shap'] = float(np.mean(gene_shap))
647
+ stats['max_shap'] = float(np.max(gene_shap)) if len(gene_shap) else 0.0
648
+ stats['min_shap'] = float(np.min(gene_shap)) if len(gene_shap) else 0.0
649
+ return stats
650
+
651
+ ###############################################################################
652
+ # 11. GENE FEATURE ANALYSIS
653
+ ###############################################################################
654
 
655
  def parse_gene_features(text: str) -> List[Dict[str, Any]]:
656
+ """Parse gene features from text file in a FASTA-like format."""
657
  genes = []
658
  current_header = None
659
  current_sequence = []
 
662
  line = line.strip()
663
  if not line:
664
  continue
 
665
  if line.startswith('>'):
666
  if current_header:
667
  genes.append({
 
673
  current_sequence = []
674
  else:
675
  current_sequence.append(line.upper())
 
676
  if current_header:
677
  genes.append({
678
  'header': current_header,
679
  'sequence': ''.join(current_sequence),
680
  'metadata': parse_gene_metadata(current_header)
681
  })
 
682
  return genes
683
 
684
  def parse_gene_metadata(header: str) -> Dict[str, str]:
685
+ """Extract metadata from gene header line."""
686
  metadata = {}
687
  parts = header.split()
 
688
  for part in parts:
689
  if '[' in part and ']' in part:
690
  key_value = part[1:-1].split('=', 1)
691
  if len(key_value) == 2:
692
  metadata[key_value[0]] = key_value[1]
 
693
  return metadata
694
 
695
  def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
696
+ """Parse gene location string, handling forward and complement strands."""
697
  try:
 
698
  clean_loc = location_str.replace('complement(', '').replace(')', '')
 
 
699
  if '..' in clean_loc:
700
  start, end = map(int, clean_loc.split('..'))
701
  return start, end
 
706
  return None, None
707
 
708
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
709
+ """Basic statistical measures for gene SHAP values."""
710
  return {
711
+ 'avg_shap': float(np.mean(gene_shap)) if len(gene_shap) else 0.0,
712
+ 'median_shap': float(np.median(gene_shap)) if len(gene_shap) else 0.0,
713
+ 'std_shap': float(np.std(gene_shap)) if len(gene_shap) else 0.0,
714
+ 'max_shap': float(np.max(gene_shap)) if len(gene_shap) else 0.0,
715
+ 'min_shap': float(np.min(gene_shap)) if len(gene_shap) else 0.0,
716
+ 'pos_fraction': float(np.mean(gene_shap > 0)) if len(gene_shap) else 0.0
717
  }
718
 
719
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
720
  """
721
+ A quick PIL-based diagram to show genes along the genome.
722
+ Color intensity = magnitude of SHAP. Red/Blue = sign of SHAP.
723
  """
 
 
 
724
  if not gene_results or genome_length <= 0:
725
  img = Image.new('RGB', (800, 100), color='white')
726
  draw = ImageDraw.Draw(img)
727
  draw.text((10, 40), "Error: Invalid input data", fill='black')
728
  return img
729
+
 
730
  for gene in gene_results:
731
  gene['start'] = max(0, int(gene['start']))
732
  gene['end'] = min(genome_length, int(gene['end']))
733
  if gene['start'] >= gene['end']:
734
+ print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}")
735
 
 
736
  width = 1500
737
  height = 600
738
  margin = 50
739
  track_height = 40
740
 
 
741
  img = Image.new('RGB', (width, height), 'white')
742
  draw = ImageDraw.Draw(img)
743
 
 
744
  try:
745
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
746
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
 
748
  font = ImageFont.load_default()
749
  title_font = ImageFont.load_default()
750
 
751
+ draw.text((margin, margin // 2), "Genome SHAP Analysis (Simple)", fill='black', font=title_font or font)
 
752
 
 
753
  line_y = height // 2
754
  draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
755
 
 
756
  scale = float(width - 2 * margin) / float(genome_length)
757
 
758
+ # Scale markers
759
  num_ticks = 10
760
+ step = max(1, genome_length // num_ticks)
 
 
 
 
 
761
  for i in range(0, genome_length + 1, step):
762
  x_coord = margin + i * scale
763
  draw.line([
 
766
  ], fill='black', width=1)
767
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
768
 
 
769
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
 
 
770
  for idx, gene in enumerate(sorted_genes):
 
771
  start_x = margin + int(gene['start'] * scale)
772
  end_x = margin + int(gene['end'] * scale)
 
 
773
  avg_shap = gene['avg_shap']
 
 
 
774
  intensity = int(abs(avg_shap) * 500)
775
+ intensity = max(50, min(255, intensity))
776
 
777
  if avg_shap > 0:
778
+ color = (255, 255 - intensity, 255 - intensity) # Redish
 
779
  else:
780
+ color = (255 - intensity, 255 - intensity, 255) # Blueish
 
781
 
 
782
  draw.rectangle([
783
  (int(start_x), int(line_y - track_height // 2)),
784
  (int(end_x), int(line_y + track_height // 2))
785
  ], fill=color, outline='black')
786
 
 
787
  label = str(gene.get('gene_name','?'))
 
 
788
  label_mask = font.getmask(label)
789
  label_width, label_height = label_mask.size
790
 
 
791
  if idx % 2 == 0:
792
  text_y = line_y - track_height - 15
793
  else:
794
  text_y = line_y + track_height + 5
795
 
 
796
  gene_width = end_x - start_x
797
  if gene_width > label_width:
798
  text_x = start_x + (gene_width - label_width) // 2
 
804
  rotated_img = txt_img.rotate(90, expand=True)
805
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
806
 
807
+ return img
808
+
809
+ def create_advanced_genome_diagram(gene_results: List[Dict[str, Any]],
810
+ genome_length: int,
811
+ shap_means: np.ndarray,
812
+ diagram_title: str = "Advanced Genome Diagram") -> Image.Image:
813
+ """
814
+ An advanced genome diagram using Biopython's GenomeDiagram.
815
+ We'll create tracks for genes and a 'SHAP line plot' track.
816
+ """
817
+ if not gene_results or genome_length <= 0 or len(shap_means) == 0:
818
+ # Fallback if data is invalid
819
+ img = Image.new('RGB', (800, 100), color='white')
820
+ d = ImageDraw.Draw(img)
821
+ d.text((10, 40), "Error: Not enough data for advanced diagram", fill='black')
822
+ return img
823
+
824
+ diagram = GenomeDiagram.Diagram(diagram_title)
825
+ gene_track = diagram.new_track(1, name="Genes", greytrack=False, height=0.5)
826
+ gene_set = gene_track.new_set()
827
+
828
+ # Add each gene as a feature
829
+ for gene in gene_results:
830
+ start = max(0, int(gene['start']))
831
+ end = min(genome_length, int(gene['end']))
832
+ avg_shap = gene['avg_shap']
833
+ # Color scale: negative = blue, positive = red
834
+ intensity = abs(avg_shap) * 500
835
+ intensity = max(50, min(255, intensity))
836
+ if avg_shap >= 0:
837
+ color_hex = colors.Color(1.0, 1.0 - intensity/255.0, 1.0 - intensity/255.0)
838
+ else:
839
+ color_hex = colors.Color(1.0 - intensity/255.0, 1.0 - intensity/255.0, 1.0)
840
+
841
+ feature = SeqFeature(FeatureLocation(start, end), strand=1)
842
+ gene_set.add_feature(
843
+ feature,
844
+ color=color_hex,
845
+ label=True,
846
+ name=str(gene.get('gene_name','?')),
847
+ label_size=8,
848
+ label_color=colors.black
849
+ )
850
+
851
+ # Add a track for the SHAP line
852
+ shap_track = diagram.new_track(2, name="SHAP Score", greytrack=False, height=0.3)
853
+ shap_set = shap_track.new_set("graph")
854
+ # We'll plot the entire shap_means array.
855
+ # X coords = [0..genome_length], Y coords = shap_means
856
+ # We'll keep negative values below baseline, positive above.
857
+
858
+ # Normalizing for visualization
859
+ max_abs = max(abs(shap_means.min()), abs(shap_means.max()))
860
+ if max_abs == 0:
861
+ scaled_shap = [0]*len(shap_means)
862
+ else:
863
+ scaled_shap = (shap_means / max_abs * 50).tolist() # scale to +/- 50
864
 
865
+ shap_set.add_graph(
866
+ data=scaled_shap,
867
+ name="shap_line",
868
+ style="line",
869
+ color=colors.darkgreen,
870
+ altcolor=colors.red,
871
+ linewidth=1
872
+ )
873
+
874
+ # Draw to a temporary file
875
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmpf:
876
+ diagram.draw(format="linear", pagesize='A3', fragments=1, start=0, end=genome_length)
877
+ diagram.write(tmpf.name, "PDF")
878
+
879
+ # Convert PDF to a PIL image (requires poppler or similar).
880
+ # If you do not have poppler, you can skip PDF -> image or use Cairo.
881
+ try:
882
+ import pdf2image
883
+ pages = pdf2image.convert_from_path(tmpf.name, dpi=100)
884
+ img = pages[0] if pages else Image.new('RGB', (800, 100), color='white')
885
+ except ImportError:
886
+ img = Image.new('RGB', (800, 100), color='white')
887
+ d = ImageDraw.Draw(img)
888
+ d.text((10, 40), "pdf2image not installed, can't show advanced diagram as image.", fill='black')
889
+
890
+ # Cleanup
891
+ os.remove(tmpf.name)
892
  return img
893
 
894
  def analyze_gene_features(sequence_file: str,
895
  features_file: str,
896
  fasta_text: str = "",
897
+ features_text: str = "",
898
+ diagram_mode: str = "advanced"
899
+ ) -> Tuple[str, Optional[str], Optional[Image.Image]]:
900
+ """
901
+ Analyze each gene in the features file, compute gene-level SHAP stats,
902
+ produce tabular output, and create an optional genome diagram.
903
+ """
904
+ # 1) Analyze the entire sequence with the top-level function
905
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
906
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
907
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
908
 
909
+ seq = sequence_results[3]["seq"]
910
  shap_means = sequence_results[3]["shap_means"]
911
+ genome_length = len(seq)
912
+
913
+ # 2) Read gene features
914
  try:
915
  if features_text.strip():
916
  genes = parse_gene_features(features_text)
 
919
  genes = parse_gene_features(f.read())
920
  except Exception as e:
921
  return f"Error reading features file: {str(e)}", None, None
922
+
 
923
  gene_results = []
924
  for gene in genes:
925
+ location = gene['metadata'].get('location', '')
926
+ if not location:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  continue
928
+ start, end = parse_location(location)
929
+ if start is None or end is None or start >= end or end > genome_length:
930
+ continue
931
+ gene_shap = shap_means[start:end]
932
+ basic_stats = compute_gene_statistics(gene_shap)
933
+ # Additional stats
934
+ gene_seq = seq[start:end]
935
+ adv_stats = advanced_gene_statistics(gene_shap, gene_seq)
936
+
937
+ # Merge basic + advanced stats
938
+ all_stats = {**basic_stats, **adv_stats}
939
+
940
+ classification = 'Human' if basic_stats['avg_shap'] > 0 else 'Non-human'
941
+ locus_tag = gene['metadata'].get('locus_tag', '')
942
+ gene_name = gene['metadata'].get('gene', 'Unknown')
943
+
944
+ gene_dict = {
945
+ 'gene_name': gene_name,
946
+ 'location': location,
947
+ 'start': start,
948
+ 'end': end,
949
+ 'locus_tag': locus_tag,
950
+ 'avg_shap': all_stats['avg_shap'],
951
+ 'median_shap': basic_stats['median_shap'],
952
+ 'std_shap': basic_stats['std_shap'],
953
+ 'max_shap': basic_stats['max_shap'],
954
+ 'min_shap': basic_stats['min_shap'],
955
+ 'pos_fraction': basic_stats['pos_fraction'],
956
+ 'n50': all_stats['n50'],
957
+ 'entropy': all_stats['entropy'],
958
+ 'classification': classification,
959
+ 'confidence': abs(all_stats['avg_shap'])
960
+ }
961
+ gene_results.append(gene_dict)
962
+
963
  if not gene_results:
964
  return "No valid genes could be processed", None, None
965
+
966
+ # 3) Summaries
967
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
 
 
968
  results_text = "Gene Analysis Results:\n\n"
969
  results_text += f"Total genes analyzed: {len(gene_results)}\n"
970
+ num_human = sum(1 for g in gene_results if g['classification'] == 'Human')
971
+ results_text += f"Human-like genes: {num_human}\n"
972
+ results_text += f"Non-human-like genes: {len(gene_results) - num_human}\n\n"
973
 
974
+ results_text += "Top 10 most distinctive genes (by avg SHAP magnitude):\n"
975
  for gene in sorted_genes[:10]:
976
  results_text += (
977
  f"Gene: {gene['gene_name']}\n"
978
  f"Location: {gene['location']}\n"
979
  f"Classification: {gene['classification']} "
980
  f"(confidence: {gene['confidence']:.4f})\n"
981
+ f"Average SHAP: {gene['avg_shap']:.4f}\n"
982
+ f"N50: {gene['n50']}, Entropy: {gene['entropy']:.3f}\n\n"
983
  )
984
+
985
+ # 4) Make CSV
986
+ csv_content = "gene_name,location,start,end,locus_tag,avg_shap,median_shap,std_shap,"
987
+ csv_content += "max_shap,min_shap,pos_fraction,n50,entropy,classification,confidence\n"
988
+ for g in gene_results:
 
989
  csv_content += (
990
+ f"{g['gene_name']},{g['location']},{g['start']},{g['end']},{g['locus_tag']},"
991
+ f"{g['avg_shap']:.4f},{g['median_shap']:.4f},{g['std_shap']:.4f},"
992
+ f"{g['max_shap']:.4f},{g['min_shap']:.4f},{g['pos_fraction']:.4f},"
993
+ f"{g['n50']},{g['entropy']:.4f},{g['classification']},{g['confidence']:.4f}\n"
994
  )
 
 
995
  try:
996
  temp_dir = tempfile.gettempdir()
997
  temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
 
998
  with open(temp_path, 'w') as f:
999
  f.write(csv_content)
1000
  except Exception as e:
1001
  print(f"Error saving CSV: {str(e)}")
1002
  temp_path = None
1003
+
1004
+ # 5) Create diagram
1005
  try:
1006
+ if diagram_mode == "advanced":
1007
+ diagram_img = create_advanced_genome_diagram(gene_results, genome_length, shap_means)
1008
+ else:
1009
+ diagram_img = create_simple_genome_diagram(gene_results, genome_length)
1010
  except Exception as e:
1011
  print(f"Error creating visualization: {str(e)}")
 
1012
  diagram_img = Image.new('RGB', (800, 100), color='white')
1013
  draw = ImageDraw.Draw(diagram_img)
1014
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
1015
+
1016
  return results_text, temp_path, diagram_img
1017
 
1018
  ###############################################################################
 
1020
  ###############################################################################
1021
 
1022
  def prepare_csv_download(data, filename="analysis_results.csv"):
1023
+ """
1024
+ Convert data to CSV for Gradio download button.
1025
+ """
1026
  if isinstance(data, str):
1027
  return data.encode(), filename
1028
  elif isinstance(data, (list, dict)):
1029
  import csv
1030
  from io import StringIO
 
1031
  output = StringIO()
1032
  writer = csv.DictWriter(output, fieldnames=data[0].keys())
1033
  writer.writeheader()
 
1051
 
1052
  with gr.Blocks(css=css) as iface:
1053
  gr.Markdown("""
1054
+ # Virus Host Classifier + Extended Genome Visualization
1055
+ **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme subregions.
1056
+ **Step 2**: Explore subregions (local SHAP, GC content, histogram).
1057
+ **Step 3**: Analyze gene features (per-gene SHAP, advanced stats, improved diagrams).
1058
+ **Step 4**: Compare sequences for SHAP differences.
1059
+
1060
+ **Color Scale**: Negative SHAP = Blue, 0 = White, Positive = Red.
1061
  """)
1062
 
1063
  with gr.Tab("1) Full-Sequence Analysis"):
1064
  with gr.Row():
1065
  with gr.Column(scale=1):
1066
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1067
+ text_input = gr.Textbox(label="Or paste FASTA", placeholder=">name\nACGT...", lines=5)
1068
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
1069
+ win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Subregion Window Size")
1070
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
1071
  with gr.Column(scale=2):
1072
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
 
1085
  with gr.Tab("2) Subregion Exploration"):
1086
  gr.Markdown("""
1087
  **Subregion Analysis**
1088
+ View SHAP signals, GC content, etc. for a specific region.
 
1089
  """)
1090
  with gr.Row():
1091
  region_start = gr.Number(label="Region Start", value=0)
 
1095
  with gr.Row():
1096
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1097
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1098
+ download_subregion = gr.File(label="Download Subregion", visible=False, elem_classes="download-button")
1099
 
1100
  region_btn.click(
1101
  analyze_subregion,
 
1106
  with gr.Tab("3) Gene Features Analysis"):
1107
  gr.Markdown("""
1108
  **Analyze Gene Features**
1109
+ - Upload a FASTA file and a gene features file.
1110
+ - See per-gene SHAP, classification, N50, entropy, etc.
1111
+ - Choose a diagram mode (simple or advanced).
 
 
 
 
 
 
 
1112
  """)
1113
  with gr.Row():
1114
  with gr.Column(scale=1):
1115
+ gene_fasta_file = gr.File(label="FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1116
+ gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", lines=5)
1117
  with gr.Column(scale=1):
1118
+ features_file = gr.File(label="Gene features file", file_types=[".txt"], type="filepath")
1119
+ features_text = gr.Textbox(label="Or paste gene features", lines=5)
1120
+ diagram_mode = gr.Radio(choices=["simple", "advanced"], value="advanced", label="Diagram Mode")
1121
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
1122
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
1123
+ gene_diagram = gr.Image(label="Genome Diagram")
1124
  download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
1125
 
1126
  analyze_genes_btn.click(
1127
  analyze_gene_features,
1128
+ inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text, diagram_mode],
1129
  outputs=[gene_results, download_gene_results, gene_diagram]
1130
  )
1131
 
1132
  with gr.Tab("4) Comparative Analysis"):
1133
  gr.Markdown("""
1134
  **Compare Two Sequences**
1135
+ - Upload or paste two FASTA sequences.
1136
+ - We'll compare SHAP patterns (normalized for different lengths).
 
 
 
 
 
1137
  """)
1138
  with gr.Row():
1139
  with gr.Column(scale=1):
1140
+ file_input1 = gr.File(label="1st FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1141
+ text_input1 = gr.Textbox(label="Or paste 1st FASTA", lines=5)
1142
  with gr.Column(scale=1):
1143
+ file_input2 = gr.File(label="2nd FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1144
+ text_input2 = gr.Textbox(label="Or paste 2nd FASTA", lines=5)
1145
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1146
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1147
  with gr.Row():
1148
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1149
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1150
+ download_comparison = gr.File(label="Download Comparison", visible=False, elem_classes="download-button")
1151
 
1152
  compare_btn.click(
1153
  analyze_sequence_comparison,
 
1156
  )
1157
 
1158
  gr.Markdown("""
1159
+ ### Notes & Features
1160
+ - **Advanced Genome Diagram** uses Biopython’s `GenomeDiagram` (requires `pdf2image` if you want it as an image).
1161
+ - **Additional Stats**: N50, Shannon entropy, etc.
1162
+ - **Auto-scaling** for comparative analysis with adaptive smoothing.
1163
+ - **Data Export**: Download CSV of analysis results.
 
 
 
 
 
 
 
 
 
 
 
 
 
1164
  """)
1165
+
1166
  if __name__ == "__main__":
1167
  iface.launch()