hiyata commited on
Commit
afbf1c6
·
verified ·
1 Parent(s): 0d2d632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -70
app.py CHANGED
@@ -4,10 +4,13 @@ import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
 
 
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
- import shap # Requires: pip install shap
 
11
 
12
  ###############################################################################
13
  # Model Definition
@@ -32,13 +35,15 @@ class VirusClassifier(nn.Module):
32
  def forward(self, x):
33
  return self.network(x)
34
 
 
35
  ###############################################################################
36
  # Torch Model Wrapper for SHAP
37
  ###############################################################################
38
  class TorchModelWrapper:
39
  """
40
  A simple callable that takes a PyTorch model and device,
41
- and allows SHAP to pass in numpy arrays, which we convert to torch tensors.
 
42
  """
43
  def __init__(self, model: nn.Module, device='cpu'):
44
  self.model = model
@@ -87,7 +92,6 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
87
  Convert a single nucleotide sequence to a k-mer frequency vector
88
  of length 4^k (e.g., for k=4, length=256).
89
  """
90
- from itertools import product
91
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
92
  kmer_dict = {km: i for i, km in enumerate(kmers)}
93
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -118,47 +122,54 @@ def create_freq_sigma_plot(
118
  Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
119
  with frequency (%) and sigma from mean on a twin-axis.
120
 
121
- single_shap_values: shape=(256,) shap values for this sample
122
- raw_freq_vector: shape=(256,) original frequencies for this sample
123
- scaled_vector: shape=(256,) scaled (Z-score) values for this sample
124
- kmer_list: list of all k-mers (length=256)
125
  """
126
- abs_vals = np.abs(single_shap_values)
 
127
  top_k = 10
128
- top_indices = np.argsort(abs_vals)[-top_k:][::-1] # top 10 by absolute shap
 
129
  top_data = []
130
  for idx in top_indices:
 
131
  top_data.append({
132
- "kmer": kmer_list[idx],
133
- "shap": single_shap_values[idx],
134
- "abs_shap": abs_vals[idx],
135
- "frequency": raw_freq_vector[idx] * 100.0, # percentage
136
- "sigma": scaled_vector[idx]
137
  })
138
 
139
  # Sort top_data by abs_shap descending
140
  top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
141
 
 
142
  kmers = [d["kmer"] for d in top_data]
143
  freqs = [d["frequency"] for d in top_data]
144
  sigmas = [d["sigma"] for d in top_data]
145
- # color by sign (positive=green, negative=red)
146
  colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
147
 
148
- import matplotlib.pyplot as plt
149
  x = np.arange(len(kmers))
150
  width = 0.4
151
 
152
  fig, ax = plt.subplots(figsize=(8, 5))
153
  # Frequency
154
- ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
 
 
155
  ax.set_ylabel("Frequency (%)", color='black')
156
- if freqs:
157
  ax.set_ylim(0, max(freqs)*1.2)
158
 
159
  # Twin axis for sigma
160
  ax2 = ax.twinx()
161
- ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean")
 
 
162
  ax2.set_ylabel("Standard Deviations (σ)", color='black')
163
 
164
  ax.set_xticks(x)
@@ -182,9 +193,9 @@ def run_classification_and_shap(file_obj):
182
  Reads one or more FASTA sequences from file_obj or text.
183
  Returns:
184
  - Table of results (list of dicts) for each sequence
185
- - shap_values object (SHAP values for the entire batch)
186
- - array/batch of scaled vectors (for use in the waterfall selection)
187
- - list of k-mers (for indexing)
188
  - error message or None
