hiyata commited on
Commit
0d2d632
·
verified ·
1 Parent(s): 44f2fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -46
app.py CHANGED
@@ -32,6 +32,28 @@ class VirusClassifier(nn.Module):
32
  def forward(self, x):
33
  return self.network(x)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ###############################################################################
37
  # Utility Functions
@@ -65,6 +87,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
65
  Convert a single nucleotide sequence to a k-mer frequency vector
66
  of length 4^k (e.g., for k=4, length=256).
67
  """
 
68
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
69
  kmer_dict = {km: i for i, km in enumerate(kmers)}
70
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -122,6 +145,7 @@ def create_freq_sigma_plot(
122
  # color by sign (positive=green, negative=red)
123
  colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
124
 
 
125
  x = np.arange(len(kmers))
126
  width = 0.4
127
 
@@ -129,7 +153,8 @@ def create_freq_sigma_plot(
129
  # Frequency
130
  ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
131
  ax.set_ylabel("Frequency (%)", color='black')
132
- ax.set_ylim(0, max(freqs)*1.2 if len(freqs) else 1)
 
133
 
134
  # Twin axis for sigma
135
  ax2 = ax.twinx()
@@ -160,7 +185,7 @@ def run_classification_and_shap(file_obj):
160
  - shap_values object (SHAP values for the entire batch)
161
  - array/batch of scaled vectors (for use in the waterfall selection)
162
  - list of k-mers (for indexing)
163
- - possibly the model or other context
164
  """
165
  # 1. Basic read
166
  if isinstance(file_obj, str):
@@ -192,12 +217,15 @@ def run_classification_and_shap(file_obj):
192
  # 4. Load model & scaler
193
  try:
194
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
195
  model = VirusClassifier(input_shape=4**k).to(device)
196
- state_dict = torch.load("model.pt", map_location=device)
 
197
  model.load_state_dict(state_dict)
198
  model.eval()
199
 
200
  scaler = joblib.load("scaler.pkl")
 
201
  except Exception as e:
202
  return None, None, f"Error loading model or scaler: {str(e)}"
203
 
@@ -224,20 +252,18 @@ def run_classification_and_shap(file_obj):
224
 
225
  # 7. SHAP Explainer
226
  # We'll pick a background subset if there are many sequences
227
- # (For performance, we might limit to e.g. 50 samples max)
228
  if scaled_data.shape[0] > 50:
229
  background_data = scaled_data[:50]
230
  else:
231
  background_data = scaled_data
232
 
233
- # Use the "new" unified shap.Explainer approach
234
- # We pass in a function that does the forward pass. Or pass the model directly.
235
- # For PyTorch models, shap can do a direct 'model' approach with a mask.
236
- # We'll do a simple "use shap.Explainer" with data=background_data
237
- explainer = shap.Explainer(model, background_data)
238
  shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
239
 
240
  # k-mer list
 
241
  kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
242
 
243
  return (results_table, shap_values, scaled_data, kmer_list, None)
@@ -249,8 +275,8 @@ def run_classification_and_shap(file_obj):
249
  def main_predict(file_obj):
250
  """
251
  This function is triggered by the 'Run' button in Gradio.
252
- It returns a markdown of all sequences/predictions and stores
253
- data needed for the subsequent SHAP visualizations.
254
  """
255
  results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
256
  if err:
@@ -270,32 +296,26 @@ def main_predict(file_obj):
270
  )
271
  md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
272
 
273
- # Return the string, and also the shap values plus data needed
274
- # We'll store these to SessionState via Gradio's "State" or we can
275
- # pass them out as hidden fields.
276
  return (md, shap_vals, scaled_data, kmer_list, results)
277
 
278
-
279
  def update_waterfall_plot(selected_index, shap_values_obj):
280
  """
281
- Build a waterfall plot for the user-selected sample.
282
  """
283
  if shap_values_obj is None:
284
  return None
285
 
 
 
 
286
  try:
287
  selected_index = int(selected_index)
288
  except:
289
  selected_index = 0
290
 
