hiyata commited on
Commit
ef80028
·
verified ·
1 Parent(s): f1d4be6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -310
app.py CHANGED
@@ -8,10 +8,6 @@ import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
 
11
- ##############################################################################
12
- # MODEL DEFINITION
13
- ##############################################################################
14
-
15
  class VirusClassifier(nn.Module):
16
  def __init__(self, input_shape: int):
17
  super(VirusClassifier, self).__init__()
@@ -32,10 +28,6 @@ class VirusClassifier(nn.Module):
32
  def forward(self, x):
33
  return self.network(x)
34
 
35
- ##############################################################################
36
- # UTILITIES
37
- ##############################################################################
38
-
39
  def parse_fasta(text):
40
  """
41
  Parses FASTA formatted text into a list of (header, sequence).
@@ -61,7 +53,7 @@ def parse_fasta(text):
61
 
62
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
63
  """
64
- Convert a sequence to a k-mer frequency vector of size len(ACGT^k).
65
  """
66
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
67
  kmer_dict = {km: i for i, km in enumerate(kmers)}
@@ -78,355 +70,218 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
78
 
79
  return vec
80
 
81
- def ablation_importance(model, x_tensor):
82
  """
83
- Calculates a simple ablation-based importance measure for each feature:
84
- 1. Compute baseline human probability p_base.
85
- 2. For each feature i, set x[i] = 0, re-run inference, compute new p, and
86
- measure delta = p_base - p.
87
- 3. Return array of deltas (positive means that removing that feature
88
- *decreases* the probability => that feature was pushing it higher).
89
  """
90
  model.eval()
91
  with torch.no_grad():
92
- # Baseline probability
93
- output = model(x_tensor)
94
- probs = torch.softmax(output, dim=1)
95
- p_base = probs[0, 1].item()
96
-
97
- # Store the delta importances
98
- importances = np.zeros(x_tensor.shape[1], dtype=np.float32)
99
-
100
- # For efficiency, we do ablation one feature at a time
101
- for i in range(x_tensor.shape[1]):
102
- x_copy = x_tensor.clone()
103
- x_copy[0, i] = 0.0 # Ablate this feature
104
- with torch.no_grad():
105
- output_ablation = model(x_copy)
106
- probs_ablation = torch.softmax(output_ablation, dim=1)
107
- p_ablation = probs_ablation[0, 1].item()
108
- # Delta
109
- importances[i] = p_base - p_ablation
110
-
111
- return importances, p_base
112
-
113
- ##############################################################################
114
- # PLOTTING
115
- ##############################################################################
116
 
117
- def create_step_and_frequency_plot(important_kmers, human_prob, title):
118
  """
