hiyata commited on
Commit
0e7de0c
·
verified ·
1 Parent(s): 8c49ca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -197
app.py CHANGED
@@ -33,14 +33,14 @@ class VirusClassifier(nn.Module):
33
 
34
  def get_feature_importance(self, x):
35
  """
36
- Calculate gradient-based feature importance.
37
- We'll compute the gradient of the 'human' probability w.r.t. the input vector.
38
  """
39
  x.requires_grad_(True)
40
  output = self.network(x)
41
  probs = torch.softmax(output, dim=1)
42
 
43
- # Gradient wrt 'human' class probability (index=1)
44
  human_prob = probs[..., 1]
45
  if x.grad is not None:
46
  x.grad.zero_()
@@ -94,127 +94,160 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
94
  ###############################################################################
95
  # Visualization
96
  ###############################################################################
97
- def create_visualization(important_kmers, human_prob, title):
98
  """
99
- Create a multi-panel figure showing:
100
- 1) A waterfall-like plot for how each top k-mer shifts the probability from 0.5
101
- (the baseline) to the final 'human' probability.
102
- 2) A side-by-side bar plot for frequency (%) and σ from mean for each important k-mer.
 
103
  """
104
 
105
- # Figure & GridSpec Layout
106
- fig = plt.figure(figsize=(14, 10))
107
- gs = plt.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1.2, 1], hspace=0.35, wspace=0.3)
108
-
109
- # -------------------------------------------------------------------------
110
- # 1. Waterfall-like Plot (top-left subplot)
111
- # -------------------------------------------------------------------------
112
- ax_waterfall = plt.subplot(gs[0, 0])
113
 
114
- # Start from baseline prob=0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  baseline = 0.5
116
- current_prob = baseline
117
- steps = [("Baseline", current_prob, 0.0)]
118
-
119
- # Build up the step changes
120
- for kmer in important_kmers:
121
- direction_multiplier = 1 if kmer["direction"] == "human" else -1
122
- change = kmer["impact"] * 0.05 * direction_multiplier
123
- # ^ scale changes so that the sum doesn't overshadow the final probability.
124
- current_prob += change
125
- steps.append((kmer["kmer"], current_prob, change))
126
-
127
- # X-values for step plot
128
- x_vals = range(len(steps))
129
- y_vals = [s[1] for s in steps]
 
 
 
 
 
130
 
131
- ax_waterfall.step(x_vals, y_vals, where='post', color='blue', linewidth=2, label='Probability')
132
- ax_waterfall.plot(x_vals, y_vals, 'b.', markersize=8)
 
 
 
 
133
 
134
- # Reference lines
135
- ax_waterfall.axhline(y=baseline, color='gray', linestyle='--', label='Baseline=0.5')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- # Annotate each step
138
- for i, (kmer, prob, change) in enumerate(steps):
139
- if i == 0: # baseline
140
- ax_waterfall.annotate(kmer, (i, prob), textcoords="offset points", xytext=(0, -15), ha='center', color='black')
141
- continue
142
-
143
- color = "green" if change > 0 else "red"
144
- ax_waterfall.annotate(
145
- f"{kmer}\n({change:+.3f})",
146
- (i, prob),
147
- textcoords="offset points",
148
- xytext=(0, -15),
149
- ha='center',
150
- color=color,
151
- fontsize=9
152
- )
153
 
154
- ax_waterfall.set_ylim(0, 1)
155
- ax_waterfall.set_xlabel("k-mer Step")
156
- ax_waterfall.set_ylabel("Running Probability (Human)")
157
- ax_waterfall.set_title(f"K-mer Waterfall Plot Final Probability: {human_prob:.3f}")
158
- ax_waterfall.grid(alpha=0.3)
159
- ax_waterfall.legend()
160
-
161
- # -------------------------------------------------------------------------
162
- # 2. Frequency & σ from Mean (top-right subplot)
163
- # -------------------------------------------------------------------------
164
- ax_bar = plt.subplot(gs[0, 1])
165
-
166
- kmers = [k["kmer"] for k in important_kmers]
167
- frequencies = [k["occurrence"] for k in important_kmers] # in %
168
- sigmas = [k["sigma"] for k in important_kmers]
169
- directions = [k["direction"] for k in important_kmers]
170
 
