hiyata commited on
Commit
a6886ca
·
verified ·
1 Parent(s): b5edb58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -155
app.py CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
2
  import torch
3
  import joblib
4
  import numpy as np
5
- from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
 
 
 
10
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
@@ -28,39 +30,40 @@ class VirusClassifier(nn.Module):
28
  def forward(self, x):
29
  return self.network(x)
30
 
31
- def get_feature_importance(self, x):
32
- """Calculate feature importance using gradient-based method"""
33
- x.requires_grad_(True)
 
 
 
 
34
  output = self.network(x)
35
  probs = torch.softmax(output, dim=1)
36
 
37
- # Get importance for human class (index 1)
38
- human_prob = probs[..., 1]
 
 
39
  if x.grad is not None:
40
  x.grad.zero_()
41
- human_prob.backward()
42
- importance = x.grad
43
 
44
- return importance, float(human_prob)
45
-
46
- def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
47
- """Convert sequence to k-mer frequency vector"""
48
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
49
- kmer_dict = {km: i for i, km in enumerate(kmers)}
50
- vec = np.zeros(len(kmers), dtype=np.float32)
51
-
52
- for i in range(len(sequence) - k + 1):
53
- kmer = sequence[i:i+k]
54
- if kmer in kmer_dict:
55
- vec[kmer_dict[kmer]] += 1
56
-
57
- total_kmers = len(sequence) - k + 1
58
- if total_kmers > 0:
59
- vec = vec / total_kmers
60
 
61
- return vec
62
 
63
- def parse_fasta(text):
 
 
 
64
  sequences = []
65
  current_header = None
66
  current_sequence = []
@@ -80,15 +83,109 @@ def parse_fasta(text):
80
  sequences.append((current_header, ''.join(current_sequence)))
81
  return sequences
82
 