119
- Creates a combined step plot (showing how each k-mer modifies the probability)
120
- and a frequency vs. sigma bar chart.
121
  """
122
- fig = plt.figure(figsize=(15, 10))
 
123
 
124
- # Create grid for subplots
125
- gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
 
 
126
 
127
- # 1. Probability Step Plot
128
- ax1 = plt.subplot(gs[0])
129
- current_prob = 0.5
130
- steps = [('Start', current_prob, 0)]
131
-
132
- for kmer_info in important_kmers:
133
- change = kmer_info['impact'] # positive => pushes up, negative => pushes down
134
- current_prob += change
135
- steps.append((kmer_info['kmer'], current_prob, change))
136
-
137
- x = range(len(steps))
138
- y = [step[1] for step in steps]
139
-
140
- # Plot steps
141
- ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
142
- ax1.plot(x, y, 'b.', markersize=10)
143
-
144
- # Add reference line
145
- ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
146
-
147
- # Customize plot
148
- ax1.grid(True, linestyle='--', alpha=0.7)
149
- ax1.set_ylim(0, 1)
150
- ax1.set_ylabel('Human Probability')
151
- ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
152
-
153
- # Add labels for each point
154
- for i, (kmer, prob, change) in enumerate(steps):
155
- # Add k-mer label
156
- ax1.annotate(kmer,
157
- (i, prob),
158
- xytext=(0, 10 if i % 2 == 0 else -20),
159
- textcoords='offset points',
160
- ha='center',
161
- rotation=45)
162
-
163
- # Add change value
164
- if i > 0:
165
- change_text = f'{change:+.3f}'
166
- color = 'green' if change > 0 else 'red'
167
- ax1.annotate(change_text,
168
- (i, prob),
169
- xytext=(0, -20 if i % 2 == 0 else 10),
170
- textcoords='offset points',
171
- ha='center',
172
- color=color)
173
 
174
- ax1.legend()
 
 
 
 
175
 
176
- # 2. K-mer Frequency and Sigma Plot
177
- ax2 = plt.subplot(gs[1])
178
-
179
- # Prepare data
180
- kmers = [k['kmer'] for k in important_kmers]
181
- frequencies = [k['occurrence'] for k in important_kmers]
182
- sigmas = [k['sigma'] for k in important_kmers]
183
-
184
- # Color the bars: if impact>0 => green, else red
185
- colors = ['g' if k['impact'] > 0 else 'r' for k in important_kmers]
186
-
187
- # Create bar plot for frequencies
188
- x = np.arange(len(kmers))
189
- width = 0.35
190
-
191
- ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
192
-
193
- # Twin axis for sigma
194
- ax2_twin = ax2.twinx()
195
- # To highlight positive or negative sigma, pick color accordingly
196
- sigma_colors = []
197
- for s, c in zip(sigmas, colors):
198
- if s >= 0:
199
- sigma_colors.append('blue') # above average
200
- else:
201
- sigma_colors.append('gray') # below average
202
-
203
- ax2_twin.bar(x + width/2, sigmas, width, label='σ from Mean', color=sigma_colors, alpha=0.3)
204
-
205
- # Customize plot
206
- ax2.set_xticks(x)
207
- ax2.set_xticklabels(kmers, rotation=45)
208
- ax2.set_ylabel('Frequency (%)')
209
- ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
210
- ax2.set_title('K-mer Frequencies and Statistical Significance')
211
-
212
- # Add legends
213
- lines1, labels1 = ax2.get_legend_handles_labels()
214
- lines2, labels2 = ax2_twin.get_legend_handles_labels()
215
- ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
216
-
217
- plt.tight_layout()
218
  return fig
219
 
220
- def create_shap_like_bar_plot(impact_values, kmer_list, top_k):
221
  """
222
- Creates a horizontal bar plot showing the top_k features by absolute impact.
223
- impact_values: array of float (length=256).
224
- kmer_list: list of all k=4 kmers in order.
225
- top_k: integer, how many top features to display.
226
  """
227
- # Sort by absolute impact
228
- indices_sorted = np.argsort(np.abs(impact_values))[::-1]
229
- top_indices = indices_sorted[:top_k]
230
 
231
- top_impacts = impact_values[top_indices]
232
- top_kmers = [kmer_list[i] for i in top_indices]
 
233
 
234
- fig = plt.figure(figsize=(8, 6))
235
- plt.barh(range(len(top_impacts)), top_impacts, color=['green' if i > 0 else 'red' for i in top_impacts])
236
- plt.yticks(range(len(top_impacts)), top_kmers)
237
- plt.xlabel("Impact on Human Probability (Ablation)")
238
- plt.title(f"Top {top_k} K-mers by Absolute Impact")
239
- plt.gca().invert_yaxis() # Highest at top
240
- plt.tight_layout()
241
- return fig
242
-
243
- def create_global_bar_plot(impact_values, kmer_list):
244
- """
245
- Creates a bar plot for ALL features (256) to see the global distribution.
246
- """
247
- fig = plt.figure(figsize=(12, 6))
248
- indices_sorted = np.argsort(np.abs(impact_values))[::-1]
249
- sorted_impacts = impact_values[indices_sorted]
250
- sorted_kmers = [kmer_list[i] for i in indices_sorted]
251
 
252
- plt.bar(range(len(sorted_impacts)), sorted_impacts,
253
- color=['green' if i > 0 else 'red' for i in sorted_impacts])
254
- plt.title("Global Impact of All 256 K-mers (Ablation Method)")
255
- plt.xlabel("K-mer (sorted by |impact|)")
256
- plt.ylabel("Impact on Human Probability")
257
- # Optionally, we can skip labeling all 256 on x-axis.
258
- # But we can show only the top/bottom or none for clarity.
259
- plt.tight_layout()
260
  return fig
261
 
262
- ##############################################################################
263
- # MAIN PREDICTION FUNCTION
264
- ##############################################################################
265
-
266
- def predict(file_obj, top_kmers=10, advanced_plots=False, fasta_text=""):
267
  """