189
  """
190
  # 1. Basic read
@@ -194,12 +205,12 @@ def run_classification_and_shap(file_obj):
194
  try:
195
  text = file_obj.decode("utf-8")
196
  except Exception as e:
197
- return None, None, f"Error reading file: {str(e)}"
198
 
199
  # 2. Parse FASTA
200
  sequences = parse_fasta(text)
201
  if len(sequences) == 0:
202
- return None, None, "No valid FASTA sequences found!"
203
 
204
  # 3. Convert each sequence to k-mer vector
205
  k = 4
@@ -219,15 +230,14 @@ def run_classification_and_shap(file_obj):
219
  device = "cuda" if torch.cuda.is_available() else "cpu"
220
 
221
  model = VirusClassifier(input_shape=4**k).to(device)
222
- # Set weights_only=True to suppress the future pickle warning
223
  state_dict = torch.load("model.pt", map_location=device, weights_only=True)
224
  model.load_state_dict(state_dict)
225
  model.eval()
226
 
227
  scaler = joblib.load("scaler.pkl")
228
-
229
  except Exception as e:
230
- return None, None, f"Error loading model or scaler: {str(e)}"
231
 
232
  # 5. Scale data
233
  scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
@@ -236,6 +246,7 @@ def run_classification_and_shap(file_obj):
236
  X_tensor = torch.FloatTensor(scaled_data).to(device)
237
  with torch.no_grad():
238
  logits = model(X_tensor)
 
239
  probs = torch.softmax(logits, dim=1).cpu().numpy()
240
  preds = np.argmax(probs, axis=1) # 0 or 1
241
 
@@ -243,29 +254,30 @@ def run_classification_and_shap(file_obj):
243
  for i, (hdr, seq) in enumerate(zip(headers, seqs)):
244
  results_table.append({
245
  "header": hdr,
246
- "sequence": seq[:50] + ("..." if len(seq)>50 else ""), # truncated
247
  "pred_label": "human" if preds[i] == 1 else "non-human",
248
  "human_prob": float(probs[i][1]),
249
  "non_human_prob": float(probs[i][0]),
250
- "confidence": float(max(probs[i]))
251
  })
252
 
253
  # 7. SHAP Explainer
254
- # We'll pick a background subset if there are many sequences
255
  if scaled_data.shape[0] > 50:
256
  background_data = scaled_data[:50]
257
  else:
258
  background_data = scaled_data
259
 
260
- # Wrap the model so it can handle numpy -> tensor
261
  wrapped_model = TorchModelWrapper(model, device)
262
  explainer = shap.Explainer(wrapped_model, background_data)
263
- shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
 
 
264
 
265
- # k-mer list
266
- from itertools import product
267
  kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
268
 
 
269
  return (results_table, shap_values, scaled_data, kmer_list, None)
270
 
271
 
@@ -274,9 +286,8 @@ def run_classification_and_shap(file_obj):
274
  ###############################################################################
275
  def main_predict(file_obj):
276
  """
277
- This function is triggered by the 'Run' button in Gradio.
278
- It returns a markdown of all sequences/predictions and
279
- the shap values plus data needed for subsequent plots.
280
  """
281
  results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
282
  if err:
@@ -294,28 +305,44 @@ def main_predict(file_obj):
294
  f"| {i} | {row['header']} | {row['pred_label']} | "
295
  f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
296
  )
297
- md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
298
 
299
  return (md, shap_vals, scaled_data, kmer_list, results)
300
 
 
301
  def update_waterfall_plot(selected_index, shap_values_obj):
302
  """
303
- Build a waterfall plot for the user-selected sample using shap.plots.waterfall.
 
 
 
304
  """
305
  if shap_values_obj is None:
306
  return None
307
 
308
  import matplotlib.pyplot as plt
309
- import shap
310
 
311
  try:
312
  selected_index = int(selected_index)
313
  except:
314
  selected_index = 0
315
 
316
- # Create the figure by calling shap.plots.waterfall
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  shap_plots_fig = plt.figure(figsize=(8, 5))
318
- shap.plots.waterfall(shap_values_obj[selected_index], max_display=14, show=False)
319
  buf = io.BytesIO()
320
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
321
  buf.seek(0)
@@ -324,18 +351,35 @@ def update_waterfall_plot(selected_index, shap_values_obj):
324
 
325
  return wf_img
326
 
 
327
  def update_beeswarm_plot(shap_values_obj):
328
  """
329
- Build a beeswarm plot across all samples.
 
 
330
  """
331
  if shap_values_obj is None:
332
  return None
333
 
334
  import matplotlib.pyplot as plt
335
- import shap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  beeswarm_fig = plt.figure(figsize=(8, 5))
338
- shap.plots.beeswarm(shap_values_obj, show=False)
339
  buf = io.BytesIO()
340
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
341
  buf.seek(0)
@@ -344,14 +388,17 @@ def update_beeswarm_plot(shap_values_obj):
344
 
345
  return bs_img
346
 
 
347
  def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
348
  """
