hiyata commited on
Commit
03f2bb5
·
verified ·
1 Parent(s): 9a00943

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -40
app.py CHANGED
@@ -140,28 +140,29 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
140
  """
141
  Finds the subregion of length `window_size` that has the maximum
142
  (mode="max") or minimum (mode="min") average SHAP.
143
- Returns (best_start, best_end, avg_shap).
144
  """
145
  n = len(shap_means)
 
 
 
146
  if window_size >= n:
147
- # If the window is bigger than the entire sequence, return the whole seq
148
- avg_val = np.mean(shap_means) if n > 0 else 0.0
149
  return (0, n, avg_val)
150
 
151
- # For efficiency, we can do a rolling sum approach
152
- csum = np.cumsum(shap_means)
153
- # csum[i] = sum of shap_means[0..i-1]
154
- def window_sum(start):
155
- end = start + window_size
156
- return csum[end] - csum[start]
157
-
158
  best_start = 0
159
- # Initialize the best with the first window
160
- best_sum = window_sum(0)
161
  best_avg = best_sum / window_size
162
 
163
  for start in range(1, n - window_size + 1):
164
- wsum = window_sum(start)
165
  wavg = wsum / window_size
166
  if mode == "max":
167
  if wavg > best_avg:
@@ -172,7 +173,7 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
172
  best_avg = wavg
173
  best_start = start
174
 
175
- return (best_start, best_start + window_size, best_avg)
176
 
177
  ###############################################################################
178
  # 6. PLOTTING / UTILITIES
@@ -192,10 +193,9 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
192
  Plots a 1D heatmap of per-base SHAP contributions.
193
  Negative = push toward Non-Human, Positive = push toward Human.
194
  Optionally can show only a subrange (start:end).
195
-
196
- We adjust layout so the colorbar is well below the x-axis:
197
  - orientation='horizontal', pad=0.35
198
- - plt.subplots_adjust(bottom=0.4)
199
  """
200
  if start is not None and end is not None:
201
  shap_means = shap_means[start:end]
@@ -294,11 +294,15 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
294
  # Load model and scaler
295
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
296
  try:
 
 
297
  model = VirusClassifier(256).to(device)
298
- model.load_state_dict(torch.load('model.pt', map_location=device))
 
 
299
  scaler = joblib.load('scaler.pkl')
300
  except Exception as e:
301
- return (f"Error loading model: {str(e)}", None, None, None, None)
302
 
303
  # Vectorize + scale
304
  freq_vector = sequence_to_kmer_vector(seq)
@@ -344,13 +348,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
344
  heatmap_img = fig_to_image(heatmap_fig)
345
 
346
  # Store data for subregion analysis
347
- state_dict = {
348
  "seq": seq,
349
  "shap_means": shap_means
350
  }
351
 
352
- # We now return 5 items (not 6):
353
- return (results_text, bar_img, heatmap_img, state_dict, header)
354
 
355
  ###############################################################################
356
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
@@ -475,16 +479,10 @@ with gr.Blocks(css=css) as iface:
475
  kmer_img = gr.Image(label="Top k-mer SHAP")
476
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
477
 
478
- # State for step 2
479
  seq_state = gr.State()
480
  header_state = gr.State()
481
 
482
- # analyze_sequence(...) now returns 5 items, so we have 5 outputs.
483
- # 1) results_text
484
- # 2) bar_img
485
- # 3) heatmap_img
486
- # 4) state_dict
487
- # 5) header
488
  analyze_btn.click(
489
  analyze_sequence,
490
  inputs=[file_input, top_k, text_input, win_size],
@@ -517,17 +515,12 @@ with gr.Blocks(css=css) as iface:
517
  )
518
 
519
  gr.Markdown("""
520
- ### What does this interface provide?
521
- 1. **Overall Classification** (human vs non-human), using a learned model on k-mer frequencies.
522
- 2. **SHAP Analysis** (ablation-based) to see which k-mer features push classification toward or away from "human".
523
- 3. **Genome-Wide SHAP Heatmap**: Each base's average SHAP across overlapping k-mers.
524
- 4. **Subregion Exploration**:
525
- - Local SHAP signals (heatmap & histogram)
526
- - GC content, fraction of bases pushing "human" vs "non-human"
527
- - Simple logic-based interpretation based on average SHAP
528
- 5. **Identification of the most 'human-pushing' subregion** (max average SHAP)
529
- and the most 'non-human–pushing' subregion (min average SHAP),
530
- each of a chosen window size.
531
  """)
532
 
533
  if __name__ == "__main__":
 
140
  """
