hiyata commited on
Commit
b5edb58
·
verified ·
1 Parent(s): 3b775b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -393
app.py CHANGED
@@ -2,17 +2,12 @@ import gradio as gr
2
  import torch
3
  import joblib
4
  import numpy as np
5
- import shap
6
- import random
7
  from itertools import product
8
  import torch.nn as nn
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
12
 
13
- ###############################################################################
14
- # Model Definition
15
- ###############################################################################
16
  class VirusClassifier(nn.Module):
17
  def __init__(self, input_shape: int):
18
  super(VirusClassifier, self).__init__()
@@ -34,28 +29,38 @@ class VirusClassifier(nn.Module):
34
  return self.network(x)
35
 
36
  def get_feature_importance(self, x):
37
- """
38
- Calculate gradient-based feature importance, specifically for the
39
- 'human' class (index=1) by computing gradient of that probability wrt x.
40
- """
41
  x.requires_grad_(True)
42
  output = self.network(x)
43
  probs = torch.softmax(output, dim=1)
44
 
45
- # Probability of 'human' class (index=1)
46
  human_prob = probs[..., 1]
47
  if x.grad is not None:
48
  x.grad.zero_()
49
  human_prob.backward()
50
- importance = x.grad # shape: (batch_size, n_features)
51
 
52
  return importance, float(human_prob)
53
 
54
- ###############################################################################
55
- # Utility Functions
56
- ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def parse_fasta(text):
58
- """Parses text input in FASTA format into a list of (header, sequence)."""
59
  sequences = []
60
  current_header = None
61
  current_sequence = []
@@ -75,213 +80,97 @@ def parse_fasta(text):
75
  sequences.append((current_header, ''.join(current_sequence)))
76
  return sequences
77
 
78
- def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
79
- """Convert a single nucleotide sequence to a k-mer frequency vector."""
80
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
81
- kmer_dict = {km: i for i, km in enumerate(kmers)}
82
- vec = np.zeros(len(kmers), dtype=np.float32)
83
 