83
- def create_visualization(important_kmers, human_prob, title):
84
- """Create a comprehensive visualization of k-mer impacts"""
85
- fig = plt.figure(figsize=(15, 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Create grid for subplots
88
- gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # 1. Probability Step Plot
91
- ax1 = plt.subplot(gs[0])
92
  current_prob = 0.5
93
  steps = [('Start', current_prob, 0)]
94
 
@@ -96,188 +193,296 @@ def create_visualization(important_kmers, human_prob, title):
96
  change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
97
  current_prob += change
98
  steps.append((kmer['kmer'], current_prob, change))
 
 
 
 
 
 
99
 
100
- x = range(len(steps))
101
- y = [step[1] for step in steps]
102
-
103
- # Plot steps
104
- ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
105
- ax1.plot(x, y, 'b.', markersize=10)
106
-
107
- # Add reference line
108
- ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
109
-
110
- # Customize plot
111
- ax1.grid(True, linestyle='--', alpha=0.7)
112
- ax1.set_ylim(0, 1)
113
- ax1.set_ylabel('Human Probability')
114
- ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
115
-
116
- # Add labels for each point
117
  for i, (kmer, prob, change) in enumerate(steps):
118
- # Add k-mer label
119
- ax1.annotate(kmer,
120
  (i, prob),
121
  xytext=(0, 10 if i % 2 == 0 else -20),
122
  textcoords='offset points',
123
  ha='center',
124
  rotation=45)
125
 
126
- # Add change value
127
  if i > 0:
128
  change_text = f'{change:+.3f}'
129
  color = 'green' if change > 0 else 'red'
130
- ax1.annotate(change_text,
131
- (i, prob),
132
- xytext=(0, -20 if i % 2 == 0 else 10),
133
- textcoords='offset points',
134
- ha='center',
135
- color=color)
136
-
137
- ax1.legend()
138
-
139
- # 2. K-mer Frequency and Sigma Plot
140
- ax2 = plt.subplot(gs[1])
 
 
 
 
 
 
 
141
 
142
  # Prepare data
143
  kmers = [k['kmer'] for k in important_kmers]
144
  frequencies = [k['occurrence'] for k in important_kmers]
145
  sigmas = [k['sigma'] for k in important_kmers]
146
- colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
147
 
148
- # Create bar plot for frequencies
149
  x = np.arange(len(kmers))
150
  width = 0.35
151
 
152
- ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
153
- ax2_twin = ax2.twinx()
154
- ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
155
 
156
- # Customize plot
157
- ax2.set_xticks(x)
158
- ax2.set_xticklabels(kmers, rotation=45)
159
- ax2.set_ylabel('Frequency (%)')
160
- ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
161
- ax2.set_title('K-mer Frequencies and Statistical Significance')
162
 
163
- # Add legends
164
- lines1, labels1 = ax2.get_legend_handles_labels()
165
- lines2, labels2 = ax2_twin.get_legend_handles_labels()
166
- ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
 
 
 
 
 
167
 
168
  plt.tight_layout()
169
  return fig
170
 
171
- def predict(file_obj):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if file_obj is None:
173
- return "Please upload a FASTA file", None
174
 
175
  try:
176
  if isinstance(file_obj, str):
177
  text = file_obj
178
  else:
179
- text = file_obj.decode('utf-8')
180
  except Exception as e:
181
- return f"Error reading file: {str(e)}", None
182
-
183
- k = 4
184
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
185
- kmer_dict = {km: i for i, km in enumerate(kmers)}
 
186
 
 
187
  try:
188
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
189
- model = VirusClassifier(256).to(device)
190
  state_dict = torch.load('model.pt', map_location=device)
191
  model.load_state_dict(state_dict)
192
- scaler = joblib.load('scaler.pkl')
193
  model.eval()
 
 
194
  except Exception as e:
195
- return f"Error loading model: {str(e)}", None
 
 
 
 
196
 
197
- results_text = ""
198
- plot_image = None
 
199
 
200
- try:
201
- sequences = parse_fasta(text)
202
- header, seq = sequences[0]
203
 
204
- raw_freq_vector = sequence_to_kmer_vector(seq)
205
- kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
206
- X_tensor = torch.FloatTensor(kmer_vector).to(device)
 
207
 
208
- # Get model predictions
209
  with torch.no_grad():
210
  output = model(X_tensor)
211
  probs = torch.softmax(output, dim=1)
212
 
213
- # Get feature importance
214
- importance, _ = model.get_feature_importance(X_tensor)
215
- kmer_importance = importance[0].cpu().numpy()
 
 
 
216
 
217
- # Get top k-mers
218
- top_k = 10
219
- top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
 
 
 
 
 
220
 
221
- important_kmers = []
222
- for idx in top_indices:
223
- kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
224
- imp = float(abs(kmer_importance[idx]))
225
- direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
226
- freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
227
- sigma = float(kmer_vector[0][idx])
 
228
 
229
- important_kmers.append({
230
- 'kmer': kmer,
231
- 'impact': imp,
232
  'direction': direction,
233
- 'occurrence': freq,
234
  'sigma': sigma
235
  })
236
 
237
- # Generate text results
238
- pred_class = 1 if probs[0][1] > probs[0][0] else 0
239
- pred_label = 'human' if pred_class == 1 else 'non-human'
240
- human_prob = float(probs[0][1])
241
-
242
- results_text = f"""Sequence: {header}
243
- Prediction: {pred_label}
244
- Confidence: {float(max(probs[0])):0.4f}
245
- Human probability: {human_prob:0.4f}
246
- Non-human probability: {float(probs[0][0]):0.4f}
247
- Most influential k-mers (ranked by importance):"""
 
 
 
 
 
 
248
 
249
- for kmer in important_kmers:
250
- results_text += f"\n {kmer['kmer']}: "
251
- results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
252
- results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
253
- results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
254
- results_text += "more" if kmer['sigma'] > 0 else "less"
255
- results_text += " than average)"
256
 
257
- # Create visualization
258
- fig = create_visualization(important_kmers, human_prob, header)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- # Save plot
261
- buf = io.BytesIO()
262
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
263
- buf.seek(0)
264
- plot_image = Image.open(buf)
265
- plt.close(fig)
 
 
 
 
266
 
267
- except Exception as e:
268
- return f"Error processing sequences: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- return results_text, plot_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- iface = gr.Interface(
273
- fn=predict,
274
- inputs=gr.File(label="Upload FASTA file", type="binary"),
275
- outputs=[
276
- gr.Textbox(label="Results"),
277
- gr.Image(label="K-mer Analysis Visualization")
278
- ],
279
- title="Virus Host Classifier"
280
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  if __name__ == "__main__":
283
- iface.launch(share=True)
 
 
2
  import torch
3
  import joblib
4
  import numpy as np
 
5
  import torch.nn as nn
6
  import matplotlib.pyplot as plt
7
  import io
8
  from PIL import Image
9
+ from itertools import product
10
+
11
+ # --------------- Model Definition ---------------
12
 
13
  class VirusClassifier(nn.Module):
14
  def __init__(self, input_shape: int):
 
30
  def forward(self, x):
31
  return self.network(x)
32
 
33
+ def get_gradient_importance(self, x, class_index=1):
34
+ """
35
+ Calculate gradient-based importance for each input feature.
36
+ By default, we compute the gradient wrt the 'human' class (index=1).
37
+ This method is akin to a raw gradient or 'saliency' approach.
38
+ """
39
+ x = x.clone().detach().requires_grad_(True)
40
  output = self.network(x)
41
  probs = torch.softmax(output, dim=1)
42
 
43
+ # Probability of the specified class
44
+ target_prob = probs[..., class_index]
45
+
46
+ # Zero existing gradients if any
47
  if x.grad is not None:
48
  x.grad.zero_()
 
 
49
 
50
+ # Backprop on that probability
51
+ target_prob.backward()
52
+
53
+ # Raw gradient is now in x.grad
54
+ importance = x.grad.detach()
55
+
56
+ # Optional: Multiply by input to get a more "integrated gradients"-like measure
57
+ # importance = importance * x.detach()
58
+
59
+ return importance, float(target_prob)
 
 
 
 
 
 
60
 
61
+ # --------------- Utility Functions ---------------
62
 
63
+ def parse_fasta(text: str):
64
+ """
65
+ Parse a FASTA string and return a list of (header, sequence) pairs.
66
+ """
67
  sequences = []
68
  current_header = None
69
  current_sequence = []
 
83
  sequences.append((current_header, ''.join(current_sequence)))
84
  return sequences
85
 
86
+ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
87
+ """
88
+ Convert a nucleotide sequence into a k-mer frequency vector.
89
+ Defaults to k=4.
90
+ """
91
+ # Generate all possible k-mers
92
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
93
+ kmer_dict = {km: i for i, km in enumerate(kmers)}
94
+ vec = np.zeros(len(kmers), dtype=np.float32)
95
+
96
+ for i in range(len(sequence) - k + 1):
97
+ kmer = sequence[i:i+k]
98
+ if kmer in kmer_dict:
99
+ vec[kmer_dict[kmer]] += 1
100
+
101
+ total_kmers = len(sequence) - k + 1
102
+ if total_kmers > 0:
103
+ vec = vec / total_kmers
104
+
105
+ return vec
106
+
107
+ def compute_sequence_stats(sequence: str):
108
+ """
109
+ Compute various statistics for a given sequence:
110
+ - Length
111
+ - GC content (%)
112
+ - A/C/G/T counts
113
+ """
114
+ length = len(sequence)
115
+ if length == 0:
116
+ return {
117
+ 'length': 0,
118
+ 'gc_content': 0,
119
+ 'counts': {'A': 0, 'C': 0, 'G': 0, 'T': 0}
120
+ }
121
+
122
+ counts = {
123
+ 'A': sequence.count('A'),
124
+ 'C': sequence.count('C'),
125
+ 'G': sequence.count('G'),
126
+ 'T': sequence.count('T')
127
+ }
128
+ gc_content = (counts['G'] + counts['C']) / length * 100.0
129
+
130
+ return {
131
+ 'length': length,
132
+ 'gc_content': gc_content,
133
+ 'counts': counts
134
+ }
135
+
136
+ # --------------- Visualization Functions ---------------
137
+
138
+ def plot_shap_like_bars(kmers, importance_values, top_k=10):
139
+ """
140
+ Create a bar chart that mimics a SHAP summary plot:
141
+ - k-mers on y-axis
142
+ - importance magnitude on x-axis
143
+ - color indicating positive (push towards human) vs negative (push towards non-human)
144
+ """
145
+ abs_importance = np.abs(importance_values)
146
+ # Sort by absolute importance
147
+ sorted_indices = np.argsort(abs_importance)[::-1]
148
+ top_indices = sorted_indices[:top_k]
149
+
150
+ # Prepare data
151
+ top_kmers = [kmers[i] for i in top_indices]
152
+ top_importances = importance_values[top_indices]
153
 
154
+ # Create plot
155
+ fig, ax = plt.subplots(figsize=(8, 6))
156
+ colors = ['green' if val > 0 else 'red' for val in top_importances]
157
+ ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors)
158
+ ax.set_yticks(range(len(top_kmers)))
159
+ ax.set_yticklabels(top_kmers)
160
+ ax.invert_yaxis() # So that the highest value is at the top
161
+ ax.set_xlabel("Feature Importance (Gradient Magnitude)")
162
+ ax.set_title(f"Top-{top_k} SHAP-like Feature Importances")
163
+ plt.tight_layout()
164
+ return fig
165
+
166
+ def plot_kmer_distribution(kmer_freq_vector, kmers):
167
+ """
168
+ Plot a histogram of k-mer frequencies for the entire vector.
169
+ (Optional if you want a quick distribution overview)
170
+ """
171
+ fig, ax = plt.subplots(figsize=(10, 4))
172
+ ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6)
173
+ ax.set_xlabel("K-mer Index")
174
+ ax.set_ylabel("Frequency")
175
+ ax.set_title("K-mer Frequency Distribution")
176
+ ax.set_xticks([])
177
+ plt.tight_layout()
178
+ return fig
179
+
180
+ def create_step_visualization(important_kmers, human_prob):
181
+ """
182
+ Re-implementation of your step-wise probability plot.
183
+ Shows how each top k-mer 'pushes' the probability from 0.5 to the final value.
184
+ """
185
+ fig = plt.figure(figsize=(8, 5))
186
+ ax = fig.add_subplot(111)
187
 
