hiyata commited on
Commit
8c49ca8
·
verified ·
1 Parent(s): 6c88c65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -138
app.py CHANGED
@@ -8,6 +8,9 @@ import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
 
 
 
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
13
  super(VirusClassifier, self).__init__()
@@ -29,38 +32,28 @@ class VirusClassifier(nn.Module):
29
  return self.network(x)
30
 
31
  def get_feature_importance(self, x):
32
- """Calculate feature importance using gradient-based method"""
 
 
 
33
  x.requires_grad_(True)
34
  output = self.network(x)
35
  probs = torch.softmax(output, dim=1)
36
 
37
- # Get importance for human class (index 1)
38
  human_prob = probs[..., 1]
39
  if x.grad is not None:
40
  x.grad.zero_()
41
  human_prob.backward()
42
- importance = x.grad
43
 
44
  return importance, float(human_prob)
45
 
46
- def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
47
- """Convert sequence to k-mer frequency vector"""
48
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
49
- kmer_dict = {km: i for i, km in enumerate(kmers)}
50
- vec = np.zeros(len(kmers), dtype=np.float32)
51
-
52
- for i in range(len(sequence) - k + 1):
53
- kmer = sequence[i:i+k]
54
- if kmer in kmer_dict:
55
- vec[kmer_dict[kmer]] += 1
56
-
57
- total_kmers = len(sequence) - k + 1
58
- if total_kmers > 0:
59
- vec = vec / total_kmers
60
-
61
- return vec
62
-
63
  def parse_fasta(text):
 
64
  sequences = []
65
  current_header = None
66
  current_sequence = []
@@ -80,98 +73,167 @@ def parse_fasta(text):
80
  sequences.append((current_header, ''.join(current_sequence)))
81
  return sequences
82
 
83
- def create_visualization(important_kmers, human_prob, title):
84
- """Create a comprehensive visualization of k-mer impacts"""
85
- fig = plt.figure(figsize=(15, 10))
86
-
87
- # Create grid for subplots
88
- gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
89
 