268
- Main prediction function called by Gradio.
269
- - file_obj: optional uploaded FASTA file
270
- - top_kmers: number of top k-mers to display in the main SHAP-like plot
271
- - advanced_plots: bool, whether to return global bar plots
272
- - fasta_text: optional direct-pasted FASTA text
273
  """
274
- # Priority: If user pasted text, use that; otherwise use uploaded file.
275
  if fasta_text.strip():
276
  text = fasta_text.strip()
277
- else:
278
- if file_obj is None:
279
- return "No FASTA input provided", None, None, None
280
  try:
281
- if isinstance(file_obj, str):
282
- text = file_obj
283
- else:
284
- text = file_obj.decode('utf-8')
285
  except Exception as e:
286
- return f"Error reading file: {str(e)}", None, None, None
 
 
287
 
288
  # Parse FASTA
289
  sequences = parse_fasta(text)
290
- if len(sequences) == 0:
291
- return "No valid FASTA sequences found", None, None, None
 
292
  header, seq = sequences[0]
293
 
294
- # Load model + scaler
295
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
296
- model = VirusClassifier(256).to(device)
297
  try:
298
- state_dict = torch.load('model.pt', map_location=device)
299
- model.load_state_dict(state_dict)
300
  scaler = joblib.load('scaler.pkl')
301
  except Exception as e:
302
- return f"Error loading model or scaler: {str(e)}", None, None, None
303
 
304
- # Prepare the vector
305
- raw_freq_vector = sequence_to_kmer_vector(seq, k=4)
306
- scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
307
- X_tensor = torch.FloatTensor(scaled_vector).to(device)
308
 
309
- # Compute ablation-based importances
310
- importances, p_base = ablation_importance(model, X_tensor)
311
- # p_base is baseline human probability
312
-
313
- # We also want frequency in % and sigma from mean
314
- # If your scaler is e.g. StandardScaler, then "scaled_vector[0][i]" is
315
- # how many std devs from the mean that feature is.
316
- # We'll gather info in a list of dicts for each k-mer.
317
- kmers_4 = [''.join(p) for p in product("ACGT", repeat=4)]
318
- kmer_dict = {km: i for i, km in enumerate(kmers_4)}
319
-
320
- # We'll sort by absolute impact to get the top 10 by default.
321
- abs_sorted_idx = np.argsort(np.abs(importances))[::-1]
322
- # But for the final step/frequency plot we only show top_kmers
323
- top_indices = abs_sorted_idx[:top_kmers]
324
-
325
- # Build a list of the top k-mers
326
  important_kmers = []
327
- for idx in top_indices:
328
- # "impact" is how much that feature changed the probability
329
- impact = importances[idx]
330
- # raw frequency => raw_freq_vector[idx] * 100 for %
331
- freq_pct = float(raw_freq_vector[idx] * 100.0)
332
- # sigma => scaled_vector[0][idx]
333
- sigma_val = float(scaled_vector[0][idx])
334
-
335
  important_kmers.append({
336
- 'kmer': kmers_4[idx],
337
- 'impact': impact,
338
- 'occurrence': freq_pct,
339
- 'sigma': sigma_val
340
  })
341
-
342
- # For text output
343
- # We decide final class based on model's direct output
344
- with torch.no_grad():
345
- output = model(X_tensor)
346
- probs = torch.softmax(output, dim=1)
347
- pred_class = 1 if probs[0,1] > probs[0,0] else 0
348
- pred_label = 'human' if pred_class == 1 else 'non-human'
349
- human_prob = probs[0,1].item()
350
- nonhuman_prob = probs[0,0].item()
351
- confidence = max(human_prob, nonhuman_prob)
352
 
353
- results_text = (f"Sequence: {header}\n"
354
- f"Prediction: {pred_label}\n"
355
- f"Confidence: {confidence:.4f}\n"
356
- f"Human probability: {human_prob:.4f}\n"
357
- f"Non-human probability: {nonhuman_prob:.4f}\n"
358
- f"Most influential k-mers (by ablation impact):\n")
 
 
359
 
360
- for kmer_info in important_kmers:
361
- # sign => if impact>0 => removing it lowers p(human), so it was pushing p(human) up
362
- direction = "UP (toward human)" if kmer_info['impact'] > 0 else "DOWN (toward non-human)"
363
- results_text += (
364
- f" {kmer_info['kmer']}: {direction}, "
365
- f"Impact={kmer_info['impact']:.4f}, "
366
- f"Occ={kmer_info['occurrence']:.2f}% of seq, "
367
- f"{abs(kmer_info['sigma']):.2f}σ "
368
- + ("above" if kmer_info['sigma']>0 else "below")
369
- + " mean\n"
370
  )
371
 
372
- # PLOT 1: A SHAP-like bar plot for the top K features
373
- shap_fig = create_shap_like_bar_plot(importances, kmers_4, top_kmers)
374
-
375
- # PLOT 2: Step + frequency plot for the top K features
376
- step_fig = create_step_and_frequency_plot(important_kmers, human_prob, header)
377
-
378
- # PLOT 3 (optional advanced): global bar plot of all 256 features
379
- global_fig = None
380
- if advanced_plots:
381
- global_fig = create_global_bar_plot(importances, kmers_4)
382
-
383
- # Convert figures to PIL Images
384
  def fig_to_image(fig):
385
  buf = io.BytesIO()
386
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=200)
387
  buf.seek(0)
388
- im = Image.open(buf)
389
  plt.close(fig)
390
- return im
391
-
392
- shap_img = fig_to_image(shap_fig)
393
- step_img = fig_to_image(step_fig)
394
- if global_fig is not None:
395
- global_img = fig_to_image(global_fig)
396
- else:
397
- global_img = None
398
-
399
- return results_text, shap_img, step_img, global_img
400
-
401
- ##############################################################################
402
- # GRADIO INTERFACE
403
- ##############################################################################
404
-
405
- title_text = "Virus Host Classifier"
406
- description_text = """
407
- Upload or paste a FASTA sequence to predict if it's likely **human** or **non-human** origin.
408
- - **k=4** k-mers are used as features.
409
- - We display ablation-based feature importance for interpretability.
410
- - Advanced plots can be toggled to see the global distribution of all 256 k-mer impacts.
411
  """
412
 
413
- iface = gr.Interface(
414
- fn=predict,
415
- inputs=[
416
- gr.File(label="Upload FASTA file", type="binary", optional=True),
417
- gr.Slider(label="Number of top k-mers to show", minimum=1, maximum=50, value=10, step=1),
418
- gr.Checkbox(label="Show advanced (global) plots?", value=False),
419
- gr.Textbox(label="Or paste FASTA text here", lines=5, placeholder=">header\nACGTACGT...")
420
- ],
421
- outputs=[
422
- gr.Textbox(label="Results", lines=10),
423
- gr.Image(label="SHAP-like Top-k K-mer Bar Plot"),
424
- gr.Image(label="Step & Frequency Plot (Top-k)"),
425
- gr.Image(label="Global 256-K-mer Plot (advanced)", optional=True)
426
- ],
427
- title=title_text,
428
- description=description_text
429
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  if __name__ == "__main__":
432
- iface.launch(share=True)
 
8
  import io
9
  from PIL import Image
10
 
 
 
 
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
13
  super(VirusClassifier, self).__init__()
 
28
  def forward(self, x):
29
  return self.network(x)
30
 
 
 
 
 
31
  def parse_fasta(text):
32
  """