349
- Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
350
- We must re-parse the raw freq vector for that sequence, or store it from earlier.
351
  """
352
  if shap_values_obj is None or scaled_data is None or kmer_list is None:
353
  return None
354
 
 
 
355
  try:
356
  selected_index = int(selected_index)
357
  except:
@@ -364,20 +411,23 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
364
  text = file_obj.decode('utf-8')
365
 
366
  sequences = parse_fasta(text)
 
367
  if selected_index >= len(sequences):
368
  selected_index = 0
369
 
370
  seq = sequences[selected_index][1]
371
- raw_vec = sequence_to_kmer_vector(seq, k=4)
372
 
373
- single_shap_values = shap_values_obj.values[selected_index]
 
374
  freq_sigma_fig = create_freq_sigma_plot(
375
- single_shap_values,
376
- raw_freq_vector=raw_vec,
377
  scaled_vector=scaled_data[selected_index],
378
  kmer_list=kmer_list,
379
  title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
380
  )
 
381
  buf = io.BytesIO()
382
  freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
383
  buf.seek(0)
@@ -391,16 +441,19 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
391
  # Gradio Interface
392
  ###############################################################################
393
  with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
394
- shap.initjs() # load shap JS if needed for interactive HTML (optional)
395
 
396
  gr.Markdown(
397
  """
398
- # **Virus Host Classifier with SHAP**
399
- **Upload a FASTA file** with one or more nucleotide sequences.
400
  This app will:
401
  1. Predict each sequence's **host** (human vs. non-human).
402
- 2. Provide **SHAP** explanations (waterfall & beeswarm).
403
- 3. Let you explore **frequency & σ** for top-10 k-mers for a chosen sequence.
 
 
 
404
  """
405
  )
406
 
@@ -408,23 +461,20 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
408
  file_input = gr.File(label="Upload FASTA", type="binary")
409
  run_btn = gr.Button("Run Classification")
410
 
411
- # Store intermediate results in "States" for usage in subsequent tabs
412
  shap_values_state = gr.State()
413
  scaled_data_state = gr.State()
414
  kmer_list_state = gr.State()
415
  results_state = gr.State()
416
- # We'll also store the "raw input" so we can reconstruct freq data for each sample
417
  file_data_state = gr.State()
418
 
419
- # TABS for outputs
420
  with gr.Tabs():
421
  with gr.Tab("Results Table"):
422
  md_out = gr.Markdown()
423
 
424
  with gr.Tab("SHAP Waterfall"):
425
- # We'll let user pick the sequence index from a dropdown or input
426
  with gr.Row():
427
- seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
428
  update_wf_btn = gr.Button("Update Waterfall")
429
 
430
  wf_plot = gr.Image(label="SHAP Waterfall Plot")
@@ -434,44 +484,43 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
434
 
435
  with gr.Tab("Top-10 Frequency & Sigma"):
436
  with gr.Row():
437
- seq_index_dropdown2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
438
  update_fs_btn = gr.Button("Update Frequency Chart")
439
  fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
440
 
441
- # --- Button Logic ---
442
- # 1) The main classification run
443
  run_btn.click(
444
  fn=main_predict,
445
  inputs=[file_input],
446
  outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
447
  )
448
- # Also store raw file data for subsequent freq usage
449
  run_btn.click(
450
  fn=lambda x: x,
451
  inputs=file_input,
452
  outputs=file_data_state
453
  )
454
 
455
- # 2) Waterfall update
456
  update_wf_btn.click(
457
  fn=update_waterfall_plot,
458
- inputs=[seq_index_dropdown, shap_values_state],
459
  outputs=[wf_plot]
460
  )
461
 
462
- # 3) Beeswarm update
463
  run_btn.click(
464
  fn=update_beeswarm_plot,
465
  inputs=[shap_values_state],
466
  outputs=[bs_plot]
467
  )
468
 
469
- # 4) Frequency top-10 update
470
  update_fs_btn.click(
471
  fn=update_freq_plot,
472
- inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
473
  outputs=[fs_plot]
474
  )
475
 
476
  if __name__ == "__main__":
477
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
+ import matplotlib
8
+ matplotlib.use("Agg") # In case we're running in a no-display environment
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
12
+ import shap
13
+
14
 
15
  ###############################################################################
16
  # Model Definition
 
35
  def forward(self, x):
36
  return self.network(x)
37
 
38
+
39
  ###############################################################################
40
  # Torch Model Wrapper for SHAP
41
  ###############################################################################
42
  class TorchModelWrapper:
43
  """