291
- # We'll create the figure by calling shap.plots.waterfall
292
- # Convert shap_values_obj to the new shap interface
293
- # shap_values_obj is a shap._explanation.Explanation typically
294
-
295
- # We can create a figure with shap.plots.waterfall and capture it as an image
296
  shap_plots_fig = plt.figure(figsize=(8, 5))
297
- shap.plots.waterfall(shap_values_obj[selected_index], max_display=14,
298
- show=False) # show=False so it doesn't pop in the notebook
299
  buf = io.BytesIO()
300
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
301
  buf.seek(0)
@@ -304,7 +324,6 @@ def update_waterfall_plot(selected_index, shap_values_obj):
304
 
305
  return wf_img
306
 
307
-
308
  def update_beeswarm_plot(shap_values_obj):
309
  """
310
  Build a beeswarm plot across all samples.
@@ -312,6 +331,9 @@ def update_beeswarm_plot(shap_values_obj):
312
  if shap_values_obj is None:
313
  return None
314
 
 
 
 
315
  beeswarm_fig = plt.figure(figsize=(8, 5))
316
  shap.plots.beeswarm(shap_values_obj, show=False)
317
  buf = io.BytesIO()
@@ -322,11 +344,10 @@ def update_beeswarm_plot(shap_values_obj):
322
 
323
  return bs_img
324
 
325
-
326
  def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
327
  """
328
  Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
329
- We'll need to also compute the raw_freq_vector from the original unscaled data.
330
  """
331
  if shap_values_obj is None or scaled_data is None or kmer_list is None:
332
  return None
@@ -336,23 +357,17 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
336
  except:
337
  selected_index = 0
338
 
339
- # We must re-generate the raw freq vector from the original input file
340
- # or store it from earlier. Let's just re-run parse for that single sequence:
341
- # But simpler is: run_classification_and_shap was storing all_raw_vectors...
342
- # Let's do a quick approach: run_classification_and_shap already computed it
343
- # but we didn't store it. We'll re-run the parse logic to get the raw freq again.
344
-
345
- # For memory / speed reasons, better is to store it.
346
- # For simplicity, let's parse again quickly:
347
  if isinstance(file_obj, str):
348
  text = file_obj
349
  else:
350
  text = file_obj.decode('utf-8')
 
351
  sequences = parse_fasta(text)
352
- # the selected_index might be out of range, so let's clamp it
353
  if selected_index >= len(sequences):
354
  selected_index = 0
355
- seq = sequences[selected_index][1] # get the sequence
 
356
  raw_vec = sequence_to_kmer_vector(seq, k=4)
357
 
358
  single_shap_values = shap_values_obj.values[selected_index]
@@ -376,11 +391,11 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
376
  # Gradio Interface
377
  ###############################################################################
378
  with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
379
- shap.initjs() # load shap JS for interactive plots in some contexts (optional)
380
 
381
  gr.Markdown(
382
  """
383
- # **Advanced Virus Host Classifier with SHAP**
384
  **Upload a FASTA file** with one or more nucleotide sequences.
385
  This app will:
386
  1. Predict each sequence's **host** (human vs. non-human).
@@ -407,7 +422,7 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
407
  md_out = gr.Markdown()
408
 
409
  with gr.Tab("SHAP Waterfall"):
410
- # We'll let user pick the sequence index from a dropdown or slider
411
  with gr.Row():
412
  seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
413
  update_wf_btn = gr.Button("Update Waterfall")
@@ -424,34 +439,39 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
424
  fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
425
 
426
  # --- Button Logic ---
 
427
  run_btn.click(
428
  fn=main_predict,
429
  inputs=[file_input],
430
  outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
431
  )
432
- run_btn.click( # Also store the raw file data for later freq plots
 
433
  fn=lambda x: x,
434
  inputs=file_input,
435
  outputs=file_data_state
436
  )
437
 
 
438
  update_wf_btn.click(
439
  fn=update_waterfall_plot,
440
  inputs=[seq_index_dropdown, shap_values_state],
441
  outputs=[wf_plot]
442
  )
443
- update_fs_btn.click(
444
- fn=update_freq_plot,
445
- inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
446
- outputs=[fs_plot]
447
- )
448
 
