hiyata commited on
Commit
82425ee
·
verified ·
1 Parent(s): 05b9733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -90
app.py CHANGED
@@ -9,7 +9,6 @@ import matplotlib.colors as mcolors
9
  import io
10
  from PIL import Image
11
  from scipy.interpolate import interp1d
12
- import numpy as np
13
 
14
  ###############################################################################
15
  # 1. MODEL DEFINITION
@@ -317,91 +316,90 @@ def compute_gc_content(sequence):
317
  return (gc_count / len(sequence)) * 100.0
318
 
319
  ###############################################################################
320
- # 7. MAIN ANALYSIS STEP (Gradio Step 1)
321
  ###############################################################################
322
 
323
- def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
324
  """
325
- Compare two sequences by analyzing their SHAP differences.
326
- Returns comparison text and visualizations.
327
  """
328
- # Process first sequence
329
- results1 = analyze_sequence(file1, fasta_text=fasta1)
330
- if isinstance(results1[0], str) and "Error" in results1[0]:
331
- return (f"Error in sequence 1: {results1[0]}", None, None)
332
-
333
- # Process second sequence
334
- results2 = analyze_sequence(file2, fasta_text=fasta2)
335
- if isinstance(results2[0], str) and "Error" in results2[0]:
336
- return (f"Error in sequence 2: {results2[0]}", None, None)
337
-
338
- # Get SHAP means from state dictionaries
339
- shap1 = results1[3]["shap_means"]
340
- shap2 = results2[3]["shap_means"]
341
-
342
- # Normalize lengths
343
- shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
344
-
345
- # Compute difference (positive = seq2 more human-like)
346
- shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
347
-
348
- # Calculate some statistics
349
- avg_diff = np.mean(shap_diff)
350
- std_diff = np.std(shap_diff)
351
- max_diff = np.max(shap_diff)
352
- min_diff = np.min(shap_diff)
353
-
354
- # Calculate what fraction of positions show substantial differences
355
- threshold = 0.05 # Arbitrary threshold for "substantial" difference
356
- substantial_diffs = np.abs(shap_diff) > threshold
357
- frac_different = np.mean(substantial_diffs)
358
-
359
- # Generate comparison text
360
- # Extract classifications without using split on newline
361
- classification1 = results1[0].split('Classification: ')[1].split('(')[0].strip()
362
- classification2 = results2[0].split('Classification: ')[1].split('(')[0].strip()
363
-
364
- # Build the text using format method
365
- comparison_text = (
366
- "Sequence Comparison Results:\n"
367
- "Sequence 1: {}\n"
368
- "Length: {:,} bases\n"
369
- "Classification: {}\n\n"
370
- "Sequence 2: {}\n"
371
- "Length: {:,} bases\n"
372
- "Classification: {}\n\n"
373
- "Comparison Statistics:\n"
374
- "Average SHAP difference: {:.4f}\n"
375
- "Standard deviation: {:.4f}\n"
376
- "Max difference: {:.4f} (Seq2 more human-like)\n"
377
- "Min difference: {:.4f} (Seq1 more human-like)\n"
378
- "Fraction of positions with substantial differences: {:.2%}\n\n"
379
- "Interpretation:\n"
380
- "Positive values (red) indicate regions where Sequence 2 is more human-like\n"
381
- "Negative values (blue) indicate regions where Sequence 1 is more human-like"
382
- ).format(
383
- results1[4], len(shap1), classification1,
384
- results2[4], len(shap2), classification2,
385
- avg_diff, std_diff, max_diff, min_diff, frac_different
386
- )
387
-
388
- # Create comparison heatmap
389
- heatmap_fig = plot_comparative_heatmap(shap_diff)
390
- heatmap_img = fig_to_image(heatmap_fig)
391
-
392
- # Create histogram of differences
393
- hist_fig = plot_shap_histogram(
394
- shap_diff,
395
- title="Distribution of SHAP Differences"
396
- )
397
- hist_img = fig_to_image(hist_fig)
398
-
399
- return comparison_text, heatmap_img, hist_img
400
-
401
-
402
- ###############################################################################
403
- # 8. SUBREGION ANALYSIS (Gradio Step 2)
404
- ###############################################################################
405
 
406
  def analyze_subregion(state, header, region_start, region_end):