44
  A simple callable that takes a PyTorch model and device,
45
+ allowing SHAP to pass in NumPy arrays. We convert them
46
+ to torch tensors, run the model, and return NumPy outputs.
47
  """
48
  def __init__(self, model: nn.Module, device='cpu'):
49
  self.model = model
 
92
  Convert a single nucleotide sequence to a k-mer frequency vector
93
  of length 4^k (e.g., for k=4, length=256).
94
  """
 
95
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
96
  kmer_dict = {km: i for i, km in enumerate(kmers)}
97
  vec = np.zeros(len(kmers), dtype=np.float32)
 
122
  Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
123
  with frequency (%) and sigma from mean on a twin-axis.
124
 
125
+ single_shap_values: shape=(256,) SHAP values for the "human" class
126
+ raw_freq_vector: shape=(256,) original frequencies for this sample
127
+ scaled_vector: shape=(256,) scaled (Z-score) values for this sample
128
+ kmer_list: list of length=256 of all k-mers
129
  """
130
+ # Identify the top 10 k-mers by absolute shap
131
+ abs_vals = np.abs(single_shap_values) # shape=(256,)
132
  top_k = 10
133
+ top_indices = np.argsort(abs_vals)[-top_k:][::-1] # indices of largest -> smallest
134
+
135
  top_data = []
136
  for idx in top_indices:
137
+ idx_int = int(idx) # ensure integer
138
  top_data.append({
139
+ "kmer": kmer_list[idx_int],
140
+ "shap": single_shap_values[idx_int],
141
+ "abs_shap": abs_vals[idx_int],
142
+ "frequency": raw_freq_vector[idx_int] * 100.0, # percentage
143
+ "sigma": scaled_vector[idx_int]
144
  })
145
 
146
  # Sort top_data by abs_shap descending
147
  top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
148
 
149
+ # Prepare for plotting
150
  kmers = [d["kmer"] for d in top_data]
151
  freqs = [d["frequency"] for d in top_data]
152
  sigmas = [d["sigma"] for d in top_data]
153
+ # color by sign (positive=green => pushes "human", negative=red => pushes "non-human")
154
  colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
155
 
 
156
  x = np.arange(len(kmers))
157
  width = 0.4
158
 
159
  fig, ax = plt.subplots(figsize=(8, 5))
160
  # Frequency
161
+ ax.bar(
162
+ x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)"
163
+ )
164
  ax.set_ylabel("Frequency (%)", color='black')
165
+ if len(freqs) > 0:
166
  ax.set_ylim(0, max(freqs)*1.2)
167
 
168
  # Twin axis for sigma
169
  ax2 = ax.twinx()
170
+ ax2.bar(
171
+ x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean"
172
+ )
173
  ax2.set_ylabel("Standard Deviations (σ)", color='black')
174
 
175
  ax.set_xticks(x)
 
193
  Reads one or more FASTA sequences from file_obj or text.
194
  Returns:
195
  - Table of results (list of dicts) for each sequence
196
+ - shap_values object (SHAP values for the entire batch, shape=(num_samples, 2, num_features))
197
+ - array of scaled vectors
198
+ - list of k-mers
199
  - error message or None
200
  """
201
  # 1. Basic read
 
205
  try:
206
  text = file_obj.decode("utf-8")
207
  except Exception as e:
208
+ return None, None, None, None, f"Error reading file: {str(e)}"
209
 
210
  # 2. Parse FASTA
211
  sequences = parse_fasta(text)
212
  if len(sequences) == 0:
213
+ return None, None, None, None, "No valid FASTA sequences found!"
214
 
215
  # 3. Convert each sequence to k-mer vector
216
  k = 4
 
230
  device = "cuda" if torch.cuda.is_available() else "cpu"
231
 
232
  model = VirusClassifier(input_shape=4**k).to(device)
233
+ # Use weights_only=True to suppress future warnings about untrusted pickles
234
  state_dict = torch.load("model.pt", map_location=device, weights_only=True)
235
  model.load_state_dict(state_dict)
236
  model.eval()
237
 
238
  scaler = joblib.load("scaler.pkl")
 
239
  except Exception as e:
240
+ return None, None, None, None, f"Error loading model or scaler: {str(e)}"
241
 
242
  # 5. Scale data