84
- for i in range(len(sequence) - k + 1):
85
- kmer = sequence[i:i+k]
86
- if kmer in kmer_dict:
87
- vec[kmer_dict[kmer]] += 1
88
-
89
- total_kmers = len(sequence) - k + 1
90
- if total_kmers > 0:
91
- vec = vec / total_kmers # normalize frequencies
92
-
93
- return vec
94
-
95
- ###############################################################################
96
- # Additional Plots
97
- ###############################################################################
98
- def create_probability_bar_plot(prob_human, prob_nonhuman):
99
- """
100
- Simple bar plot comparing human vs. non-human probabilities.
101
- """
102
- labels = ["Non-human", "Human"]
103
- probs = [prob_nonhuman, prob_human]
104
- colors = ["red", "green"]
105
-
106
- fig, ax = plt.subplots(figsize=(6, 4))
107
- ax.bar(labels, probs, color=colors, alpha=0.7)
108
- ax.set_ylim(0, 1)
109
- for i, v in enumerate(probs):
110
- ax.text(i, v+0.02, f"{v:.3f}", ha='center', color='black', fontsize=11)
111
-
112
- ax.set_title("Predicted Probabilities")
113
- ax.set_ylabel("Probability")
114
- plt.tight_layout()
115
- return fig
116
-
117
- def create_frequency_sigma_plot(important_kmers, title):
118
- """
119
- Creates a bar plot of the top k-mers (by importance) showing
120
- frequency (%) and σ from mean.
121
- """
122
- # Sort by absolute impact
123
- sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
124
- kmers = [k["kmer"] for k in sorted_kmers]
125
- frequencies = [k["occurrence"] for k in sorted_kmers] # in %
126
- sigmas = [k["sigma"] for k in sorted_kmers]
127
- directions = [k["direction"] for k in sorted_kmers]
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  x = np.arange(len(kmers))
130
- width = 0.4
131
-
132
- fig, ax_bar = plt.subplots(figsize=(10, 5))
133
-
134
- # Bar for frequency
135
- bars_freq = ax_bar.bar(
136
- x - width/2, frequencies, width, alpha=0.7,
137
- color=["green" if d=="human" else "red" for d in directions],
138
- label="Frequency (%)"
139
- )
140
- ax_bar.set_ylabel("Frequency (%)")
141
- ax_bar.set_ylim(0, max(frequencies) * 1.2 if len(frequencies) > 0 else 1)
142
-
143
- # Twin axis for σ
144
- ax_bar_twin = ax_bar.twinx()
145
- bars_sigma = ax_bar_twin.bar(
146
- x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
147
- )
148
- ax_bar_twin.set_ylabel("Standard Deviations (σ)")
149
-
150
- ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
151
- ax_bar.set_xticks(x)
152
- ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
153
-
154
- # Combined legend
155
- lines1, labels1 = ax_bar.get_legend_handles_labels()
156
- lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
157
- ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
158
-
159
- plt.tight_layout()
160
- return fig
161
-
162
- def create_importance_bar_plot(important_kmers, title):
163
- """
164
- Create a simple bar chart showing the absolute gradient magnitude
165
- for the top k-mers, sorted descending.
166
- """
167
- sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
168
- kmers = [k['kmer'] for k in sorted_kmers]
169
- impacts = [k['impact'] for k in sorted_kmers]
170
- directions = [k["direction"] for k in sorted_kmers]
171
-
172
- x = np.arange(len(kmers))
173
-
174
- fig, ax = plt.subplots(figsize=(10, 5))
175
- bar_colors = ["green" if d=="human" else "red" for d in directions]
176
-
177
- ax.bar(x, impacts, color=bar_colors, alpha=0.7, edgecolor='black')
178
- ax.set_xticks(x)
179
- ax.set_xticklabels(kmers, rotation=45, ha='right')
180
- ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
181
- ax.set_ylabel("Gradient Magnitude")
182
- ax.grid(axis="y", alpha=0.3)
183
-
184
- plt.tight_layout()
185
- return fig
186
-
187
- ###############################################################################
188
- # SHAP Beeswarm
189
- ###############################################################################
190
- def create_shap_beeswarm_plot(
191
- model,
192
- input_vector: np.ndarray,
193
- background_data: np.ndarray,
194
- feature_names: list
195
- ):
196
- """
197
- Creates a SHAP beeswarm plot using KernelExplainer for the given model and data.
198
 
199
- Parameters
200
- ----------
201
- model : nn.Module
202
- Trained PyTorch model (binary classifier).
203
- input_vector : np.ndarray
204
- The 1-sample input (or multiple samples) we want SHAP values for.
205
- background_data : np.ndarray
206
- Background samples for KernelExplainer. Should have shape (N, #features).
207
- feature_names : list
208
- Names for each feature (k-mers).
209
 
210
- Returns
211
- -------
212
- fig : matplotlib Figure
213
- Beeswarm plot figure.
214
- """
215
-
216
- # We'll define a prediction function that shap can call
217
- # The model outputs logits for shape [N, 2]
218
- # We want the raw outputs for each class. SHAP will handle the link function if needed.
219
- def predict_fn(data):
220
- """
221
- data: shape (N, #features)
222
- returns: shape (N, 2) for 2-class logits
223
- """
224
- with torch.no_grad():
225
- x = torch.FloatTensor(data)
226
- logits = model(x)
227
- return logits.detach().cpu().numpy()
228
-
229
- # Create KernelExplainer
230
- explainer = shap.KernelExplainer(
231
- model=predict_fn,
232
- data=background_data
233
- )
234
-
235
- # Compute SHAP values
236
- # For a 2-class model, shap_values is a list of length 2 => [class0 array, class1 array]
237
- # Each array is shape (N, #features).
238
- shap_values = explainer.shap_values(input_vector)
239
-
240
- # We’ll produce a beeswarm for the 'human' class (class index=1).
241
- # If we have only 1 sample, the beeswarm won't be too interesting, but let's do it anyway.
242
- class_idx = 1 # 'human'
243
 
244
- # If we only have one sample, place it in an array for shap summary plotting:
245
- # We can do shap_values[class_idx].shape => (1, #features) for a single sample
246
- # Beeswarm typically expects multiple samples. We'll plot anyway.
247
- shap.plots.beeswarm(
248
- shap_values[class_idx],
249
- feature_names=feature_names,
250
- show=False
251
- )
252
-
253
- fig = plt.gcf()
254
- fig.set_size_inches(8, 6)
255
- plt.title("SHAP Beeswarm Plot (Class: Human)")
256
-
257
  plt.tight_layout()
258
  return fig
259
 