33
  Parses FASTA formatted text into a list of (header, sequence).
 
53
 
54
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
55
  """
56
+ Convert a sequence to a k-mer frequency vector.
57
  """
58
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
59
  kmer_dict = {km: i for i, km in enumerate(kmers)}
 
70
 
71
  return vec
72
 
73
+ def calculate_shap_values(model, x_tensor):
74
  """
75
+ Calculate SHAP-like values using a simple ablation approach.
 
 
 
 
 
76
  """
77
  model.eval()
78
  with torch.no_grad():
79
+ baseline_output = model(x_tensor)
80
+ baseline_prob = torch.softmax(baseline_output, dim=1)[0, 1].item()
81
+
82
+ shap_values = []
83
+ for i in range(x_tensor.shape[1]):
84
+ perturbed_input = x_tensor.clone()
85
+ perturbed_input[0, i] = 0 # Ablate feature
86
+ output = model(perturbed_input)
87
+ prob = torch.softmax(output, dim=1)[0, 1].item()
88
+ shap_values.append(baseline_prob - prob)
89
+
90
+ return np.array(shap_values), baseline_prob
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ def create_importance_plot(shap_values, kmers, top_k=10):
93
  """
94
+ Create horizontal bar plot of feature importance.
 
95
  """
96
+ plt.style.use('seaborn')
97
+ fig = plt.figure(figsize=(10, 8))
98
 