243
  scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
 
246
  X_tensor = torch.FloatTensor(scaled_data).to(device)
247
  with torch.no_grad():
248
  logits = model(X_tensor)
249
+ # shape=(num_seqs, 2)
250
  probs = torch.softmax(logits, dim=1).cpu().numpy()
251
  preds = np.argmax(probs, axis=1) # 0 or 1
252
 
 
254
  for i, (hdr, seq) in enumerate(zip(headers, seqs)):
255
  results_table.append({
256
  "header": hdr,
257
+ "sequence": seq[:50] + ("..." if len(seq) > 50 else ""),
258
  "pred_label": "human" if preds[i] == 1 else "non-human",
259
  "human_prob": float(probs[i][1]),
260
  "non_human_prob": float(probs[i][0]),
261
+ "confidence": float(np.max(probs[i]))
262
  })
263
 
264
  # 7. SHAP Explainer
265
+ # For large data, pick a smaller background subset
266
  if scaled_data.shape[0] > 50:
267
  background_data = scaled_data[:50]
268
  else:
269
  background_data = scaled_data
270
 
 
271
  wrapped_model = TorchModelWrapper(model, device)
272
  explainer = shap.Explainer(wrapped_model, background_data)
273
+ # shap_values shape=(num_samples, num_features) if single-output
274
+ # but here we have 2 outputs => shape=(num_samples, 2, num_features).
275
+ shap_values = explainer(scaled_data)
276
 
277
+ # Prepare k-mer list
 
278
  kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
279
 
280
+ # Return everything
281
  return (results_table, shap_values, scaled_data, kmer_list, None)
282
 
283
 
 
286
  ###############################################################################
287
  def main_predict(file_obj):
288
  """
289
+ Triggered by the 'Run Classification' button in Gradio.
290
+ Returns a markdown table plus states for subsequent plots.
 
291
  """
292
  results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
293
  if err:
 
305
  f"| {i} | {row['header']} | {row['pred_label']} | "
306
  f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
307
  )
308
+ md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)."
309
 
310
  return (md, shap_vals, scaled_data, kmer_list, results)
311
 
312
+
313
  def update_waterfall_plot(selected_index, shap_values_obj):
314
  """
315
+ Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human).
316
+ shap_values_obj has shape=(num_samples, 2, num_features).
317
+ We do shap_values_obj[selected_index, 1] => shape=(num_features,)
318
+ for a single-sample single-class explanation.
319
  """
320
  if shap_values_obj is None:
321
  return None
322
 
323
  import matplotlib.pyplot as plt
 
324
 
325
  try:
326
  selected_index = int(selected_index)
327
  except:
328
  selected_index = 0
329
 
330
+ # We only visualize class=1 ("human") SHAP values
331
+ # shap_values_obj.values shape => (num_samples, 2, num_features)
332
+ single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,)
333
+ single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar
334
+ single_ex_data = shap_values_obj.data[selected_index] # shape=(256,)
335
+
336
+ # Construct a shap.Explanation object for just this one sample & class
337
+ single_expl = shap.Explanation(
338
+ values=single_ex_values,
339
+ base_values=single_ex_base,
340
+ data=single_ex_data,
341
+ feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])]
342
+ )
343
+
344
  shap_plots_fig = plt.figure(figsize=(8, 5))
345
+ shap.plots.waterfall(single_expl, max_display=14, show=False)
346
  buf = io.BytesIO()
347
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
348
  buf.seek(0)
 
351
 
352
  return wf_img
353
 
354
+
355
  def update_beeswarm_plot(shap_values_obj):
356
  """
357
+ Build a beeswarm plot across all samples, but only for class=1 (human).
358
+ We slice shap_values_obj to pick shap_values_obj.values[:, 1, :]
359
+ => shape=(num_samples, num_features).
360
  """
361
  if shap_values_obj is None:
362
  return None
363
 
364
  import matplotlib.pyplot as plt
365
+
366
+ # For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features)
367
+ # We'll create a new Explanation object for class=1:
368
+ class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features)
369
+ class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,)
370
+ class1_data = shap_values_obj.data # shape=(num_samples, num_features)
371
+
372
+ # Some versions of shap store data in a 2D array, which is fine
373
+ # We'll re-wrap them in a shap.Explanation:
374
+ class1_expl = shap.Explanation(
375
+ values=class1_vals,
376
+ base_values=class1_base,
377
+ data=class1_data,
378
+ feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])]
379
+ )
380
 
