hiyata commited on
Commit
90c03ec
·
verified ·
1 Parent(s): c5accc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -81
app.py CHANGED
@@ -320,92 +320,84 @@ def compute_gc_content(sequence):
320
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
321
  ###############################################################################
322
 
323
- def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
324
  """
325
- Analyzes the entire genome, returning classification, full-genome heatmap,
326
- top k-mer bar plot, and identifies subregions with strongest positive/negative push.
327
  """
328
- # Handle input
329
- if fasta_text.strip():
330
- text = fasta_text.strip()
331
- elif file_obj is not None:
332
- try:
333
- with open(file_obj, 'r') as f:
334
- text = f.read()
335
- except Exception as e:
336
- return (f"Error reading file: {str(e)}", None, None, None, None)
337
- else:
338
- return ("Please provide a FASTA sequence.", None, None, None, None)
339
-
340
- # Parse FASTA
341
- sequences = parse_fasta(text)
342
- if not sequences:
343
- return ("No valid FASTA sequences found.", None, None, None, None)
344
-
345
- header, seq = sequences[0]
346
-
347
- # Load model and scaler
348
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
349
- try:
350
- # Use weights_only=True for safer loading
351
- state_dict = torch.load('model.pt', map_location=device, weights_only=True)
352
- model = VirusClassifier(256).to(device)
353
- model.load_state_dict(state_dict)
354
-
355
- scaler = joblib.load('scaler.pkl')
356
- except Exception as e:
357
- return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
358
-
359
- # Vectorize + scale
360
- freq_vector = sequence_to_kmer_vector(seq)
361
- scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
362
- x_tensor = torch.FloatTensor(scaled_vector).to(device)
363
-
364
- # SHAP + classification
365
- shap_values, prob_human = calculate_shap_values(model, x_tensor)
366
- prob_nonhuman = 1.0 - prob_human
367
-
368
- classification = "Human" if prob_human > 0.5 else "Non-human"
369
- confidence = max(prob_human, prob_nonhuman)
370
-
371
- # Per-base SHAP
372
- shap_means = compute_positionwise_scores(seq, shap_values, k=4)
373
-
374
- # Find the most "human-pushing" region
375
- (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
376
- # Find the most "non-human–pushing" region
377
- (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
378
-
379
- # Build results text
380
- results_text = (
381
- f"Sequence: {header}\n"
382
- f"Length: {len(seq):,} bases\n"
383
- f"Classification: {classification}\n"
384
- f"Confidence: {confidence:.3f}\n"
385
- f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
386
- f"---\n"
387
- f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
388
- f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
389
- f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
390
- f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
391
  )
392
-
393
- # K-mer importance plot
394
- kmers = [''.join(p) for p in product("ACGT", repeat=4)]
395
- bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
396
- bar_img = fig_to_image(bar_fig)
397
-
398
- # Full-genome SHAP heatmap
399
- heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
400
  heatmap_img = fig_to_image(heatmap_fig)
 
 
 
 
 
 
 
 
 
401
 
402
- # Store data for subregion analysis
403
- state_dict_out = {
404
- "seq": seq,
405
- "shap_means": shap_means
406
- }
407
-
408
- return (results_text, bar_img, heatmap_img, state_dict_out, header)
409
 
410
  ###############################################################################
411
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
 
320
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
321
  ###############################################################################
322
 
323
+ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
324
  """
325
+ Compare two sequences by analyzing their SHAP differences.
326
+ Returns comparison text and visualizations.
327
  """
328
+ # Process first sequence
329
+ results1 = analyze_sequence(file1, fasta_text=fasta1)
330
+ if isinstance(results1[0], str) and "Error" in results1[0]:
331
+ return (f"Error in sequence 1: {results1[0]}", None, None)
332
+
333
+ # Process second sequence
334
+ results2 = analyze_sequence(file2, fasta_text=fasta2)
335
+ if isinstance(results2[0], str) and "Error" in results2[0]:
336
+ return (f"Error in sequence 2: {results2[0]}", None, None)
337
+
338
+ # Get SHAP means from state dictionaries
339
+ shap1 = results1[3]["shap_means"]
340
+ shap2 = results2[3]["shap_means"]
341
+
342
+ # Normalize lengths
343
+ shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
344
+
345
+ # Compute difference (positive = seq2 more human-like)
346
+ shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
347
+
348
+ # Calculate some statistics
349
+ avg_diff = np.mean(shap_diff)
350
+ std_diff = np.std(shap_diff)
351
+ max_diff = np.max(shap_diff)
352
+ min_diff = np.min(shap_diff)
353
+
354
+ # Calculate what fraction of positions show substantial differences
355
+ threshold = 0.05 # Arbitrary threshold for "substantial" difference
356
+ substantial_diffs = np.abs(shap_diff) > threshold
357
+ frac_different = np.mean(substantial_diffs)
358
+
359
+ # Generate comparison text
360
+ # Extract classifications without using split on newline
361
+ classification1 = results1[0].split('Classification: ')[1].split('(')[0].strip()
362
+ classification2 = results2[0].split('Classification: ')[1].split('(')[0].strip()
363
+
364
+ # Build the text using format method
365
+ comparison_text = (
366
+ "Sequence Comparison Results:\n"
367
+ "Sequence 1: {}\n"
368
+ "Length: {:,} bases\n"
369
+ "Classification: {}\n\n"
370
+ "Sequence 2: {}\n"
371
+ "Length: {:,} bases\n"
372
+ "Classification: {}\n\n"
373
+ "Comparison Statistics:\n"
374
+ "Average SHAP difference: {:.4f}\n"
375
+ "Standard deviation: {:.4f}\n"
376
+ "Max difference: {:.4f} (Seq2 more human-like)\n"
377
+ "Min difference: {:.4f} (Seq1 more human-like)\n"
378
+ "Fraction of positions with substantial differences: {:.2%}\n\n"
379
+ "Interpretation:\n"
380
+ "Positive values (red) indicate regions where Sequence 2 is more human-like\n"
381
+ "Negative values (blue) indicate regions where Sequence 1 is more human-like"
382
+ ).format(
383
+ results1[4], len(shap1), classification1,
384
+ results2[4], len(shap2), classification2,
385
+ avg_diff, std_diff, max_diff, min_diff, frac_different
 
 
 
 
 
386
  )
387
+
388
+ # Create comparison heatmap
389
+ heatmap_fig = plot_comparative_heatmap(shap_diff)
 
 
 
 
 
390
  heatmap_img = fig_to_image(heatmap_fig)
391
+
392
+ # Create histogram of differences
393
+ hist_fig = plot_shap_histogram(
394
+ shap_diff,
395
+ title="Distribution of SHAP Differences"
396
+ )
397
+ hist_img = fig_to_image(hist_fig)
398
+
399
+ return comparison_text, heatmap_img, hist_img
400
 
 
 
 
 
 
 
 
401
 
402
  ###############################################################################
403
  # 8. SUBREGION ANALYSIS (Gradio Step 2)