188
+ # Start from 0.5
 
189
  current_prob = 0.5
190
  steps = [('Start', current_prob, 0)]
191
 
 
193
  change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
194
  current_prob += change
195
  steps.append((kmer['kmer'], current_prob, change))
196
+
197
+ x_vals = range(len(steps))
198
+ y_vals = [s[1] for s in steps]
199
+
200
+ ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2)
201
+ ax.plot(x_vals, y_vals, 'b.', markersize=10)
202
 
203
+ # Reference line at 0.5
204
+ ax.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
205
+ ax.set_ylim(0, 1)
206
+ ax.set_ylabel('Human Probability')
207
+ ax.set_title(f'K-mer Contributions (final p={human_prob:.3f})')
208
+ ax.grid(True, linestyle='--', alpha=0.7)
209
+
 
 
 
 
 
 
 
 
 
 
210
  for i, (kmer, prob, change) in enumerate(steps):
211
+ ax.annotate(kmer,
 
212
  (i, prob),
213
  xytext=(0, 10 if i % 2 == 0 else -20),
214
  textcoords='offset points',
215
  ha='center',
216
  rotation=45)
217
 
 
218
  if i > 0:
219
  change_text = f'{change:+.3f}'
220
  color = 'green' if change > 0 else 'red'
221
+ ax.annotate(change_text,
222
+ (i, prob),
223
+ xytext=(0, -20 if i % 2 == 0 else 10),
224
+ textcoords='offset points',
225
+ ha='center',
226
+ color=color)
227
+
228
+ ax.legend()
229
+ plt.tight_layout()
230
+ return fig
231
+
232
+ def plot_kmer_freq_and_sigma(important_kmers):
233
+ """
234
+ Plot frequencies vs. sigma from mean for the top k-mers.
235
+ This reuses logic from the original create_visualization second subplot,
236
+ but as its own function for clarity.
237
+ """
238
+ fig, ax = plt.subplots(figsize=(8, 5))
239
 