449
- # We can auto-generate the beeswarm right after classification as well
450
  run_btn.click(
451
  fn=update_beeswarm_plot,
452
  inputs=[shap_values_state],
453
  outputs=[bs_plot]
454
  )
455
 
 
 
 
 
 
 
 
456
  if __name__ == "__main__":
457
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
32
  def forward(self, x):
33
  return self.network(x)
34
 
35
+ ###############################################################################
36
+ # Torch Model Wrapper for SHAP
37
+ ###############################################################################
38
+ class TorchModelWrapper:
39
+ """
40
+ A simple callable that takes a PyTorch model and device,
41
+ and allows SHAP to pass in numpy arrays, which we convert to torch tensors.
42
+ """
43
+ def __init__(self, model: nn.Module, device='cpu'):
44
+ self.model = model
45
+ self.device = device
46
+
47
+ def __call__(self, x_np: np.ndarray):
48
+ """
49
+ x_np: shape=(batch_size, num_features) as a numpy array
50
+ Returns: numpy array of shape=(batch_size, num_outputs)
51
+ """
52
+ x_torch = torch.from_numpy(x_np).float().to(self.device)
53
+ with torch.no_grad():
54
+ out = self.model(x_torch).cpu().numpy()
55
+ return out
56
+
57
 
58
  ###############################################################################
59
  # Utility Functions
 
87
  Convert a single nucleotide sequence to a k-mer frequency vector
88
  of length 4^k (e.g., for k=4, length=256).
89
  """
90
+ from itertools import product
91
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
92
  kmer_dict = {km: i for i, km in enumerate(kmers)}
93
  vec = np.zeros(len(kmers), dtype=np.float32)
 
145
  # color by sign (positive=green, negative=red)
146
  colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
147
 
148
+ import matplotlib.pyplot as plt
149
  x = np.arange(len(kmers))
150
  width = 0.4
151
 
 
153
  # Frequency
154
  ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
155
  ax.set_ylabel("Frequency (%)", color='black')
156
+ if freqs:
157
+ ax.set_ylim(0, max(freqs)*1.2)
158
 
159
  # Twin axis for sigma
160
  ax2 = ax.twinx()
 
185
  - shap_values object (SHAP values for the entire batch)
186
  - array/batch of scaled vectors (for use in the waterfall selection)
187
  - list of k-mers (for indexing)
188
+ - error message or None
189
  """
190
  # 1. Basic read
191
  if isinstance(file_obj, str):
 
217
  # 4. Load model & scaler
218
  try:
219
  device = "cuda" if torch.cuda.is_available() else "cpu"
220
+
221
  model = VirusClassifier(input_shape=4**k).to(device)
222
+ # Set weights_only=True to suppress the future pickle warning
223
+ state_dict = torch.load("model.pt", map_location=device, weights_only=True)
224
  model.load_state_dict(state_dict)
225
  model.eval()
226
 
227
  scaler = joblib.load("scaler.pkl")
228
+
229
  except Exception as e:
230
  return None, None, f"Error loading model or scaler: {str(e)}"
231
 
 
252
 
253
  # 7. SHAP Explainer
254
  # We'll pick a background subset if there are many sequences
 
255
  if scaled_data.shape[0] > 50:
256
  background_data = scaled_data[:50]
257
  else:
258
  background_data = scaled_data
259
 
260
+ # Wrap the model so it can handle numpy -> tensor
261
+ wrapped_model = TorchModelWrapper(model, device)
262
+ explainer = shap.Explainer(wrapped_model, background_data)
 
 
263
  shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
264
 
265
  # k-mer list
266
+ from itertools import product
267
  kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
268
 
269
  return (results_table, shap_values, scaled_data, kmer_list, None)
 
275
  def main_predict(file_obj):
276
  """
277
  This function is triggered by the 'Run' button in Gradio.
278
+ It returns a markdown of all sequences/predictions and
279
+ the shap values plus data needed for subsequent plots.
280
  """
281
  results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
282
  if err:
 
296
  )
297
  md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
298
 
 
 
 
299
  return (md, shap_vals, scaled_data, kmer_list, results)