407
  """
@@ -468,9 +466,8 @@ def analyze_subregion(state, header, region_start, region_end):
468
 
469
  return (region_info, heatmap_img, hist_img)
470
 
471
-
472
  ###############################################################################
473
- # NEW SECTION: COMPARATIVE ANALYSIS FUNCTIONS
474
  ###############################################################################
475
 
476
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
@@ -583,7 +580,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
583
  # Compute difference (positive = seq2 more human-like)
584
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
585
 
586
- # Calculate some statistics
587
  avg_diff = np.mean(shap_diff)
588
  std_diff = np.std(shap_diff)
589
  max_diff = np.max(shap_diff)
@@ -594,7 +591,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
594
  substantial_diffs = np.abs(shap_diff) > threshold
595
  frac_different = np.mean(substantial_diffs)
596
 
597
- # Extract classifications safely without using f-strings with backslashes
598
  classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
599
  classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
600
 
@@ -603,7 +600,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
603
  len2_formatted = "{:,}".format(len(shap2))
604
  frac_formatted = "{:.2%}".format(frac_different)
605
 
606
- # Build comparison text without f-strings containing backslashes
607
  comparison_text = (
608
  "Sequence Comparison Results:\n"
609
  f"Sequence 1: {results1[4]}\n"
@@ -635,7 +632,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
635
  hist_img = fig_to_image(hist_fig)
636
 
637
  return comparison_text, heatmap_img, hist_img
638
-
639
  ###############################################################################
640
  # 9. BUILD GRADIO INTERFACE
641
  ###############################################################################
@@ -694,7 +691,6 @@ with gr.Blocks(css=css) as iface:
694
  seq_state = gr.State()
695
  header_state = gr.State()
696
 
697
- # analyze_sequence(...) returns 5 items
698
  analyze_btn.click(
699
  analyze_sequence,
700
  inputs=[file_input, top_k, text_input, win_size],
@@ -781,6 +777,7 @@ with gr.Blocks(css=css) as iface:
781
  inputs=[file_input1, file_input2, text_input1, text_input2],
782
  outputs=[comparison_text, diff_heatmap, diff_hist]
783
  )
 
784
  gr.Markdown("""
785
  ### Interface Features
786
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
@@ -793,7 +790,28 @@ with gr.Blocks(css=css) as iface:
793
  - GC content
794
  - Fraction of positions pushing human vs. non-human
795
  - Simple logic-based classification
 
 
 
 
796
  """)
797
 
 
 
 
 
798
  if __name__ == "__main__":
799
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import io
10
  from PIL import Image
11
  from scipy.interpolate import interp1d
 
12
 
13
  ###############################################################################
14
  # 1. MODEL DEFINITION
 
316
  return (gc_count / len(sequence)) * 100.0
317
 
318
  ###############################################################################
319
+ # 7. SEQUENCE ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
322
+ def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
323
  """
324
+ Analyze a virus sequence from a FASTA file or text input.
325
+ Returns (results_text, kmer_plot, heatmap_plot, state_dict, header)
326
  """
327
+ try:
328
+ # Load model and k-mer info
329
+ model = VirusClassifier(256) # 4^4 = 256 k-mers for k=4
330
+ model.load_state_dict(torch.load("model.pt"))
331
+ model.eval()
332
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
333
+
334
+ # Process input (file takes precedence over text)
335
+ if file_path:
336
+ with open(file_path, 'r') as f:
337
+ fasta_text = f.read()
338
+
339
+ if not fasta_text.strip():
340
+ return ("Error: No sequence provided", None, None, {}, "")
341
+
342
+ # Parse FASTA
343
+ sequences = parse_fasta(fasta_text)
344
+ if not sequences:
345
+ return ("Error: No valid FASTA sequences found", None, None, {}, "")
346
+
347
+ header, sequence = sequences[0] # Take first sequence
348
+
349
+ # Convert to k-mer frequencies
350
+ x = sequence_to_kmer_vector(sequence)
351
+ x_tensor = torch.tensor(x).float().unsqueeze(0)
352
+
353
+ # Get model prediction
354
+ with torch.no_grad():
355
+ output = model(x_tensor)
356
+ probs = torch.softmax(output, dim=1)
357
+ pred_human = probs[0, 1].item()
358
+
359
+ # Calculate SHAP values
360
+ shap_values, prob = calculate_shap_values(model, x_tensor)
361
+
362
+ # Find most extreme regions
363
+ shap_means = compute_positionwise_scores(sequence, shap_values)
364
+ start_max, end_max, avg_max = find_extreme_subregion(shap_means, window_size, mode="max")
365
+ start_min, end_min, avg_min = find_extreme_subregion(shap_means, window_size, mode="min")
366
+
367
+ # Format results text
368
+ classification = "Human" if pred_human > 0.5 else "Non-human"
369
+ results = (
370
+ f"Classification: {classification} "
371
+ f"(probability of human = {pred_human:.3f})\n\n"
372
+ f"Sequence length: {len(sequence):,} bases\n"
373
+ f"Overall GC content: {compute_gc_content(sequence):.1f}%\n\n"
374
+ f"Most human-like {window_size}bp region:\n"
375
+ f"Position {start_max:,} to {end_max:,}\n"
376
+ f"Average SHAP: {avg_max:.4f}\n"
377
+ f"GC content: {compute_gc_content(sequence[start_max:end_max]):.1f}%\n\n"
378
+ f"Least human-like {window_size}bp region:\n"
379
+ f"Position {start_min:,} to {end_min:,}\n"
380
+ f"Average SHAP: {avg_min:.4f}\n"
381
+ f"GC content: {compute_gc_content(sequence[start_min:end_min]):.1f}%"
382
+ )
383
+
384
+ # Create k-mer importance plot
385
+ kmer_fig = create_importance_bar_plot(shap_values, kmers, top_k)
386
+ kmer_img = fig_to_image(kmer_fig)
387
+
388
+ # Create genome-wide heatmap
389
+ heatmap_fig = plot_linear_heatmap(shap_means)
390
+ heatmap_img = fig_to_image(heatmap_fig)
391
+
392
+ # Store data for subregion analysis
393
+ state = {
394
+ "seq": sequence,
395
+ "shap_means": shap_means
396
+ }
397
+
398
+ return results, kmer_img, heatmap_img, state, header
399
+
400
+ except Exception as e:
401
+ error_msg = f"Error analyzing sequence: {str(e)}"
402
+ return (error_msg, None, None, {}, "")
 
403
 
404
  def analyze_subregion(state, header, region_start, region_end):
405
  """
 