240
  # Prepare data
241
  kmers = [k['kmer'] for k in important_kmers]
242
  frequencies = [k['occurrence'] for k in important_kmers]
243
  sigmas = [k['sigma'] for k in important_kmers]
244
+ colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers]
245
 
 
246
  x = np.arange(len(kmers))
247
  width = 0.35
248
 
249
+ # Frequency bars
250
+ ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
 
251
 
252
+ # Create a twin axis for sigma
253
+ ax2 = ax.twinx()
254
+ # Sigma bars
255
+ ax2.bar(x + width/2, sigmas, width, label='σ from mean',
256
+ color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
 
257
 
258
+ ax.set_xticks(x)
259
+ ax.set_xticklabels(kmers, rotation=45)
260
+ ax.set_ylabel('Frequency (%)')
261
+ ax2.set_ylabel('Standard Deviations (σ) from Mean')
262
+ ax.set_title("K-mer Frequencies & Statistical Significance")
263
+
264
+ lines1, labels1 = ax.get_legend_handles_labels()
265
+ lines2, labels2 = ax2.get_legend_handles_labels()
266
+ ax.legend(lines1 + lines2, labels1 + labels2, loc='best')
267
 
268
  plt.tight_layout()
269
  return fig
270
 
271
+ # --------------- Main Prediction Logic ---------------
272
+
273
+ def predict_fasta(
274
+ file_obj,
275
+ k_size=4,
276
+ top_k=10,
277
+ advanced_analysis=False
278
+ ):
279
+ """
280
+ Main function to predict classes for each sequence in an uploaded FASTA.
281
+ Returns:
282
+ - Combined textual report for all sequences
283
+ - A list of generated PIL Image plots
284
+ """
285
+ # 1. Read raw text from file or string
286
  if file_obj is None:
287
+ return "Please upload a FASTA file", []
288
 
289
  try:
290
  if isinstance(file_obj, str):
291
  text = file_obj
292
  else:
293
+ text = file_obj.decode('utf-8', errors='replace')
294
  except Exception as e:
295
+ return f"Error reading file: {str(e)}", []
296
+
297
+ # 2. Parse the FASTA
298
+ sequences = parse_fasta(text)
299
+ if not sequences:
300
+ return "No valid FASTA sequences found!", []
301
 
302
+ # 3. Load model & scaler
303
  try:
304
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
305
+ model = VirusClassifier(input_shape=(4 ** k_size)).to(device)
306
  state_dict = torch.load('model.pt', map_location=device)
307
  model.load_state_dict(state_dict)
 
308
  model.eval()
309
+
310
+ scaler = joblib.load('scaler.pkl')
311
  except Exception as e:
312
+ return f"Error loading model/scaler: {str(e)}", []
313
+
314
+ # 4. Prepare k-mer dictionary for reference
315
+ all_kmers = [''.join(p) for p in product("ACGT", repeat=k_size)]
316
+ kmer_dict = {km: i for i, km in enumerate(all_kmers)}
317
 
318
+ # 5. Iterate over sequences and build output
319
+ final_text_report = []
320
+ plots = []
321
 
322
+ for idx, (header, seq) in enumerate(sequences, start=1):
323
+ seq_stats = compute_sequence_stats(seq)
 
324
 
325
+ # Convert sequence -> raw freq -> scaled freq
326
+ raw_kmer_freq = sequence_to_kmer_vector(seq, k=k_size)
327
+ scaled_kmer_freq = scaler.transform(raw_kmer_freq.reshape(1, -1))
328
+ X_tensor = torch.FloatTensor(scaled_kmer_freq).to(device)
329
 
330
+ # Predict
331
  with torch.no_grad():
332
  output = model(X_tensor)
333
  probs = torch.softmax(output, dim=1)
334
 
335
+ # Determine class
336
+ pred_class = torch.argmax(probs, dim=1).item()
337
+ pred_label = 'human' if pred_class == 1 else 'non-human'
338
+ human_prob = float(probs[0][1])
339
+ non_human_prob = float(probs[0][0])
340
+ confidence = float(torch.max(probs[0]).item())
341
 
342
+ # Compute gradient-based importance
343
+ importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1)
344
+ importance = importance[0].cpu().numpy() # shape: (num_features,)
345
+
346
+ # Identify top-k features (by absolute gradient)
347
+ abs_importance = np.abs(importance)
348
+ sorted_indices = np.argsort(abs_importance)[::-1]
349
+ top_indices = sorted_indices[:top_k]
350
 