171
- # X-locations
172
  x = np.arange(len(kmers))
173
  width = 0.4
174
 
175
- # We will create twin axes: one for frequency, one for σ
176
- bars1 = ax_bar.bar(x - width/2, frequencies, width, label='Frequency (%)',
177
- alpha=0.7, color=['green' if d=='human' else 'red' for d in directions])
 
 
 
 
 
178
  ax_bar.set_ylabel("Frequency (%)")
179
  ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
180
- ax_bar.set_title("Frequency vs. σ from Mean")
181
 
182
  # Twin axis for σ
183
  ax_bar_twin = ax_bar.twinx()
184
- bars2 = ax_bar_twin.bar(x + width/2, sigmas, width, label='σ from Mean',
185
- alpha=0.5, color='gray')
 
186
  ax_bar_twin.set_ylabel("Standard Deviations (σ)")
187
 
 
188
  ax_bar.set_xticks(x)
189
- ax_bar.set_xticklabels(kmers, rotation=45, ha='right', fontsize=9)
190
-
191
- # Combine legends
192
  lines1, labels1 = ax_bar.get_legend_handles_labels()
193
  lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
194
- ax_bar.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
195
-
196
- # -------------------------------------------------------------------------
197
- # 3. Top Feature Importances (Bottom, spanning both columns)
198
- # -------------------------------------------------------------------------
199
- ax_imp = plt.subplot(gs[1, :])
200
 
201
- # Sort by absolute impact
 
 
 
 
 
 
 
202
  sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
203
- top_kmer_labels = [k['kmer'] for k in sorted_kmers]
204
- top_kmer_impacts = [k['impact'] for k in sorted_kmers]
205
- top_kmer_dirs = [k['direction'] for k in sorted_kmers]
206
 
207
- x_imp = np.arange(len(top_kmer_impacts))
208
- bar_colors = ['green' if d == 'human' else 'red' for d in top_kmer_dirs]
 
 
209
 
210
- ax_imp.bar(x_imp, top_kmer_impacts, color=bar_colors, alpha=0.7)
211
- ax_imp.set_xticks(x_imp)
212
- ax_imp.set_xticklabels(top_kmer_labels, rotation=45, ha='right', fontsize=9)
213
- ax_imp.set_title("Absolute Feature Importance (Top k-mers)")
214
- ax_imp.set_ylabel("Importance (gradient magnitude)")
215
- ax_imp.grid(alpha=0.3, axis='y')
216
 
217
- plt.suptitle(title, fontsize=14, y=1.02)
218
  plt.tight_layout()
219
  return fig
220
 
@@ -224,149 +257,213 @@ def create_visualization(important_kmers, human_prob, title):
224
  ###############################################################################
225
  def predict(file_obj):
226
  """
227
- Main function that Gradio will call:
228
- 1. Reads the uploaded FASTA file (or text).
229
  2. Loads the model and scaler.
230
  3. Generates predictions, probabilities, and top k-mers.
231
- 4. Creates a summary text and a matplotlib figure for visualization.
 
 
 
 
232
  """
 
233
  if file_obj is None:
234
- return "Please upload a FASTA file.", None
 
 
 
 
 
235
 
236
- # Read text from file
237
  try:
 
238
  if isinstance(file_obj, str):
239
  text = file_obj
240
  else:
 
241
  text = file_obj.decode('utf-8')
242
  except Exception as e:
243
- return f"Error reading file: {str(e)}", None
 
 
 
 
 
244
 
245
- # Build k-mer dictionary
 
 
 
 
 
 
 
 
 
 
 
 
246
  k = 4
247
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
248
- kmer_dict = {km: i for i, km in enumerate(kmers)}
249
-
250
- # Load model & scaler
251
  try:
252
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
253
- model = VirusClassifier(256).to(device)
 
 
 
 
 