466
 
467
  return (region_info, heatmap_img, hist_img)
468
 
 
469
  ###############################################################################
470
+ # 8. COMPARISON ANALYSIS FUNCTIONS
471
  ###############################################################################
472
 
473
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
 
580
  # Compute difference (positive = seq2 more human-like)
581
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
582
 
583
+ # Calculate statistics
584
  avg_diff = np.mean(shap_diff)
585
  std_diff = np.std(shap_diff)
586
  max_diff = np.max(shap_diff)
 
591
  substantial_diffs = np.abs(shap_diff) > threshold
592
  frac_different = np.mean(substantial_diffs)
593
 
594
+ # Extract classifications safely
595
  classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
596
  classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
597
 
 
600
  len2_formatted = "{:,}".format(len(shap2))
601
  frac_formatted = "{:.2%}".format(frac_different)
602
 
603
+ # Build comparison text
604
  comparison_text = (
605
  "Sequence Comparison Results:\n"
606
  f"Sequence 1: {results1[4]}\n"
 
632
  hist_img = fig_to_image(hist_fig)
633
 
634
  return comparison_text, heatmap_img, hist_img
635
+
636
  ###############################################################################
637
  # 9. BUILD GRADIO INTERFACE
638
  ###############################################################################
 
691
  seq_state = gr.State()
692
  header_state = gr.State()
693
 
 
694
  analyze_btn.click(
695
  analyze_sequence,
696
  inputs=[file_input, top_k, text_input, win_size],
 
777
  inputs=[file_input1, file_input2, text_input1, text_input2],
778
  outputs=[comparison_text, diff_heatmap, diff_hist]
779
  )
780
+
781
  gr.Markdown("""
782
  ### Interface Features
783
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
 
790
  - GC content
791
  - Fraction of positions pushing human vs. non-human
792
  - Simple logic-based classification
793
+ - **Sequence Comparison**:
794
+ - Compare two sequences to identify regions of difference
795
+ - Normalized comparison to handle different sequence lengths
796
+ - Statistical summary of differences
797
  """)
798
 
799
+ ###############################################################################
800
+ # 10. MAIN EXECUTION
801
+ ###############################################################################
802
+
803
  if __name__ == "__main__":
804
+ # Set up any global configurations if needed
805
+ plt.style.use('default')
806
+ plt.rcParams['figure.figsize'] = [10, 6]
807
+ plt.rcParams['figure.dpi'] = 100
808
+ plt.rcParams['font.size'] = 10
809
+
810
+ # Launch the interface
811
+ iface.launch(
812
+ share=False, # Set to True to create a public link
813
+ server_name="0.0.0.0", # Listen on all network interfaces
814
+ server_port=7860, # Default Gradio port
815
+ show_api=False, # Hide API docs
816
+ debug=False # Set to True for debugging
817
+ )