351
+ # Build a list of top k-mers
352
+ top_kmers_info = []
353
+ for i in top_indices:
354
+ kmer_name = all_kmers[i]
355
+ imp_val = float(importance[i])
356
+ direction = 'human' if imp_val > 0 else 'non-human'
357
+ freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent
358
+ sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler)
359
 
360
+ top_kmers_info.append({
361
+ 'kmer': kmer_name,
362
+ 'impact': abs(imp_val),
363
  'direction': direction,
364
+ 'occurrence': freq_perc,
365
  'sigma': sigma
366
  })
367
 
368
+ # Text summary for this sequence
369
+ seq_report = []
370
+ seq_report.append(f"=== Sequence {idx} ===")
371
+ seq_report.append(f"Header: {header}")
372
+ seq_report.append(f"Length: {seq_stats['length']}")
373
+ seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%")
374
+ seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}")
375
+ seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})")
376
+ seq_report.append(f" Human Probability: {human_prob:.4f}")
377
+ seq_report.append(f" Non-human Probability: {non_human_prob:.4f}")
378
+ seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):")
379
+ for tkm in top_kmers_info:
380
+ seq_report.append(
381
+ f" {tkm['kmer']}: pushes towards {tkm['direction']} "
382
+ f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, "
383
+ f"sigma={tkm['sigma']:.2f}"
384
+ )
385
 