141
  Finds the subregion of length `window_size` that has the maximum
142
  (mode="max") or minimum (mode="min") average SHAP.
143
+ Returns (best_start, best_end, best_avg).
144
  """
145
  n = len(shap_means)
146
+ if n == 0:
147
+ # Edge case: empty array
148
+ return (0, 0, 0.0)
149
  if window_size >= n:
150
+ # If the window is bigger than the entire sequence, return entire seq
151
+ avg_val = float(np.mean(shap_means))
152
  return (0, n, avg_val)
153
 
154
+ # We'll build csum as length n+1 so csum[i] = sum of shap_means[:i]
155
+ # That means sum in [start, start+window_size) = csum[start+window_size] - csum[start].
156
+ csum = np.zeros(n + 1, dtype=np.float32)
157
+ csum[1:] = np.cumsum(shap_means)
158
+
 
 
159
  best_start = 0
160
+ # Initialize with the first window: [0, window_size)
161
+ best_sum = csum[window_size] - csum[0]
162
  best_avg = best_sum / window_size
163
 
164
  for start in range(1, n - window_size + 1):
165
+ wsum = csum[start + window_size] - csum[start]
166
  wavg = wsum / window_size
167
  if mode == "max":
168
  if wavg > best_avg:
 
173
  best_avg = wavg
174
  best_start = start
175
 
176
+ return (best_start, best_start + window_size, float(best_avg))
177
 
178
  ###############################################################################
179
  # 6. PLOTTING / UTILITIES
 
193
  Plots a 1D heatmap of per-base SHAP contributions.
194
  Negative = push toward Non-Human, Positive = push toward Human.
195
  Optionally can show only a subrange (start:end).
196
+ Adjust layout so the colorbar is well below the x-axis:
 
197
  - orientation='horizontal', pad=0.35
198
+ - plt.subplots_adjust(bottom=0.4)
199
  """
200
  if start is not None and end is not None:
201
  shap_means = shap_means[start:end]
 
294
  # Load model and scaler
295
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
296
  try:
297
+ # Use weights_only=True to address the FutureWarning about untrusted pickle data
298
+ state_dict = torch.load('model.pt', map_location=device, weights_only=True)
299
  model = VirusClassifier(256).to(device)
300
+ model.load_state_dict(state_dict)
301
+
302
+ # Load scaler (warning if version mismatch)
303
  scaler = joblib.load('scaler.pkl')
304
  except Exception as e:
305
+ return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
306
 
307
  # Vectorize + scale
308
  freq_vector = sequence_to_kmer_vector(seq)
 
348
  heatmap_img = fig_to_image(heatmap_fig)
349
 
350
  # Store data for subregion analysis
351
+ state_dict_out = {
352
  "seq": seq,
353
  "shap_means": shap_means
354
  }
355
 
356
+ # Return exactly 5 items
357
+ return (results_text, bar_img, heatmap_img, state_dict_out, header)
358
 
359
  ###############################################################################
360
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
 
479
  kmer_img = gr.Image(label="Top k-mer SHAP")
480
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
481
 
 
482
  seq_state = gr.State()
483
  header_state = gr.State()
484
 
485
+ # analyze_sequence(...) returns 5 items.
 
 
 
 
 
486
  analyze_btn.click(
487
  analyze_sequence,
488
  inputs=[file_input, top_k, text_input, win_size],
 
515
  )
516
 
517
  gr.Markdown("""
518
+ ### Interface Features
519
+ - **Overall Classification** (human vs non-human) using k-mer frequencies.
520
+ - **Top k-mer SHAP**: which k-mers push the classifier output.
521
+ - **Genome-Wide SHAP Heatmap**: each base's average SHAP across overlapping k-mers.
522
+ - **Identify Subregions** (sliding window) with the strongest push for human or non-human.
523
+ - **Subregion Exploration**: local SHAP heatmap & histogram, GC content, fraction of positions pushing human vs. non-human.
 
 
 
 
 
524
  """)
525
 
526
  if __name__ == "__main__":