381
  beeswarm_fig = plt.figure(figsize=(8, 5))
382
+ shap.plots.beeswarm(class1_expl, show=False)
383
  buf = io.BytesIO()
384
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
385
  buf.seek(0)
 
388
 
389
  return bs_img
390
 
391
+
392
  def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
393
  """
394
+ Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP).
395
+ Again, we'll use class=1 SHAP values only.
396
  """
397
  if shap_values_obj is None or scaled_data is None or kmer_list is None:
398
  return None
399
 
400
+ import matplotlib.pyplot as plt
401
+
402
  try:
403
  selected_index = int(selected_index)
404
  except:
 
411
  text = file_obj.decode('utf-8')
412
 
413
  sequences = parse_fasta(text)
414
+ # If out of range, clamp to 0
415
  if selected_index >= len(sequences):
416
  selected_index = 0
417
 
418
  seq = sequences[selected_index][1]
419
+ raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,)
420
 
421
+ # SHAP for class=1 => shape=(num_samples, 2, 256)
422
+ single_shap_values = shap_values_obj.values[selected_index, 1, :]
423
  freq_sigma_fig = create_freq_sigma_plot(
424
+ single_shap_values,
425
+ raw_freq_vector=raw_vec,
426
  scaled_vector=scaled_data[selected_index],
427
  kmer_list=kmer_list,
428
  title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
429
  )
430
+
431
  buf = io.BytesIO()
432
  freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
433
  buf.seek(0)
 
441
  # Gradio Interface
442
  ###############################################################################
443
  with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
444
+ shap.initjs() # load shap JS if needed for HTML-based plots (optional)
445
 
446
  gr.Markdown(
447
  """
448
+ # **irus Host Classifier**
449
+ Upload a FASTA file with one or more nucleotide sequences.
450
  This app will:
451
  1. Predict each sequence's **host** (human vs. non-human).
452
+ 2. Provide **SHAP** explanations focusing on the 'human' class (index=1).
453
+ 3. Display:
454
+ - A **waterfall** plot per-sequence (top features).
455
+ - A **beeswarm** plot across all sequences (global summary).
456
+ - A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence.
457
  """
458
  )
459
 
 
461
  file_input = gr.File(label="Upload FASTA", type="binary")
462
  run_btn = gr.Button("Run Classification")
463
 
464
+ # Store intermediate results in Gradio states
465
  shap_values_state = gr.State()
466
  scaled_data_state = gr.State()
467
  kmer_list_state = gr.State()
468
  results_state = gr.State()
 
469
  file_data_state = gr.State()
470
 
 
471
  with gr.Tabs():
472
  with gr.Tab("Results Table"):
473
  md_out = gr.Markdown()
474
 
475
  with gr.Tab("SHAP Waterfall"):
 
476
  with gr.Row():
477
+ seq_index_input = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
478
  update_wf_btn = gr.Button("Update Waterfall")
479
 
480
  wf_plot = gr.Image(label="SHAP Waterfall Plot")
 
484
 
485
  with gr.Tab("Top-10 Frequency & Sigma"):
486
  with gr.Row():
487
+ seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
488
  update_fs_btn = gr.Button("Update Frequency Chart")
489
  fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
490
 
491
+ # 1) Main classification
 
492
  run_btn.click(
493
  fn=main_predict,
494
  inputs=[file_input],
495
  outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
496
  )
 
497
  run_btn.click(
498
  fn=lambda x: x,
499
  inputs=file_input,
500
  outputs=file_data_state
501
  )
502
 
503
+ # 2) Update Waterfall
504
  update_wf_btn.click(
505
  fn=update_waterfall_plot,
506
+ inputs=[seq_index_input, shap_values_state],
507
  outputs=[wf_plot]
508
  )
509
 
510
+ # 3) Update Beeswarm right after classification
511
  run_btn.click(
512
  fn=update_beeswarm_plot,
513
  inputs=[shap_values_state],
514
  outputs=[bs_plot]
515
  )
516
 
517
+ # 4) Update Frequency & σ
518
  update_fs_btn.click(
519
  fn=update_freq_plot,
520
+ inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
521
  outputs=[fs_plot]
522
  )
523
 
524
  if __name__ == "__main__":
525
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
526
+