99
+ # Sort by absolute importance
100
+ indices = np.argsort(np.abs(shap_values))[-top_k:]
101
+ values = shap_values[indices]
102
+ features = [kmers[i] for i in indices]
103
 
104
+ colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ plt.barh(range(len(values)), values, color=colors)
107
+ plt.yticks(range(len(values)), features)
108
+ plt.xlabel('Impact on Prediction (SHAP value)')
109
+ plt.title(f'Top {top_k} Most Influential k-mers')
110
+ plt.gca().invert_yaxis()
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return fig
113
 
114
+ def create_contribution_plot(important_kmers, final_prob):
115
  """
116
+ Create waterfall plot showing cumulative feature contributions.
 
 
 
117
  """
118
+ plt.style.use('seaborn')
119
+ fig = plt.figure(figsize=(12, 6))
 
120
 
121
+ base_prob = 0.5
122
+ cumulative = [base_prob]
123
+ labels = ['Base']
124
 
125
+ for kmer_info in important_kmers:
126
+ cumulative.append(cumulative[-1] + kmer_info['impact'])
127
+ labels.append(kmer_info['kmer'])
128
+
129
+ plt.plot(range(len(cumulative)), cumulative, 'b-o', linewidth=2)
130
+ plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
131
+
132
+ plt.xticks(range(len(labels)), labels, rotation=45)
133
+ plt.ylim(0, 1)
134
+ plt.grid(True, alpha=0.3)
135
+ plt.title('Cumulative Feature Contributions')
136
+ plt.ylabel('Probability of Human Origin')
 
 
 
 
 
137
 
 
 
 
 
 
 
 
 
138
  return fig
139
 
140
+ def predict(file_obj, top_kmers=10, fasta_text=""):
 
 
 
 
141
  """
142
+ Main prediction function for the Gradio interface.
 
 
 
 
143
  """
144
+ # Handle input
145
  if fasta_text.strip():
146
  text = fasta_text.strip()
147
+ elif file_obj is not None:
 
 
148
  try:
149
+ text = file_obj.decode('utf-8')
 
 
 
150
  except Exception as e:
151
+ return f"Error reading file: {str(e)}", None, None
152
+ else:
153
+ return "Please provide a FASTA sequence either by file upload or text input.", None, None
154
 
155
  # Parse FASTA
156
  sequences = parse_fasta(text)
157
+ if not sequences:
158
+ return "No valid FASTA sequences found in input.", None, None
159
+
160
  header, seq = sequences[0]
161
 
162
+ # Process sequence
163
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
164
  try:
165
+ model = VirusClassifier(256).to(device)
166
+ model.load_state_dict(torch.load('model.pt', map_location=device))
167
  scaler = joblib.load('scaler.pkl')
168
  except Exception as e:
169
+ return f"Error loading model: {str(e)}", None, None
170
 
171
+ # Generate features
172
+ freq_vector = sequence_to_kmer_vector(seq)
173
+ scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
174
+ x_tensor = torch.FloatTensor(scaled_vector).to(device)
175
 
176
+ # Calculate SHAP values and predictions
177
+ shap_values, human_prob = calculate_shap_values(model, x_tensor)
178
+
179
+ # Generate k-mer information
180
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
181
+ important_indices = np.argsort(np.abs(shap_values))[-top_kmers:]
182
+
 
 
 
 
 
 
 
 
 
 
183
  important_kmers = []