300
 
 
301
  def update_waterfall_plot(selected_index, shap_values_obj):
302
  """
303
+ Build a waterfall plot for the user-selected sample using shap.plots.waterfall.
304
  """
305
  if shap_values_obj is None:
306
  return None
307
 
308
+ import matplotlib.pyplot as plt
309
+ import shap
310
+
311
  try:
312
  selected_index = int(selected_index)
313
  except:
314
  selected_index = 0
315
 
316
+ # Create the figure by calling shap.plots.waterfall
 
 
 
 
317
  shap_plots_fig = plt.figure(figsize=(8, 5))
318
+ shap.plots.waterfall(shap_values_obj[selected_index], max_display=14, show=False)
 
319
  buf = io.BytesIO()
320
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
321
  buf.seek(0)
 
324
 
325
  return wf_img
326
 
 
327
  def update_beeswarm_plot(shap_values_obj):
328
  """
329
  Build a beeswarm plot across all samples.
 
331
  if shap_values_obj is None:
332
  return None
333
 
334
+ import matplotlib.pyplot as plt
335
+ import shap
336
+
337
  beeswarm_fig = plt.figure(figsize=(8, 5))
338
  shap.plots.beeswarm(shap_values_obj, show=False)
339
  buf = io.BytesIO()
 
344
 
345
  return bs_img
346
 
 
347
  def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
348
  """
349
  Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
350
+ We must re-parse the raw freq vector for that sequence, or store it from earlier.
351
  """
352
  if shap_values_obj is None or scaled_data is None or kmer_list is None:
353
  return None
 
357
  except:
358
  selected_index = 0
359
 
360
+ # Re-parse the FASTA to get the corresponding sequence
 
 
 
 
 
 
 
361
  if isinstance(file_obj, str):
362
  text = file_obj
363
  else:
364
  text = file_obj.decode('utf-8')
365
+
366
  sequences = parse_fasta(text)
 
367
  if selected_index >= len(sequences):
368
  selected_index = 0
369
+
370
+ seq = sequences[selected_index][1]
371
  raw_vec = sequence_to_kmer_vector(seq, k=4)
372
 
373
  single_shap_values = shap_values_obj.values[selected_index]
 
391
  # Gradio Interface
392
  ###############################################################################
393
  with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
394
+ shap.initjs() # load shap JS if needed for interactive HTML (optional)
395
 
396
  gr.Markdown(
397
  """
398
+ # **Virus Host Classifier with SHAP**
399
  **Upload a FASTA file** with one or more nucleotide sequences.
400
  This app will:
401
  1. Predict each sequence's **host** (human vs. non-human).
 
422
  md_out = gr.Markdown()
423
 
424
  with gr.Tab("SHAP Waterfall"):
425
+ # We'll let user pick the sequence index from a dropdown or input
426
  with gr.Row():
427
  seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
428
  update_wf_btn = gr.Button("Update Waterfall")
 
439
  fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
440
 
441
  # --- Button Logic ---
442
+ # 1) The main classification run
443
  run_btn.click(
444
  fn=main_predict,
445
  inputs=[file_input],
446
  outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
447
  )
448
+ # Also store raw file data for subsequent freq usage
449
+ run_btn.click(
450
  fn=lambda x: x,
451
  inputs=file_input,
452
  outputs=file_data_state
453
  )
454
 
455
+ # 2) Waterfall update
456
  update_wf_btn.click(
457
  fn=update_waterfall_plot,
458
  inputs=[seq_index_dropdown, shap_values_state],
459
  outputs=[wf_plot]
460
  )
 
 
 
 
 
461
 
462
+ # 3) Beeswarm update
463
  run_btn.click(
464
  fn=update_beeswarm_plot,
465
  inputs=[shap_values_state],
466
  outputs=[bs_plot]
467
  )
468
 
469
+ # 4) Frequency top-10 update
470
+ update_fs_btn.click(
471
+ fn=update_freq_plot,
472
+ inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
473
+ outputs=[fs_plot]
474
+ )
475
+
476
  if __name__ == "__main__":
477
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)