386
+ final_text_report.append("\n".join(seq_report))
 
 
 
 
 
 
387
 
388
+ # 6. Generate Plots (for each sequence)
389
+ if advanced_analysis:
390
+ # 6A. SHAP-like bar chart
391
+ fig_shap = plot_shap_like_bars(
392
+ kmers=all_kmers,
393
+ importance_values=importance,
394
+ top_k=top_k
395
+ )
396
+ buf_shap = io.BytesIO()
397
+ fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150)
398
+ buf_shap.seek(0)
399
+ plots.append(Image.open(buf_shap))
400
+ plt.close(fig_shap)
401
+
402
+ # 6B. k-mer distribution histogram
403
+ fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers)
404
+ buf_dist = io.BytesIO()
405
+ fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150)
406
+ buf_dist.seek(0)
407
+ plots.append(Image.open(buf_dist))
408
+ plt.close(fig_kmer_dist)
409
 
410
+ # 6C. Original step visualization for top k k-mers
411
+ # Sort by actual 'impact' to preserve that step logic
412
+ # (largest absolute impact first)
413
+ top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True)
414
+ fig_step = create_step_visualization(top_kmers_info_step, human_prob)
415
+ buf_step = io.BytesIO()
416
+ fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150)
417
+ buf_step.seek(0)
418
+ plots.append(Image.open(buf_step))
419
+ plt.close(fig_step)
420
 
421
+ # 6D. Frequency vs. sigma bar chart
422
+ fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step)
423
+ buf_freq_sigma = io.BytesIO()
424
+ fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150)
425
+ buf_freq_sigma.seek(0)
426
+ plots.append(Image.open(buf_freq_sigma))
427
+ plt.close(fig_freq_sigma)
428
+
429
+ # Combine all text results
430
+ combined_text = "\n\n".join(final_text_report)
431
+ return combined_text, plots
432
+
433
+ # --------------- Gradio Interface ---------------
434
 
435
+ def run_prediction(
436
+ file_obj,
437
+ k_size,
438
+ top_k,
439
+ advanced_analysis
440
+ ):
441
+ """
442
+ Wrapper for Gradio to handle the outputs in (text, List[Image]) form.
443
+ """
444
+ text_output, pil_images = predict_fasta(
445
+ file_obj=file_obj,
446
+ k_size=k_size,
447
+ top_k=top_k,
448
+ advanced_analysis=advanced_analysis
449
+ )
450
 
451
+
452
+ return text_output, pil_images
453
+
454
+
455
+ with gr.Blocks() as demo:
456
+ gr.Markdown("# Virus Host Classifier (Improved!)")
457
+ gr.Markdown(
458
+ "Upload a FASTA file and configure k-mer size, number of top features, "
459
+ "and whether to run advanced analysis (plots of SHAP-like bars & k-mer distribution)."
460
+ )
461
+
462
+ with gr.Row():
463
+ with gr.Column():
464
+ fasta_file = gr.File(label="Upload FASTA file", type="binary")
465
+ kmer_slider = gr.Slider(minimum=2, maximum=6, value=4, step=1, label="K-mer Size")
466
+ topk_slider = gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Top-k Features")
467
+ advanced_check = gr.Checkbox(value=False, label="Advanced Analysis")
468
+ predict_button = gr.Button("Predict")
469
+
470
+ with gr.Column():
471
+ results_text = gr.Textbox(
472
+ label="Results",
473
+ lines=20,
474
+ placeholder="Prediction results will appear here..."
475
+ )
476
+
477
+ # We can display multiple images in a Gallery or as separate outputs.
478
+ plots_gallery = gr.Gallery(label="Analysis Plots").style(grid=[2], height="auto")
479
+
480
+ predict_button.click(
481
+ fn=run_prediction,
482
+ inputs=[fasta_file, kmer_slider, topk_slider, advanced_check],
483
+ outputs=[results_text, plots_gallery]
484
+ )
485
 
486
  if __name__ == "__main__":
487
+ demo.launch(share=True)
488
+