254
  state_dict = torch.load('model.pt', map_location=device)
255
  model.load_state_dict(state_dict)
256
  scaler = joblib.load('scaler.pkl')
257
  model.eval()
258
- except Exception as e:
259
- return f"Error loading model or scaler: {str(e)}", None
260
 
261
- results_text = ""
262
- plot_image = None
263
-
264
- try:
265
- # Parse FASTA
266
- sequences = parse_fasta(text)
267
- if len(sequences) == 0:
268
- return "No valid FASTA sequences found. Please check your input.", None
269
-
270
- header, seq = sequences[0] # For simplicity, we'll only classify the first sequence
271
 
272
- # Transform sequence to scaled k-mer vector
273
- raw_freq_vector = sequence_to_kmer_vector(seq)
274
- kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
275
- X_tensor = torch.FloatTensor(kmer_vector).to(device)
276
-
277
- # Inference
278
  with torch.no_grad():
279
- output = model(X_tensor)
280
- probs = torch.softmax(output, dim=1)
281
-
282
- # Feature Importance
 
 
 
 
 
283
  importance, hum_prob_grad = model.get_feature_importance(X_tensor)
284
- kmer_importance = importance[0].cpu().numpy() # shape: (256,)
 
 
 
 
 
 
285
 
286
- # Top k-mers by absolute importance
 
287
  top_k = 10
288
- top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1] # largest -> smallest
289
  important_kmers = []
290
-
291
- for idx in top_indices:
292
- # find corresponding k-mer by index
293
- for kmer_str, i_ in kmer_dict.items():
294
- if i_ == idx:
295
- kmer_name = kmer_str
296
- break
297
-
298
- imp_val = float(abs(kmer_importance[idx]))
299
- direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
300
- freq = float(raw_freq_vector[idx] * 100) # frequency in %
301
- sigma = float(kmer_vector[0][idx]) # scaled value (Z-score if standard scaler)
302
-
303
  important_kmers.append({
304
- 'kmer': kmer_name,
305
- 'impact': imp_val,
 
306
  'direction': direction,
307
- 'occurrence': freq,
308
- 'sigma': sigma
309
  })
310
 
311
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
312
- pred_label = 'human' if pred_class == 1 else 'non-human'
313
- human_prob = float(probs[0][1])
314
- non_human_prob = float(probs[0][0])
315
- conf = float(max(probs[0])) # confidence in the predicted class
316
-
317
- # Generate text results
318
- results_text = (
319
  f"**Sequence Header**: {header}\n\n"
320
  f"**Predicted Label**: {pred_label}\n"
321
- f"**Confidence**: {conf:.4f}\n\n"
322
  f"**Human Probability**: {human_prob:.4f}\n"
323
  f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
324
  "### Most Influential k-mers:\n"
325
  )
326
- for k in important_kmers:
327
- direction_text = f"pushes toward {k['direction']}"
328
- occurrence_text = f"{k['occurrence']:.2f}% of sequence"
329
- sigma_text = f"{abs(k['sigma']):.2f}σ " + ("above" if k['sigma'] > 0 else "below") + " mean"
330
- results_text += (
331
- f"- **{k['kmer']}**: "
332
- f"impact = {k['impact']:.4f}, {direction_text}, "
333
- f"occurrence = {occurrence_text}, "
334
- f"({sigma_text})\n"
335
  )
336
 
337
- # Create figure
338
- fig = create_visualization(important_kmers, human_prob, f"{header}")
339
-
340
- # Convert figure to image
341
- buf = io.BytesIO()
342
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
343
- buf.seek(0)
344
- plot_image = Image.open(buf)
345
- plt.close(fig)
346
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
348
- return f"Error during prediction or visualization: {str(e)}", None
 
 
 
 
 
349
 
350
- return results_text, plot_image
351
 
352
  ###############################################################################
353
  # Gradio Interface
354
  ###############################################################################