260
- ###############################################################################
261
- # Prediction Function
262
- ###############################################################################
263
  def predict(file_obj):
264
- """
265
- Main function for Gradio:
266
- 1. Reads the uploaded FASTA file or text.
267
- 2. Loads the model and scaler.
268
- 3. Generates predictions, probabilities, and top k-mers.
269
- 4. Creates multiple outputs:
270
- - Text summary (Markdown)
271
- - Probability Bar Plot
272
- - SHAP Beeswarm Plot
273
- - Frequency & σ Plot
274
- - Absolute Feature Importance Bar Plot
275
- """
276
- # 0. Basic file read
277
  if file_obj is None:
278
- return (
279
- "Please upload a FASTA file.",
280
- None,
281
- None,
282
- None,
283
- None
284
- )
285
 
286
  try:
287
  if isinstance(file_obj, str):
@@ -289,202 +178,106 @@ def predict(file_obj):
289
  else:
290
  text = file_obj.decode('utf-8')
291
  except Exception as e:
292
- return (
293
- f"Error reading file: {str(e)}",
294
- None,
295
- None,
296
- None,
297
- None
298
- )
299
-
300
- # 1. Parse FASTA
301
- sequences = parse_fasta(text)
302
- if len(sequences) == 0:
303
- return (
304
- "No valid FASTA sequences found. Please check your input.",
305
- None,
306
- None,
307
- None,
308
- None
309
- )
310
- header, seq = sequences[0] # We'll classify only the first sequence
311
 
312
- # 2. Prepare model, scaler, and input
313
  k = 4
314
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
315
  try:
316
- raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
317
-
318
- # Load model & scaler
319
- model = VirusClassifier(input_shape=4**k).to(device)
320
- state_dict = torch.load("model.pt", map_location=device)
321
  model.load_state_dict(state_dict)
322
- scaler = joblib.load("scaler.pkl")
323
  model.eval()
 
 
324
 
325
- scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
326
- X_tensor = torch.FloatTensor(scaled_vector).to(device)
327
-
328
- # 3. Predict
 
 
 
 
 
 
 
 
329
  with torch.no_grad():
330
- logits = model(X_tensor)
331
- probs = torch.softmax(logits, dim=1)
332
- human_prob = float(probs[0][1])
333
- non_human_prob = float(probs[0][0])
334
- pred_label = "human" if human_prob >= non_human_prob else "non-human"
335
- confidence = float(max(probs[0]))
336
-
337
- # 4. Gradient-based feature importance
338
- importance, hum_prob_grad = model.get_feature_importance(X_tensor)
339
- importances = importance[0].cpu().numpy() # shape: (#features,)
340
- abs_importances = np.abs(importances)
341
-
342
- # 5. Gather k-mer strings
343
- kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
344
- # top 10 by absolute importance
345
  top_k = 10
346
- top_idxs = np.argsort(abs_importances)[-top_k:][::-1]
 
347
  important_kmers = []
348
- for idx in top_idxs:
349
- direction = "human" if importances[idx] > 0 else "non-human"
350
- freq_percent = float(raw_freq_vector[idx] * 100.0)
351
- sigma_val = float(scaled_vector[0][idx]) # scaled / standardized val
 
 
 
352
  important_kmers.append({
353
- 'kmer': kmers_list[idx],
354
- 'idx': idx,
355
- 'impact': abs_importances[idx],
356
  'direction': direction,
357
- 'occurrence': freq_percent,
358
- 'sigma': sigma_val
359
  })