90
- # 1. Probability Step Plot
91
- ax1 = plt.subplot(gs[0])
92
- current_prob = 0.5
93
- steps = [('Start', current_prob, 0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
95
  for kmer in important_kmers:
96
- change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
 
 
97
  current_prob += change
98
- steps.append((kmer['kmer'], current_prob, change))
99
-
100
- x = range(len(steps))
101
- y = [step[1] for step in steps]
102
-
103
- # Plot steps
104
- ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
105
- ax1.plot(x, y, 'b.', markersize=10)
106
-
107
- # Add reference line
108
- ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
109
-
110
- # Customize plot
111
- ax1.grid(True, linestyle='--', alpha=0.7)
112
- ax1.set_ylim(0, 1)
113
- ax1.set_ylabel('Human Probability')
114
- ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
115
 
116
- # Add labels for each point
 
 
 
 
 
 
 
 
 
 
117
  for i, (kmer, prob, change) in enumerate(steps):
118
- # Add k-mer label
119
- ax1.annotate(kmer,
120
- (i, prob),
121
- xytext=(0, 10 if i % 2 == 0 else -20),
122
- textcoords='offset points',
123
- ha='center',
124
- rotation=45)
125
 
126
- # Add change value
127
- if i > 0:
128
- change_text = f'{change:+.3f}'
129
- color = 'green' if change > 0 else 'red'
130
- ax1.annotate(change_text,
131
- (i, prob),
132
- xytext=(0, -20 if i % 2 == 0 else 10),
133
- textcoords='offset points',
134
- ha='center',
135
- color=color)
136
-
137
- ax1.legend()
138
-
139
- # 2. K-mer Frequency and Sigma Plot
140
- ax2 = plt.subplot(gs[1])
141
-
142
- # Prepare data
143
- kmers = [k['kmer'] for k in important_kmers]
144
- frequencies = [k['occurrence'] for k in important_kmers]
145
- sigmas = [k['sigma'] for k in important_kmers]
146
- colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
 
 
 
 
 
 
147
 
148
- # Create bar plot for frequencies
149
  x = np.arange(len(kmers))
150
- width = 0.35
151
-
152
- ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
153
- ax2_twin = ax2.twinx()
154
- ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
155
-
156
- # Customize plot
157
- ax2.set_xticks(x)
158
- ax2.set_xticklabels(kmers, rotation=45)
159
- ax2.set_ylabel('Frequency (%)')
160
- ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
161
- ax2.set_title('K-mer Frequencies and Statistical Significance')
 
 
 
 
 
162
 
163
- # Add legends
164
- lines1, labels1 = ax2.get_legend_handles_labels()
165
- lines2, labels2 = ax2_twin.get_legend_handles_labels()
166
- ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  plt.tight_layout()
169
  return fig
170
 
 
 
 
 
171
  def predict(file_obj):
 
 
 
 
 
 
 
172
  if file_obj is None:
173
- return "Please upload a FASTA file", None
174
 
 
175
  try:
176
  if isinstance(file_obj, str):
177
  text = file_obj
@@ -180,10 +242,12 @@ def predict(file_obj):
180
  except Exception as e:
181
  return f"Error reading file: {str(e)}", None
182
 
 
183
  k = 4
184
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
185
  kmer_dict = {km: i for i, km in enumerate(kmers)}
186
 
 
187
  try:
188
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
189
  model = VirusClassifier(256).to(device)
@@ -192,92 +256,117 @@ def predict(file_obj):
192
  scaler = joblib.load('scaler.pkl')
193
  model.eval()
194
  except Exception as e:
195
- return f"Error loading model: {str(e)}", None
196
 
197
  results_text = ""
198
  plot_image = None
199
 
200
  try:
 
201
  sequences = parse_fasta(text)
202
- header, seq = sequences[0]
 
203
 
 
 
 
204
  raw_freq_vector = sequence_to_kmer_vector(seq)
205
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
206
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
207
-
208
- # Get model predictions
209
  with torch.no_grad():
210
  output = model(X_tensor)
211
  probs = torch.softmax(output, dim=1)
212
 
213
- # Get feature importance
214
- importance, _ = model.get_feature_importance(X_tensor)
215
- kmer_importance = importance[0].cpu().numpy()
216
-
217
- # Get top k-mers
218
  top_k = 10
219
- top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
220
-
221
  important_kmers = []
 
222
  for idx in top_indices:
223
- kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
224
- imp = float(abs(kmer_importance[idx]))
 
 
 
 
 
225
  direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
226
- freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
227
- sigma = float(kmer_vector[0][idx])
228
 
229
  important_kmers.append({
230
- 'kmer': kmer,
231
- 'impact': imp,
232
  'direction': direction,
233
  'occurrence': freq,
234
  'sigma': sigma
235
  })
236
-
237
- # Generate text results
238
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
239
  pred_label = 'human' if pred_class == 1 else 'non-human'
240
  human_prob = float(probs[0][1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- results_text = f"""Sequence: {header}
243
- Prediction: {pred_label}
244
- Confidence: {float(max(probs[0])):0.4f}
245
- Human probability: {human_prob:0.4f}
246
- Non-human probability: {float(probs[0][0]):0.4f}
247
- Most influential k-mers (ranked by importance):"""
248
-
249
- for kmer in important_kmers:
250
- results_text += f"\n {kmer['kmer']}: "
251
- results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
252
- results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
253
- results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
254
- results_text += "more" if kmer['sigma'] > 0 else "less"
255
- results_text += " than average)"
256
-
257
- # Create visualization
258
- fig = create_visualization(important_kmers, human_prob, header)
259
-
260
- # Save plot
261
  buf = io.BytesIO()
262
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
263
  buf.seek(0)
264
  plot_image = Image.open(buf)
265
  plt.close(fig)
266
 
267
  except Exception as e:
268
- return f"Error processing sequences: {str(e)}", None
269
 
270
  return results_text, plot_image
271
 
 
 
 
272
  iface = gr.Interface(
273
  fn=predict,
274
  inputs=gr.File(label="Upload FASTA file", type="binary"),
275
  outputs=[
276
- gr.Textbox(label="Results"),
277
  gr.Image(label="K-mer Analysis Visualization")
278
  ],
279
- title="Virus Host Classifier"
 
 
 
 
 
 
280
  )
281
 
282
  if __name__ == "__main__":
283
- iface.launch(share=True)
 
8
  import io
9
  from PIL import Image
10
 
11
+ ###############################################################################
12
+ # Model Definition
13
+ ###############################################################################
14
  class VirusClassifier(nn.Module):
15
  def __init__(self, input_shape: int):
16
  super(VirusClassifier, self).__init__()
 
32
  return self.network(x)
33
 
34
  def get_feature_importance(self, x):
35
+ """
36
+ Calculate gradient-based feature importance.
37
+ We'll compute the gradient of the 'human' probability w.r.t. the input vector.
38
+ """
39
  x.requires_grad_(True)
40
  output = self.network(x)
41
  probs = torch.softmax(output, dim=1)
42
 
43
+ # Gradient wrt 'human' class probability (index=1)
44
  human_prob = probs[..., 1]
45
  if x.grad is not None:
46
  x.grad.zero_()
47
  human_prob.backward()
48
+ importance = x.grad # shape: (batch_size, n_features)
49
 
50
  return importance, float(human_prob)
51
 
52
+ ###############################################################################
53
+ # Utility Functions
54
+ ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def parse_fasta(text):
56
+ """Parses text input in FASTA format into a list of (header, sequence)."""
57
  sequences = []
58
  current_header = None
59
  current_sequence = []
 
73
  sequences.append((current_header, ''.join(current_sequence)))
74
  return sequences
75
 
76
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
77
+ """Convert a single nucleotide sequence to a k-mer frequency vector."""
78
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
79
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
80
+ vec = np.zeros(len(kmers), dtype=np.float32)
 
81
 
82
+ for i in range(len(sequence) - k + 1):
83
+ kmer = sequence[i:i+k]
84
+ if kmer in kmer_dict:
85
+ vec[kmer_dict[kmer]] += 1
86
+
87
+ total_kmers = len(sequence) - k + 1
88
+ if total_kmers > 0:
89
+ vec = vec / total_kmers # normalize frequencies
90
+
91
+ return vec
92
+
93
+
94
+ ###############################################################################
95
+ # Visualization
96
+ ###############################################################################
97
+ def create_visualization(important_kmers, human_prob, title):
98
+ """
99
+ Create a multi-panel figure showing:
100
+ 1) A waterfall-like plot for how each top k-mer shifts the probability from 0.5
101
+ (the baseline) to the final 'human' probability.
102
+ 2) A side-by-side bar plot for frequency (%) and σ from mean for each important k-mer.
103
+ """
104
+
105
+ # Figure & GridSpec Layout
106
+ fig = plt.figure(figsize=(14, 10))
107
+ gs = plt.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1.2, 1], hspace=0.35, wspace=0.3)
108
+
109
+ # -------------------------------------------------------------------------
110
+ # 1. Waterfall-like Plot (top-left subplot)
111
+ # -------------------------------------------------------------------------
112
+ ax_waterfall = plt.subplot(gs[0, 0])
113
+
114
+ # Start from baseline prob=0.5
115
+ baseline = 0.5
116
+ current_prob = baseline
117
+ steps = [("Baseline", current_prob, 0.0)]
118
 
119
+ # Build up the step changes
120
  for kmer in important_kmers:
121
+ direction_multiplier = 1 if kmer["direction"] == "human" else -1
122
+ change = kmer["impact"] * 0.05 * direction_multiplier
123
+ # ^ scale changes so that the sum doesn't overshadow the final probability.
124
  current_prob += change
125
+ steps.append((kmer["kmer"], current_prob, change))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # X-values for step plot
128
+ x_vals = range(len(steps))
129
+ y_vals = [s[1] for s in steps]
130
+
131
+ ax_waterfall.step(x_vals, y_vals, where='post', color='blue', linewidth=2, label='Probability')
132
+ ax_waterfall.plot(x_vals, y_vals, 'b.', markersize=8)
133
+
134
+ # Reference lines
135
+ ax_waterfall.axhline(y=baseline, color='gray', linestyle='--', label='Baseline=0.5')
136
+
137
+ # Annotate each step
138
  for i, (kmer, prob, change) in enumerate(steps):
139
+ if i == 0: # baseline
140
+ ax_waterfall.annotate(kmer, (i, prob), textcoords="offset points", xytext=(0, -15), ha='center', color='black')
141
+ continue
 
 
 
 
142
 
143
+ color = "green" if change > 0 else "red"
144
+ ax_waterfall.annotate(
145
+ f"{kmer}\n({change:+.3f})",
146
+ (i, prob),
147
+ textcoords="offset points",
148
+ xytext=(0, -15),
149
+ ha='center',
150
+ color=color,
151
+ fontsize=9
152
+ )
153
+
154
+ ax_waterfall.set_ylim(0, 1)
155
+ ax_waterfall.set_xlabel("k-mer Step")
156
+ ax_waterfall.set_ylabel("Running Probability (Human)")
157
+ ax_waterfall.set_title(f"K-mer Waterfall Plot — Final Probability: {human_prob:.3f}")
158
+ ax_waterfall.grid(alpha=0.3)
159
+ ax_waterfall.legend()
160
+
161
+ # -------------------------------------------------------------------------
162
+ # 2. Frequency & σ from Mean (top-right subplot)
163
+ # -------------------------------------------------------------------------
164
+ ax_bar = plt.subplot(gs[0, 1])
165
+
166
+ kmers = [k["kmer"] for k in important_kmers]
167
+ frequencies = [k["occurrence"] for k in important_kmers] # in %
168
+ sigmas = [k["sigma"] for k in important_kmers]
169
+ directions = [k["direction"] for k in important_kmers]
170
 
171
+ # X-locations
172
  x = np.arange(len(kmers))
173
+ width = 0.4
174
+
175
+ # We will create twin axes: one for frequency, one for σ
176
+ bars1 = ax_bar.bar(x - width/2, frequencies, width, label='Frequency (%)',
177
+ alpha=0.7, color=['green' if d=='human' else 'red' for d in directions])
178
+ ax_bar.set_ylabel("Frequency (%)")
179
+ ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
180
+ ax_bar.set_title("Frequency vs. σ from Mean")
181
+
182
+ # Twin axis for σ
183
+ ax_bar_twin = ax_bar.twinx()
184
+ bars2 = ax_bar_twin.bar(x + width/2, sigmas, width, label='σ from Mean',
185
+ alpha=0.5, color='gray')
186
+ ax_bar_twin.set_ylabel("Standard Deviations (σ)")
187
+
188
+ ax_bar.set_xticks(x)
189
+ ax_bar.set_xticklabels(kmers, rotation=45, ha='right', fontsize=9)
190
 
191
+ # Combine legends
192
+ lines1, labels1 = ax_bar.get_legend_handles_labels()
193
+ lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
194
+ ax_bar.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
195
 
196
+ # -------------------------------------------------------------------------
197
+ # 3. Top Feature Importances (Bottom, spanning both columns)
198
+ # -------------------------------------------------------------------------
199
+ ax_imp = plt.subplot(gs[1, :])
200
+
201
+ # Sort by absolute impact
202
+ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
203
+ top_kmer_labels = [k['kmer'] for k in sorted_kmers]
204
+ top_kmer_impacts = [k['impact'] for k in sorted_kmers]
205
+ top_kmer_dirs = [k['direction'] for k in sorted_kmers]
206
+
207
+ x_imp = np.arange(len(top_kmer_impacts))
208
+ bar_colors = ['green' if d == 'human' else 'red' for d in top_kmer_dirs]
209
+
210
+ ax_imp.bar(x_imp, top_kmer_impacts, color=bar_colors, alpha=0.7)
211
+ ax_imp.set_xticks(x_imp)
212
+ ax_imp.set_xticklabels(top_kmer_labels, rotation=45, ha='right', fontsize=9)
213
+ ax_imp.set_title("Absolute Feature Importance (Top k-mers)")
214
+ ax_imp.set_ylabel("Importance (gradient magnitude)")
215
+ ax_imp.grid(alpha=0.3, axis='y')
216
+
217
+ plt.suptitle(title, fontsize=14, y=1.02)
218
  plt.tight_layout()
219
  return fig
220
 
221
+
222
+ ###############################################################################
223
+ # Prediction Function
224
+ ###############################################################################
225
  def predict(file_obj):
226
+ """
227
+ Main function that Gradio will call:
228
+ 1. Reads the uploaded FASTA file (or text).
229
+ 2. Loads the model and scaler.
230
+ 3. Generates predictions, probabilities, and top k-mers.
231
+ 4. Creates a summary text and a matplotlib figure for visualization.
232
+ """
233
  if file_obj is None:
234
+ return "Please upload a FASTA file.", None
235
 
236
+ # Read text from file
237
  try:
238
  if isinstance(file_obj, str):
239
  text = file_obj
 
242
  except Exception as e:
243
  return f"Error reading file: {str(e)}", None
244
 
245
+ # Build k-mer dictionary
246
  k = 4
247
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
248
  kmer_dict = {km: i for i, km in enumerate(kmers)}
249
 
250
+ # Load model & scaler
251
  try:
252
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
253
  model = VirusClassifier(256).to(device)
 
256
  scaler = joblib.load('scaler.pkl')
257
  model.eval()
258
  except Exception as e:
259
+ return f"Error loading model or scaler: {str(e)}", None
260
 
261
  results_text = ""
262
  plot_image = None
263
 
264
  try:
265
+ # Parse FASTA
266
  sequences = parse_fasta(text)
267
+ if len(sequences) == 0:
268
+ return "No valid FASTA sequences found. Please check your input.", None
269
 
270
+ header, seq = sequences[0] # For simplicity, we'll only classify the first sequence
271
+
272
+ # Transform sequence to scaled k-mer vector
273
  raw_freq_vector = sequence_to_kmer_vector(seq)
274
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
275
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
276
+
277
+ # Inference
278
  with torch.no_grad():
279
  output = model(X_tensor)
280
  probs = torch.softmax(output, dim=1)
281
 
282
+ # Feature Importance
283
+ importance, hum_prob_grad = model.get_feature_importance(X_tensor)
284
+ kmer_importance = importance[0].cpu().numpy() # shape: (256,)
285
+
286
+ # Top k-mers by absolute importance
287
  top_k = 10
288
+ top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1] # largest -> smallest
 
289
  important_kmers = []
290
+
291
  for idx in top_indices:
292
+ # find corresponding k-mer by index
293
+ for kmer_str, i_ in kmer_dict.items():
294
+ if i_ == idx:
295
+ kmer_name = kmer_str
296
+ break
297
+
298
+ imp_val = float(abs(kmer_importance[idx]))
299
  direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
300
+ freq = float(raw_freq_vector[idx] * 100) # frequency in %
301
+ sigma = float(kmer_vector[0][idx]) # scaled value (Z-score if standard scaler)
302
 
303
  important_kmers.append({
304
+ 'kmer': kmer_name,
305
+ 'impact': imp_val,
306
  'direction': direction,
307
  'occurrence': freq,
308
  'sigma': sigma
309
  })
310
+
 
311
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
312
  pred_label = 'human' if pred_class == 1 else 'non-human'
313
  human_prob = float(probs[0][1])
314
+ non_human_prob = float(probs[0][0])
315
+ conf = float(max(probs[0])) # confidence in the predicted class
316
+
317
+ # Generate text results
318
+ results_text = (
319
+ f"**Sequence Header**: {header}\n\n"
320
+ f"**Predicted Label**: {pred_label}\n"
321
+ f"**Confidence**: {conf:.4f}\n\n"
322
+ f"**Human Probability**: {human_prob:.4f}\n"
323
+ f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
324
+ "### Most Influential k-mers:\n"
325
+ )
326
+ for k in important_kmers:
327
+ direction_text = f"pushes toward {k['direction']}"
328
+ occurrence_text = f"{k['occurrence']:.2f}% of sequence"
329
+ sigma_text = f"{abs(k['sigma']):.2f}σ " + ("above" if k['sigma'] > 0 else "below") + " mean"
330
+ results_text += (
331
+ f"- **{k['kmer']}**: "
332
+ f"impact = {k['impact']:.4f}, {direction_text}, "
333
+ f"occurrence = {occurrence_text}, "
334
+ f"({sigma_text})\n"
335
+ )
336
+
337
+ # Create figure
338
+ fig = create_visualization(important_kmers, human_prob, f"{header}")
339
 
340
+ # Convert figure to image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  buf = io.BytesIO()
342
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
343
  buf.seek(0)
344
  plot_image = Image.open(buf)
345
  plt.close(fig)
346
 
347
  except Exception as e:
348
+ return f"Error during prediction or visualization: {str(e)}", None
349
 
350
  return results_text, plot_image
351
 
352
+ ###############################################################################
353
+ # Gradio Interface
354
+ ###############################################################################
355
  iface = gr.Interface(
356
  fn=predict,
357
  inputs=gr.File(label="Upload FASTA file", type="binary"),
358
  outputs=[
359
+ gr.Markdown(label="Prediction Results"),
360
  gr.Image(label="K-mer Analysis Visualization")
361
  ],
362
+ title="Virus Host Classifier",
363
+ description=(
364
+ "Upload a FASTA file containing a single nucleotide sequence. "
365
+ "This model will predict whether the virus host is **human** or **non-human**, "
366
+ "provide a confidence score, and highlight the most influential k-mers in the classification."
367
+ ),
368
+ allow_flagging="never",
369
  )
370
 
371
  if __name__ == "__main__":
372
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)