355
- iface = gr.Interface(
356
- fn=predict,
357
- inputs=gr.File(label="Upload FASTA file", type="binary"),
358
- outputs=[
359
- gr.Markdown(label="Prediction Results"),
360
- gr.Image(label="K-mer Analysis Visualization")
361
- ],
362
- title="Virus Host Classifier",
363
- description=(
364
- "Upload a FASTA file containing a single nucleotide sequence. "
365
- "This model will predict whether the virus host is **human** or **non-human**, "
366
- "provide a confidence score, and highlight the most influential k-mers in the classification."
367
- ),
368
- allow_flagging="never",
369
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  if __name__ == "__main__":
372
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
33
 
34
  def get_feature_importance(self, x):
35
  """
36
+ Calculate gradient-based feature importance, specifically for the
37
+ 'human' class (index=1) by computing gradient of that probability wrt x.
38
  """
39
  x.requires_grad_(True)
40
  output = self.network(x)
41
  probs = torch.softmax(output, dim=1)
42
 
43
+ # Probability of 'human' class (index=1)
44
  human_prob = probs[..., 1]
45
  if x.grad is not None:
46
  x.grad.zero_()
 
94
  ###############################################################################
95
  # Visualization
96
  ###############################################################################
97
+ def create_shap_waterfall_plot(important_kmers, all_kmer_importance, human_prob, title):
98
  """
99
+ Create a SHAP-like waterfall plot:
100
+ - Start at baseline = 0.5
101
+ - Add a bar for "Other" which is the combined effect of all less-important k-mers
102
+ - Then apply each of the top k-mers in descending order of absolute importance
103
+ - Show final predicted human probability as the endpoint
104
  """
105
 
106
+ # 1) Sort 'important_kmers' by absolute impact descending
107
+ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
 
 
 
 
 
 
108
 
109
+ # 2) Compute the total effect of "other" k-mers
110
+ # We have 256 total features. We selected top 10. Sum the rest.
111
+ top_ids = set([km['idx'] for km in sorted_kmers])
112
+ other_contributions = []
113
+ for i, val in enumerate(all_kmer_importance):
114
+ if i not in top_ids:
115
+ other_contributions.append(val)
116
+ # sum up those "other" contributions
117
+ other_sum = np.sum(other_contributions)
118
+ # The "impact" for "other" will be the absolute value, direction depends on sign
119
+ other_impact = float(abs(other_sum))
120
+ other_direction = "human" if other_sum > 0 else "non-human"
121
+
122
+ # 3) Build a list of all bars: first "other", then each top k-mer
123
+ # Each bar needs: name, raw_contribution_value
124
+ # We'll store (label, contribution). The sign indicates direction.
125
+ bars = []
126
+ bars.append(("Other", other_sum)) # lumps the leftover k-mers
127
+
128
+ for km in sorted_kmers:
129
+ # We re-inject the sign on the raw gradient
130
+ # (We stored only the absolute in "impact," so let's create a signed value)
131
+ signed_val = km['impact'] if km['direction'] == 'human' else -km['impact']
132
+ bars.append((km['kmer'], signed_val))
133
+
134
+ # 4) Waterfall plot data:
135
+ # We'll accumulate partial sums from baseline=0.5
136
  baseline = 0.5
137
+ running_val = baseline
138
+ x_labels = []
139
+ y_vals = []
140
+ bar_colors = []
141
+
142
+ # We'll use green for positive contributions (pushing toward 'human'),
143
+ # red for negative contributions (pushing away from 'human')
144
+ for (label, contrib) in bars:
145
+ x_labels.append(label)
146
+ # new value after adding this contribution
147
+ new_val = running_val + (0.05 * contrib)
148
+ # ^ scaled by 0.05 for better display. Adjust as desired.
149
+
150
+ y_vals.append((running_val, new_val))
151
+ running_val = new_val
152
+ if contrib >= 0:
153
+ bar_colors.append("green")
154
+ else:
155
+ bar_colors.append("red")
156
 
157
+ final_prob = running_val
158
+ # Final point is the model's predicted probability (not always exact, but this is a shap-like idea).
159
+ # If we want to forcibly ensure final_prob = human_prob, we could do:
160
+ # correction = human_prob - running_val
161
+ # running_val += correction
162
+ # but for now let's keep the "waterfall" purely additive from the gradient.
163
 
164
+ # Let's plot:
165
+ fig, ax = plt.subplots(figsize=(10, 6))
166
+
167
+ # We'll create the bars manually
168
+ x_positions = np.arange(len(x_labels))
169
+ last_end = baseline
170
+
171
+ for i, ((start_val, end_val), color) in enumerate(zip(y_vals, bar_colors)):
172
+ # The bar's height is the difference
173
+ height = end_val - start_val
174
+ ax.bar(i, height, bottom=start_val, color=color, edgecolor='black', alpha=0.7)
175
+ ax.text(i, (start_val + end_val) / 2, f"{height:+.3f}", ha='center', va='center', color='white', fontsize=8)
176
+
177
+ ax.axhline(y=baseline, color='black', linestyle='--', linewidth=1)
178
+ ax.set_xticks(x_positions)
179
+ ax.set_xticklabels(x_labels, rotation=45, ha='right')
180
+ ax.set_ylim(0, 1)
181
+ ax.set_ylabel("Running Probability (Human)")
182
+ ax.set_title(f"SHAP-like Waterfall — Final Probability: {final_prob:.3f} (Model Probability: {human_prob:.3f})")
183
 
184
+ plt.tight_layout()
185
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ def create_frequency_sigma_plot(important_kmers, title):
188
+ """Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean."""
189
+ # Sort by absolute impact
190
+ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
191
+ kmers = [k["kmer"] for k in sorted_kmers]
192
+ frequencies = [k["occurrence"] for k in sorted_kmers] # in %
193
+ sigmas = [k["sigma"] for k in sorted_kmers]
194
+ directions = [k["direction"] for k in sorted_kmers]
 
 
 
 
 
 
 
 
195
 
 
196
  x = np.arange(len(kmers))
197
  width = 0.4
198
 
199
+ fig, ax_bar = plt.subplots(figsize=(10, 6))
200
+
201
+ # Bar for frequency
202
+ bars_freq = ax_bar.bar(
203
+ x - width/2, frequencies, width, alpha=0.7,
204
+ color=["green" if d=="human" else "red" for d in directions],
205
+ label="Frequency (%)"
206
+ )
207
  ax_bar.set_ylabel("Frequency (%)")
208
  ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
 
209
 
210
  # Twin axis for σ
211
  ax_bar_twin = ax_bar.twinx()
212
+ bars_sigma = ax_bar_twin.bar(
213
+ x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
214
+ )
215
  ax_bar_twin.set_ylabel("Standard Deviations (σ)")
216
 
217
+ ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
218
  ax_bar.set_xticks(x)
219
+ ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
220
+
221
+ # Combined legend
222
  lines1, labels1 = ax_bar.get_legend_handles_labels()
223
  lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
224
+ ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
 
 
 
 
 
225
 
226
+ plt.tight_layout()
227
+ return fig
228
+
229
+ def create_importance_bar_plot(important_kmers, title):
230
+ """
231
+ Create a simple bar chart showing the absolute gradient magnitude
232
+ for the top k-mers, sorted descending.
233
+ """
234
  sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
235
+ kmers = [k['kmer'] for k in sorted_kmers]
236
+ impacts = [k['impact'] for k in sorted_kmers]
237
+ directions = [k["direction"] for k in sorted_kmers]
238
 
239
+ x = np.arange(len(kmers))
240
+
241
+ fig, ax = plt.subplots(figsize=(10, 6))
242
+ bar_colors = ["green" if d=="human" else "red" for d in directions]
243
 
244
+ ax.bar(x, impacts, color=bar_colors, alpha=0.7)
245
+ ax.set_xticks(x)
246
+ ax.set_xticklabels(kmers, rotation=45, ha='right')
247
+ ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
248
+ ax.set_ylabel("Gradient Magnitude")
249
+ ax.grid(axis="y", alpha=0.3)
250
 
 
251
  plt.tight_layout()
252
  return fig
253
 
 
257
  ###############################################################################
258
  def predict(file_obj):
259
  """
260
+ Main function for Gradio:
261
+ 1. Reads the uploaded FASTA file or text.
262
  2. Loads the model and scaler.
263
  3. Generates predictions, probabilities, and top k-mers.
264
+ 4. Returns multiple outputs:
265
+ - A textual summary (Markdown).
266
+ - Waterfall plot.
267
+ - Frequency & sigma plot.
268
+ - Absolute importance bar plot.
269
  """
270
+ # 0. Basic file read
271
  if file_obj is None:
272
+ return (
273
+ "Please upload a FASTA file.",
274
+ None,
275
+ None,
276
+ None
277
+ )
278
 
 
279
  try:
280
+ # If user provided raw text, use that
281
  if isinstance(file_obj, str):
282
  text = file_obj
283
  else:
284
+ # If user uploaded a file, decode it
285
  text = file_obj.decode('utf-8')
286
  except Exception as e:
287
+ return (
288
+ f"Error reading file: {str(e)}",
289
+ None,
290
+ None,
291
+ None
292
+ )
293
 
294
+ # 1. Parse FASTA
295
+ sequences = parse_fasta(text)
296
+ if len(sequences) == 0:
297
+ return (
298
+ "No valid FASTA sequences found. Please check your input.",
299
+ None,
300
+ None,
301
+ None
302
+ )
303
+ # We’ll just classify the first sequence for demonstration
304
+ header, seq = sequences[0]
305
+
306
+ # 2. Create k-mer vector & load model
307
  k = 4
 
 
 
 
308
  try:
309
+ device = "cuda" if torch.cuda.is_available() else "cpu"
310
+
311
+ # Prepare raw freq vector & scale
312
+ raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
313
+
314
+ # Load model & scaler
315
+ model = VirusClassifier(input_shape=4**k).to(device)
316
  state_dict = torch.load('model.pt', map_location=device)
317
  model.load_state_dict(state_dict)
318
  scaler = joblib.load('scaler.pkl')
319
  model.eval()
 
 
320
 
321
+ scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
322
+ X_tensor = torch.FloatTensor(scaled_vector).to(device)
 
 
 
 
 
 
 
 
323
 
324
+ # 3. Inference
 
 
 
 
 
325
  with torch.no_grad():
326
+ logits = model(X_tensor)
327
+ probs = torch.softmax(logits, dim=1)
328
+ human_prob = float(probs[0][1])
329
+ non_human_prob = float(probs[0][0])
330
+ pred_class = 1 if human_prob >= non_human_prob else 0
331
+ pred_label = "human" if pred_class == 1 else "non-human"
332
+ confidence = float(max(probs[0]))
333
+
334
+ # 4. Feature importance
335
  importance, hum_prob_grad = model.get_feature_importance(X_tensor)
336
+ # shape: [1, 256]
337
+ kmer_importances = importance[0].cpu().numpy()
338
+
339
+ # We’ll store them as a dictionary: index -> (k-mer, importance)
340
+ # Build up a dict for k-mer strings
341
+ kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
342
+ kmer_dict = {km: i for i, km in enumerate(kmers_list)}
343
 
344
+ # 5. Get the top 10 k-mers by absolute importance
345
+ abs_importance = np.abs(kmer_importances)
346
  top_k = 10
347
+ top_idxs = np.argsort(abs_importance)[-top_k:][::-1] # descending
348
  important_kmers = []
349
+ for idx in top_idxs:
350
+ # Find the k-mer by index
351
+ kmer_str = kmers_list[idx]
352
+ # direction
353
+ direction = "human" if kmer_importances[idx] > 0 else "non-human"
354
+ # frequency in % from raw_freq_vector
355
+ freq_percent = float(raw_freq_vector[idx] * 100)
356
+ # sigma from scaled vector
357
+ sigma_val = float(scaled_vector[0][idx])
 
 
 
 
358
  important_kmers.append({
359
+ 'kmer': kmer_str,
360
+ 'idx': idx,
361
+ 'impact': float(abs_importance[idx]),
362
  'direction': direction,
363
+ 'occurrence': freq_percent,
364
+ 'sigma': sigma_val
365
  })
366
 
367
+ # 6. Text Summary
368
+ summary_text = (
 
 
 
 
 
 
369
  f"**Sequence Header**: {header}\n\n"
370
  f"**Predicted Label**: {pred_label}\n"
371
+ f"**Confidence**: {confidence:.4f}\n\n"
372
  f"**Human Probability**: {human_prob:.4f}\n"
373
  f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
374
  "### Most Influential k-mers:\n"
375
  )
376
+ for km in important_kmers:
377
+ direction_text = f"(pushes toward {km['direction']})"
378
+ freq_text = f"{km['occurrence']:.2f}%"
379
+ sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean"
380
+ summary_text += (
381
+ f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
382
+ f"occurrence={freq_text}, ({sigma_text})\n"
 
 
383
  )
384
 
385
+ # 7. Plots
386
+ # a) SHAP-like Waterfall Plot
387
+ fig_waterfall = create_shap_waterfall_plot(
388
+ important_kmers,
389
+ kmer_importances,
390
+ human_prob,
391
+ f"{header}"
392
+ )
393
+ buf1 = io.BytesIO()
394
+ fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120)
395
+ buf1.seek(0)
396
+ waterfall_img = Image.open(buf1)
397
+ plt.close(fig_waterfall)
398
+
399
+ # b) Frequency & σ Plot (top 10 k-mers)
400
+ fig_freq_sigma = create_frequency_sigma_plot(
401
+ important_kmers,
402
+ f"{header}"
403
+ )
404
+ buf2 = io.BytesIO()
405
+ fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120)
406
+ buf2.seek(0)
407
+ freq_sigma_img = Image.open(buf2)
408
+ plt.close(fig_freq_sigma)
409
+
410
+ # c) Absolute Importance Bar Plot
411
+ fig_imp = create_importance_bar_plot(
412
+ important_kmers,
413
+ f"{header}"
414
+ )
415
+ buf3 = io.BytesIO()
416
+ fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120)
417
+ buf3.seek(0)
418
+ importance_img = Image.open(buf3)
419
+ plt.close(fig_imp)
420
+
421
+ return summary_text, waterfall_img, freq_sigma_img, importance_img
422
+
423
  except Exception as e:
424
+ return (
425
+ f"Error during prediction or visualization: {str(e)}",
426
+ None,
427
+ None,
428
+ None
429
+ )
430
 
 
431
 
432
  ###############################################################################
433
  # Gradio Interface
434
  ###############################################################################
435
+ with gr.Blocks(title="Advanced Virus Host Classifier") as demo:
436
+ gr.Markdown(
437
+ """
438
+ # Advanced Virus Host Classifier
439
+ **Upload a FASTA file** containing a single nucleotide sequence.
440
+ The model will predict whether this sequence is **human** or **non-human**,
441
+ provide a confidence score, and highlight the most influential k-mers
442
+ (using a SHAP-like waterfall plot) along with two additional plots.
443
+ """
444
+ )
445
+
446
+ with gr.Row():
447
+ file_in = gr.File(label="Upload FASTA", type="binary")
448
+ btn = gr.Button("Run Prediction")
449
+
450
+ # We will create multiple tabs for our outputs
451
+ with gr.Tabs():
452
+ with gr.Tab("Prediction Results"):
453
+ md_out = gr.Markdown()
454
+ with gr.Tab("SHAP-like Waterfall Plot"):
455
+ water_out = gr.Image()
456
+ with gr.Tab("Frequency & σ Plot"):
457
+ freq_out = gr.Image()
458
+ with gr.Tab("Importance Bar Plot"):
459
+ imp_out = gr.Image()
460
+
461
+ # Link the button
462
+ btn.click(
463
+ fn=predict,
464
+ inputs=[file_in],
465
+ outputs=[md_out, water_out, freq_out, imp_out]
466
+ )
467
 
468
  if __name__ == "__main__":
469
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)