184
+ for idx in important_indices:
 
 
 
 
 
 
 
185
  important_kmers.append({
186
+ 'kmer': kmers[idx],
187
+ 'impact': shap_values[idx],
188
+ 'frequency': freq_vector[idx] * 100,
189
+ 'significance': scaled_vector[0][idx]
190
  })
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # Format results text
193
+ results = [
194
+ f"Sequence: {header}",
195
+ f"Prediction: {'Human' if human_prob > 0.5 else 'Non-human'} Origin",
196
+ f"Confidence: {max(human_prob, 1-human_prob):.3f}",
197
+ f"Human Probability: {human_prob:.3f}",
198
+ "\nTop Contributing k-mers:",
199
+ ]
200
 
201
+ for kmer in important_kmers:
202
+ direction = "→ Human" if kmer['impact'] > 0 else "→ Non-human"
203
+ results.append(
204
+ f"• {kmer['kmer']}: {direction} "
205
+ f"(impact: {kmer['impact']:.3f}, "
206
+ f"freq: {kmer['frequency']:.2f}%)"
 
 
 
 
207
  )
208
 
209
+ # Generate plots
210
+ shap_plot = create_importance_plot(shap_values, kmers, top_kmers)
211
+ contribution_plot = create_contribution_plot(important_kmers, human_prob)
212
+
213
+ # Convert plots to images
 
 
 
 
 
 
 
214
  def fig_to_image(fig):
215
  buf = io.BytesIO()
216
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
217
  buf.seek(0)
218
+ img = Image.open(buf)
219
  plt.close(fig)
220
+ return img
221
+
222
+ return "\n".join(results), fig_to_image(shap_plot), fig_to_image(contribution_plot)
223
+
224
+ # Create Gradio interface
225
+ css = """
226
+ .gradio-container {
227
+ font-family: 'IBM Plex Sans', sans-serif;
228
+ }
229
+ .interpretation-container {
230
+ margin-top: 20px;
231
+ padding: 15px;
232
+ border-radius: 8px;
233
+ background-color: #f8f9fa;
234
+ }
 
 
 
 
 
 
235
  """
236
 
237
+ with gr.Blocks(css=css) as iface:
238
+ gr.Markdown("""
239
+ # Virus Host Classifier
240
+ This tool predicts whether a viral sequence is likely of human or non-human origin using k-mer frequency analysis.
241
+
242
+ ### Instructions
243
+ 1. Upload a FASTA file or paste your sequence in FASTA format
244
+ 2. Adjust the number of top k-mers to display (default: 10)
245
+ 3. View the prediction results and feature importance visualizations
246
+ """)
247
+
248
+ with gr.Row():
249
+ with gr.Column(scale=1):
250
+ file_input = gr.File(
251
+ label="Upload FASTA file",
252
+ file_types=[".fasta", ".fa", ".txt"]
253
+ )
254
+ text_input = gr.Textbox(
255
+ label="Or paste FASTA sequence",
256
+ placeholder=">sequence_name\nACGTACGT...",
257
+ lines=5
258
+ )
259
+ top_k = gr.Slider(
260
+ minimum=5,
261
+ maximum=20,
262
+ value=10,
263
+ step=1,
264
+ label="Number of top k-mers to display"
265
+ )
266
+ submit_btn = gr.Button("Analyze Sequence", variant="primary")
267
+
268
+ with gr.Column(scale=2):
269
+ results = gr.Textbox(label="Analysis Results", lines=10)
270
+ shap_plot = gr.Image(label="Feature Importance Plot")
271
+ contribution_plot = gr.Image(label="Cumulative Contribution Plot")
272
+
273
+ submit_btn.click(
274
+ predict,
275
+ inputs=[file_input, top_k, text_input],
276
+ outputs=[results, shap_plot, contribution_plot]
277
+ )
278
+
279
+ gr.Markdown("""
280
+ ### About
281
+ - Uses 4-mer frequencies as sequence features
282
+ - Employs SHAP-like values for feature importance interpretation
283
+ - Visualizes cumulative feature contributions to the final prediction
284
+ """)
285
 
286
  if __name__ == "__main__":
287
+ iface.launch()