360
-
361
- # 6. Generate text summary
362
- text_summary = (
363
- f"**Sequence Header**: {header}\n\n"
364
- f"**Predicted Label**: {pred_label}\n"
365
- f"**Confidence**: {confidence:.4f}\n\n"
366
- f"**Human Probability**: {human_prob:.4f}\n"
367
- f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
368
- "### Most Influential k-mers:\n"
369
- )
370
- for km in important_kmers:
371
- direction_text = f"(pushes toward {km['direction']})"
372
- freq_text = f"{km['occurrence']:.2f}%"
373
- sigma_text = (
374
- f"{abs(km['sigma']):.2f}σ "
375
- + ("above" if km['sigma'] > 0 else "below")
376
- + " mean"
377
- )
378
- text_summary += (
379
- f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
380
- f"occurrence={freq_text}, ({sigma_text})\n"
381
- )
382
-
383
- # 7. Probability Bar Plot
384
- fig_prob = create_probability_bar_plot(human_prob, non_human_prob)
385
- buf_prob = io.BytesIO()
386
- fig_prob.savefig(buf_prob, format='png', bbox_inches='tight', dpi=120)
387
- buf_prob.seek(0)
388
- prob_img = Image.open(buf_prob)
389
- plt.close(fig_prob)
390
-
391
- # 8. SHAP Beeswarm Plot
392
- # We need some background data for KernelExplainer. Let's create a small random sample
393
- # or sample from the scaled_vector itself in a repeated manner. Real usage: choose a valid background set.
394
- background_size = 5 # keep small for speed
395
- # We'll pick random sequences from normal(0,1) or from scaled_vector repeated
396
- background_data = []
397
- for _ in range(background_size):
398
- # Option A: random small variations around scaled_vector
399
- # new_sample = scaled_vector[0] + np.random.normal(0, 0.5, size=scaled_vector.shape[1])
400
- # Option B: just clone the same scaled vector multiple times
401
- new_sample = scaled_vector[0]
402
- background_data.append(new_sample)
403
- background_data = np.stack(background_data, axis=0) # shape (5, #features)
404
-
405
- fig_bee = create_shap_beeswarm_plot(
406
- model=model,
407
- input_vector=scaled_vector, # our single sample
408
- background_data=background_data, # background for KernelExplainer
409
- feature_names=kmers_list
410
- )
411
- buf_bee = io.BytesIO()
412
- fig_bee.savefig(buf_bee, format='png', bbox_inches='tight', dpi=120)
413
- buf_bee.seek(0)
414
- bee_img = Image.open(buf_bee)
415
- plt.close(fig_bee)
416
-
417
- # 9. Frequency & σ Plot
418
- fig_freq = create_frequency_sigma_plot(important_kmers, header)
419
- buf_freq = io.BytesIO()
420
- fig_freq.savefig(buf_freq, format='png', bbox_inches='tight', dpi=120)
421
- buf_freq.seek(0)
422
- freq_img = Image.open(buf_freq)
423
- plt.close(fig_freq)
424
-
425
- # 10. Absolute Feature Importance Bar Plot
426
- fig_imp = create_importance_bar_plot(important_kmers, header)
427
- buf_imp = io.BytesIO()
428
- fig_imp.savefig(buf_imp, format='png', bbox_inches='tight', dpi=120)
429
- buf_imp.seek(0)
430
- imp_img = Image.open(buf_imp)
431
- plt.close(fig_imp)
432
-
433
- return text_summary, prob_img, bee_img, freq_img, imp_img
434
-
435
- except Exception as e:
436
- return (
437
- f"Error during prediction or visualization: {str(e)}",
438
- None,
439
- None,
440
- None,
441
- None
442
- )
443
-
444
-
445
- ###############################################################################
446
- # Gradio Interface
447
- ###############################################################################
448
- with gr.Blocks(title="Advanced Virus Host Classifier with SHAP Beeswarm") as demo:
449
- gr.Markdown(
450
- """
451
- # Advanced Virus Host Classifier (SHAP Beeswarm Edition)
452
-
453
- **Upload a FASTA file** containing a single nucleotide sequence.
454
- The model will predict whether this sequence is **human** or **non-human**,
455
- provide a confidence score, and highlight the most influential k-mers.
456
- We also produce a **SHAP beeswarm** plot for the features.
457
 
458
- ---
459
- **Note**: Beeswarm plots are usually most insightful with multiple samples.
460
- Here, we demonstrate usage with a single sample plus a small synthetic background.
461
- """
462
- )
463
-
464
- with gr.Row():
465
- file_in = gr.File(label="Upload FASTA", type="binary")
466
- btn = gr.Button("Run Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
- # We will create multiple tabs for our outputs
469
- with gr.Tabs():
470
- with gr.Tab("Prediction Results"):
471
- md_out = gr.Markdown()
472
- with gr.Tab("Probability Plot"):
473
- prob_out = gr.Image()
474
- with gr.Tab("SHAP Beeswarm Plot"):
475
- bee_out = gr.Image()
476
- with gr.Tab("Frequency & σ Plot"):
477
- freq_out = gr.Image()
478
- with gr.Tab("Importance Bar Plot"):
479
- imp_out = gr.Image()
480
 
481
- # Link the button
482
- btn.click(
483
- fn=predict,
484
- inputs=[file_in],
485
- outputs=[md_out, prob_out, bee_out, freq_out, imp_out]
486
- )
 
 
 
487
 
488
  if __name__ == "__main__":
489
- # By default, share=False. You can set share=True for external access.
490
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
2
  import torch
3
  import joblib
4
  import numpy as np
 
 
5
  from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
 
 
 
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
13
  super(VirusClassifier, self).__init__()
 
29
  return self.network(x)
30
 
31
  def get_feature_importance(self, x):
32
+ """Calculate feature importance using gradient-based method"""
 
 
 
33
  x.requires_grad_(True)
34
  output = self.network(x)
35
  probs = torch.softmax(output, dim=1)
36
 
37
+ # Get importance for human class (index 1)
38
  human_prob = probs[..., 1]
39
  if x.grad is not None:
40
  x.grad.zero_()
41
  human_prob.backward()
42
+ importance = x.grad
43
 
44
  return importance, float(human_prob)
45
 
46
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
47
+ """Convert sequence to k-mer frequency vector"""
48
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
49
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
50
+ vec = np.zeros(len(kmers), dtype=np.float32)
51
+
52
+ for i in range(len(sequence) - k + 1):
53
+ kmer = sequence[i:i+k]
54
+ if kmer in kmer_dict:
55
+ vec[kmer_dict[kmer]] += 1
56
+
57
+ total_kmers = len(sequence) - k + 1
58
+ if total_kmers > 0:
59
+ vec = vec / total_kmers
60
+
61
+ return vec
62
+
63
  def parse_fasta(text):
 
64
  sequences = []
65
  current_header = None
66
  current_sequence = []
 
80
  sequences.append((current_header, ''.join(current_sequence)))
81
  return sequences
82
 
83
+ def create_visualization(important_kmers, human_prob, title):
84
+ """Create a comprehensive visualization of k-mer impacts"""
85
+ fig = plt.figure(figsize=(15, 10))
 
 
86
 
87
+ # Create grid for subplots
88
+ gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # 1. Probability Step Plot
91
+ ax1 = plt.subplot(gs[0])
92
+ current_prob = 0.5
93
+ steps = [('Start', current_prob, 0)]
94
+
95
+ for kmer in important_kmers:
96
+ change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
97
+ current_prob += change
98
+ steps.append((kmer['kmer'], current_prob, change))
99
+
100
+ x = range(len(steps))
101
+ y = [step[1] for step in steps]
102
+
103
+ # Plot steps
104
+ ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
105
+ ax1.plot(x, y, 'b.', markersize=10)
106
+
107
+ # Add reference line
108
+ ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
109
+
110
+ # Customize plot
111
+ ax1.grid(True, linestyle='--', alpha=0.7)
112
+ ax1.set_ylim(0, 1)
113
+ ax1.set_ylabel('Human Probability')
114
+ ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
115
+
116
+ # Add labels for each point
117
+ for i, (kmer, prob, change) in enumerate(steps):
118
+ # Add k-mer label
119
+ ax1.annotate(kmer,
120
+ (i, prob),
121
+ xytext=(0, 10 if i % 2 == 0 else -20),
122
+ textcoords='offset points',
123
+ ha='center',
124
+ rotation=45)
125
+
126
+ # Add change value
127
+ if i > 0:
128
+ change_text = f'{change:+.3f}'
129
+ color = 'green' if change > 0 else 'red'
130
+ ax1.annotate(change_text,
131
+ (i, prob),
132
+ xytext=(0, -20 if i % 2 == 0 else 10),
133
+ textcoords='offset points',
134
+ ha='center',
135
+ color=color)
136
+
137
+ ax1.legend()
138
+
139
+ # 2. K-mer Frequency and Sigma Plot
140
+ ax2 = plt.subplot(gs[1])
141
+
142
+ # Prepare data
143
+ kmers = [k['kmer'] for k in important_kmers]
144
+ frequencies = [k['occurrence'] for k in important_kmers]
145
+ sigmas = [k['sigma'] for k in important_kmers]
146
+ colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
147
+
148
+ # Create bar plot for frequencies
149
  x = np.arange(len(kmers))
150
+ width = 0.35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
153
+ ax2_twin = ax2.twinx()
154
+ ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
 
 
 
 
 
 
 
155
 
156
+ # Customize plot
157
+ ax2.set_xticks(x)
158
+ ax2.set_xticklabels(kmers, rotation=45)
159
+ ax2.set_ylabel('Frequency (%)')
160
+ ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
161
+ ax2.set_title('K-mer Frequencies and Statistical Significance')
162
+
163
+ # Add legends
164
+ lines1, labels1 = ax2.get_legend_handles_labels()
165
+ lines2, labels2 = ax2_twin.get_legend_handles_labels()
166
+ ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  plt.tight_layout()
169
  return fig
170
 
 
 
 
171
  def predict(file_obj):
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if file_obj is None:
173
+ return "Please upload a FASTA file", None
 
 
 
 
 
 
174
 
175
  try:
176
  if isinstance(file_obj, str):
 
178
  else:
179
  text = file_obj.decode('utf-8')
180
  except Exception as e:
181
+ return f"Error reading file: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
 
183
  k = 4
184
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
185
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
186
+
187
  try:
188
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
189
+ model = VirusClassifier(256).to(device)
190
+ state_dict = torch.load('model.pt', map_location=device)
 
 
191
  model.load_state_dict(state_dict)
192
+ scaler = joblib.load('scaler.pkl')
193
  model.eval()
194
+ except Exception as e:
195
+ return f"Error loading model: {str(e)}", None
196
 
197
+ results_text = ""
198
+ plot_image = None
199
+
200
+ try:
201
+ sequences = parse_fasta(text)
202
+ header, seq = sequences[0]
203
+
204
+ raw_freq_vector = sequence_to_kmer_vector(seq)
205
+ kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
206
+ X_tensor = torch.FloatTensor(kmer_vector).to(device)
207
+
208
+ # Get model predictions
209
  with torch.no_grad():
210
+ output = model(X_tensor)
211
+ probs = torch.softmax(output, dim=1)
212
+
213
+ # Get feature importance
214
+ importance, _ = model.get_feature_importance(X_tensor)
215
+ kmer_importance = importance[0].cpu().numpy()
216
+
217
+ # Get top k-mers
 
 
 
 
 
 
 
218
  top_k = 10
219
+ top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
220
+
221
  important_kmers = []
222
+ for idx in top_indices:
223
+ kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
224
+ imp = float(abs(kmer_importance[idx]))
225
+ direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
226
+ freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
227
+ sigma = float(kmer_vector[0][idx])
228
+
229
  important_kmers.append({
230
+ 'kmer': kmer,
231
+ 'impact': imp,
 
232
  'direction': direction,
233
+ 'occurrence': freq,
234
+ 'sigma': sigma
235
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ # Generate text results
238
+ pred_class = 1 if probs[0][1] > probs[0][0] else 0
239
+ pred_label = 'human' if pred_class == 1 else 'non-human'
240
+ human_prob = float(probs[0][1])
241
+
242
+ results_text = f"""Sequence: {header}
243
+ Prediction: {pred_label}
244
+ Confidence: {float(max(probs[0])):0.4f}
245
+ Human probability: {human_prob:0.4f}
246
+ Non-human probability: {float(probs[0][0]):0.4f}
247
+ Most influential k-mers (ranked by importance):"""
248
+
249
+ for kmer in important_kmers:
250
+ results_text += f"\n {kmer['kmer']}: "
251
+ results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
252
+ results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
253
+ results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
254
+ results_text += "more" if kmer['sigma'] > 0 else "less"
255
+ results_text += " than average)"
256
+
257
+ # Create visualization
258
+ fig = create_visualization(important_kmers, human_prob, header)
259
+
260
+ # Save plot
261
+ buf = io.BytesIO()
262
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
263
+ buf.seek(0)
264
+ plot_image = Image.open(buf)
265
+ plt.close(fig)
266
+
267
+ except Exception as e:
268
+ return f"Error processing sequences: {str(e)}", None
269
 
270
+ return results_text, plot_image
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ iface = gr.Interface(
273
+ fn=predict,
274
+ inputs=gr.File(label="Upload FASTA file", type="binary"),
275
+ outputs=[
276
+ gr.Textbox(label="Results"),
277
+ gr.Image(label="K-mer Analysis Visualization")
278
+ ],
279
+ title="Virus Host Classifier"
280
+ )
281
 
282
  if __name__ == "__main__":
283
+ iface.launch(share=True)