hiyata commited on
Commit
56468ea
·
verified ·
1 Parent(s): d76e76a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -160
app.py CHANGED
@@ -32,7 +32,6 @@ class VirusClassifier(nn.Module):
32
  def forward(self, x):
33
  return self.network(x)
34
 
35
-
36
  ###############################################################################
37
  # 2. FASTA PARSING & K-MER FEATURE ENGINEERING
38
  ###############################################################################
@@ -59,7 +58,7 @@ def parse_fasta(text):
59
  return sequences
60
 
61
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
62
- """Convert a sequence to a k-mer frequency vector."""
63
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
64
  kmer_dict = {km: i for i, km in enumerate(kmers)}
65
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -75,7 +74,6 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
75
 
76
  return vec
77
 
78
-
79
  ###############################################################################
80
  # 3. SHAP-VALUE (ABLATION) CALCULATION
81
  ###############################################################################
@@ -83,30 +81,29 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
83
  def calculate_shap_values(model, x_tensor):
84
  """
85
  Calculate SHAP values using a simple ablation approach.
86
- Returns shap values and model prediction.
87
  """
88
  model.eval()
89
  with torch.no_grad():
90
- # Get baseline prediction
91
  baseline_output = model(x_tensor)
92
  baseline_probs = torch.softmax(baseline_output, dim=1)
93
  baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
94
 
95
- # Calculate impact of zeroing each feature
96
  shap_values = []
97
  x_zeroed = x_tensor.clone()
98
  for i in range(x_tensor.shape[1]):
99
- original_value = x_zeroed[0, i].item()
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
  prob = probs[0, 1].item()
104
  impact = baseline_prob - prob
105
  shap_values.append(impact)
106
- x_zeroed[0, i] = original_value # restore
107
  return np.array(shap_values), baseline_prob
108
 
109
-
110
  ###############################################################################
111
  # 4. PER-BASE SHAP AGGREGATION
112
  ###############################################################################
@@ -116,7 +113,6 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
116
  Returns an array of per-base SHAP contributions by averaging
117
  the k-mer SHAP values of all k-mers covering that base.
118
  """
119
- # Create the list of k-mers (in lexicographic order)
120
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
121
  kmer_dict = {km: i for i, km in enumerate(kmers)}
122
 
@@ -136,79 +132,44 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
136
 
137
  return shap_means
138
 
139
-
140
  ###############################################################################
141
- # 5. HEATMAP PLOTS
142
  ###############################################################################
143
 
144
- def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap"):
 
 
 
 
 
 
 
 
 
145
  """
146
  Plots a 1D heatmap of per-base SHAP contributions.
147
  Negative = push toward Non-Human, Positive = push toward Human.
 
148
  """
149
- heatmap_data = shap_means.reshape(1, -1) # shape (1, seq_len)
150
- fig, ax = plt.subplots(figsize=(12, 2))
 
 
 
151
 
 
 
 
152
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
153
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
154
  cbar.set_label('SHAP Contribution')
155
 
156
  ax.set_yticks([])
157
  ax.set_xlabel('Position in Sequence')
158
- ax.set_title(title)
159
- plt.tight_layout()
160
- return fig
161
-
162
- def get_top_signal_region(shap_means, window_size=500):
163
- """
164
- Find the window of length `window_size` that has the highest
165
- sum of absolute SHAP values. Returns (start_index, end_index).
166
- """
167
- seq_len = len(shap_means)
168
- if window_size >= seq_len:
169
- return 0, seq_len # entire sequence if window too large
170
-
171
- abs_values = np.abs(shap_means)
172
- max_sum = -1
173
- max_start = 0
174
-
175
- # Slide a window over shap_means
176
- current_sum = np.sum(abs_values[:window_size])
177
- max_sum = current_sum
178
- for start in range(1, seq_len - window_size + 1):
179
- # Remove the leftmost base, add the new rightmost base
180
- current_sum = current_sum - abs_values[start-1] + abs_values[start + window_size - 1]
181
- if current_sum > max_sum:
182
- max_sum = current_sum
183
- max_start = start
184
-
185
- return max_start, max_start + window_size
186
-
187
- def plot_zoomed_heatmap(shap_means, window_size=500, title="Zoomed SHAP Region"):
188
- """
189
- Finds the region with the largest absolute SHAP sum in a fixed window,
190
- then plots a 1D heatmap of just that sub-region.
191
- """
192
- start, end = get_top_signal_region(shap_means, window_size)
193
- sub_means = shap_means[start:end].reshape(1, -1)
194
-
195
- fig, ax = plt.subplots(figsize=(12, 2))
196
- cax = ax.imshow(sub_means, aspect='auto', cmap='RdBu_r')
197
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
198
- cbar.set_label('SHAP Contribution')
199
-
200
- ax.set_yticks([])
201
- ax.set_xlabel(f'Position in Sequence (zoomed in {start} - {end})')
202
- ax.set_title(title)
203
-
204
  plt.tight_layout()
205
  return fig
206
 
207
-
208
- ###############################################################################
209
- # 6. OTHER PLOT: TOP-K K-MER BAR PLOT
210
- ###############################################################################
211
-
212
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
213
  """Create a bar plot of the most important k-mers."""
214
  plt.rcParams.update({'font.size': 10})
@@ -223,31 +184,24 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
223
 
224
  plt.barh(range(len(values)), values, color=colors)
225
  plt.yticks(range(len(values)), features)
226
- plt.xlabel('SHAP value (impact on model output)')
227
  plt.title(f'Top {top_k} Most Influential k-mers')
228
  plt.gca().invert_yaxis()
229
  return fig
230
 
231
- ###############################################################################
232
- # 7. HELPER FUNCTION: FIG TO IMAGE
233
- ###############################################################################
234
-
235
- def fig_to_image(fig):
236
- """Convert a Matplotlib figure to a PIL Image."""
237
- import io
238
- buf = io.BytesIO()
239
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
240
- buf.seek(0)
241
- img = Image.open(buf)
242
- plt.close(fig)
243
- return img
244
 
245
  ###############################################################################
246
- # 8. MAIN PREDICTION FUNCTION
247
  ###############################################################################
248
 
249
- def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
250
- """Main prediction function for Gradio interface."""
251
  # Handle input
252
  if fasta_text.strip():
253
  text = fasta_text.strip()
@@ -256,14 +210,14 @@ def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
256
  with open(file_obj, 'r') as f:
257
  text = f.read()
258
  except Exception as e:
259
- return f"Error reading file: {str(e)}", None, None, None
260
  else:
261
- return "Please provide a FASTA sequence.", None, None, None
262
 
263
  # Parse FASTA
264
  sequences = parse_fasta(text)
265
  if not sequences:
266
- return "No valid FASTA sequences found.", None, None, None
267
 
268
  header, seq = sequences[0]
269
 
@@ -274,49 +228,101 @@ def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
274
  model.load_state_dict(torch.load('model.pt', map_location=device))
275
  scaler = joblib.load('scaler.pkl')
276
  except Exception as e:
277
- return f"Error loading model: {str(e)}", None, None, None
278
 
279
- # Generate features
280
  freq_vector = sequence_to_kmer_vector(seq)
281
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
282
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
283
 
284
- # Calculate SHAP values and get prediction
285
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
 
286
 
287
- # Prediction text
288
- results = [
289
- f"Sequence: {header}",
290
- f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin",
291
- f"Confidence: {max(prob_human, 1 - prob_human):.3f}",
292
- f"Human Probability: {prob_human:.3f}"
293
- ]
294
-
295
- # Create k-mer list (4-mers in lexicographic order)
 
 
 
 
296
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
297
-
298
- # 1) Top-k k-mer bar plot
299
- importance_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
300
- importance_img = fig_to_image(importance_fig)
301
-
302
- # 2) Full-genome per-base SHAP heatmap
303
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
304
- heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Per-base SHAP")
305
  heatmap_img = fig_to_image(heatmap_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- # 3) Zoomed region (optional, using the largest absolute SHAP region)
308
- if zoom_window > 0:
309
- zoom_fig = plot_zoomed_heatmap(shap_means, window_size=zoom_window,
310
- title=f"Top SHAP Region (window={zoom_window})")
311
- zoom_img = fig_to_image(zoom_fig)
312
- else:
313
- zoom_img = None
 
 
 
 
 
 
 
 
 
314
 
315
- return "\n".join(results), importance_img, heatmap_img, zoom_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
 
318
  ###############################################################################
319
- # 9. BUILD GRADIO INTERFACE
320
  ###############################################################################
321
 
322
  css = """
@@ -327,57 +333,87 @@ css = """
327
 
328
  with gr.Blocks(css=css) as iface:
329
  gr.Markdown("""
330
- # Virus Host Classifier
331
- Predicts whether a viral sequence is of human or non-human origin using k-mer analysis.
 
332
  """)
333
 
334
- with gr.Row():
335
- with gr.Column(scale=1):
336
- file_input = gr.File(
337
- label="Upload FASTA file",
338
- file_types=[".fasta", ".fa", ".txt"],
339
- type="filepath"
340
- )
341
- text_input = gr.Textbox(
342
- label="Or paste FASTA sequence",
343
- placeholder=">sequence_name\nACGTACGT...",
344
- lines=5
345
- )
346
- top_k = gr.Slider(
347
- minimum=5,
348
- maximum=30,
349
- value=10,
350
- step=1,
351
- label="Number of top k-mers to display"
352
- )
353
- zoom_window = gr.Slider(
354
- minimum=0,
355
- maximum=5000,
356
- value=500,
357
- step=100,
358
- label="Zoom Window Size (0 to disable zoom plot)"
359
- )
360
- submit_btn = gr.Button("Analyze Sequence", variant="primary")
361
-
362
- with gr.Column(scale=2):
363
- results_box = gr.Textbox(label="Analysis Results", lines=5)
364
- kmer_plot = gr.Image(label="Top k-mer SHAP")
365
- full_heatmap = gr.Image(label="Genome-wide SHAP Heatmap")
366
- zoomed_heatmap = gr.Image(label="Zoomed SHAP Region (largest signal)")
 
 
 
 
 
 
 
 
367
 
368
- submit_btn.click(
369
- predict,
370
- inputs=[file_input, top_k, text_input, zoom_window],
371
- outputs=[results_box, kmer_plot, full_heatmap, zoomed_heatmap]
372
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  gr.Markdown("""
375
- ### Visualization Guide
376
- - **Top k-mer SHAP**: Shows the most influential k-mers and their SHAP values.
377
- - **Genome-wide SHAP Heatmap**: Per-base SHAP values across the entire sequence.
378
- - Red = push toward human
379
- - Blue = push toward non-human
380
- - **Zoomed SHAP Region**: Shows the subregion of length 'Zoom Window Size' that has the highest absolute SHAP sum.
 
 
 
 
 
381
  """)
382
 
383
  if __name__ == "__main__":
 
32
  def forward(self, x):
33
  return self.network(x)
34
 
 
35
  ###############################################################################
36
  # 2. FASTA PARSING & K-MER FEATURE ENGINEERING
37
  ###############################################################################
 
58
  return sequences
59
 
60
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
61
+ """Convert a sequence to a k-mer frequency vector for classification."""
62
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
63
  kmer_dict = {km: i for i, km in enumerate(kmers)}
64
  vec = np.zeros(len(kmers), dtype=np.float32)
 
74
 
75
  return vec
76
 
 
77
  ###############################################################################
78
  # 3. SHAP-VALUE (ABLATION) CALCULATION
79
  ###############################################################################
 
81
  def calculate_shap_values(model, x_tensor):
82
  """
83
  Calculate SHAP values using a simple ablation approach.
84
+ Returns shap_values, prob_human
85
  """
86
  model.eval()
87
  with torch.no_grad():
88
+ # Baseline
89
  baseline_output = model(x_tensor)
90
  baseline_probs = torch.softmax(baseline_output, dim=1)
91
  baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
92
 
93
+ # Zeroing each feature to measure impact
94
  shap_values = []
95
  x_zeroed = x_tensor.clone()
96
  for i in range(x_tensor.shape[1]):
97
+ original_val = x_zeroed[0, i].item()
98
  x_zeroed[0, i] = 0.0
99
  output = model(x_zeroed)
100
  probs = torch.softmax(output, dim=1)
101
  prob = probs[0, 1].item()
102
  impact = baseline_prob - prob
103
  shap_values.append(impact)
104
+ x_zeroed[0, i] = original_val # restore
105
  return np.array(shap_values), baseline_prob
106
 
 
107
  ###############################################################################
108
  # 4. PER-BASE SHAP AGGREGATION
109
  ###############################################################################
 
113
  Returns an array of per-base SHAP contributions by averaging
114
  the k-mer SHAP values of all k-mers covering that base.
115
  """
 
116
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
117
  kmer_dict = {km: i for i, km in enumerate(kmers)}
118
 
 
132
 
133
  return shap_means
134
 
 
135
  ###############################################################################
136
+ # 5. PLOTTING / UTILITIES
137
  ###############################################################################
138
 
139
+ def fig_to_image(fig):
140
+ """Convert a Matplotlib figure to a PIL Image for Gradio."""
141
+ buf = io.BytesIO()
142
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
143
+ buf.seek(0)
144
+ img = Image.open(buf)
145
+ plt.close(fig)
146
+ return img
147
+
148
+ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
149
  """
150
  Plots a 1D heatmap of per-base SHAP contributions.
151
  Negative = push toward Non-Human, Positive = push toward Human.
152
+ Optionally can show only a subrange (start:end).
153
  """
154
+ if start is not None and end is not None:
155
+ shap_means = shap_means[start:end]
156
+ subtitle = f" (positions {start}-{end})"
157
+ else:
158
+ subtitle = ""
159
 
160
+ heatmap_data = shap_means.reshape(1, -1) # shape (1, region_length)
161
+
162
+ fig, ax = plt.subplots(figsize=(12, 2))
163
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
164
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
165
  cbar.set_label('SHAP Contribution')
166
 
167
  ax.set_yticks([])
168
  ax.set_xlabel('Position in Sequence')
169
+ ax.set_title(f"{title}{subtitle}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  plt.tight_layout()
171
  return fig
172
 
 
 
 
 
 
173
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
174
  """Create a bar plot of the most important k-mers."""
175
  plt.rcParams.update({'font.size': 10})
 
184
 
185
  plt.barh(range(len(values)), values, color=colors)
186
  plt.yticks(range(len(values)), features)
187
+ plt.xlabel('SHAP Value (impact on model output)')
188
  plt.title(f'Top {top_k} Most Influential k-mers')
189
  plt.gca().invert_yaxis()
190
  return fig
191
 
192
+ def compute_gc_content(sequence):
193
+ """Compute %GC in the sequence (A, C, G, T)."""
194
+ if not sequence:
195
+ return 0
196
+ gc_count = sequence.count('G') + sequence.count('C')
197
+ return (gc_count / len(sequence)) * 100.0
 
 
 
 
 
 
 
198
 
199
  ###############################################################################
200
+ # 6. MAIN ANALYSIS STEP (Gradio Step 1)
201
  ###############################################################################
202
 
203
+ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
204
+ """Analyzes the entire genome, returning classification and a heatmap."""
205
  # Handle input
206
  if fasta_text.strip():
207
  text = fasta_text.strip()
 
210
  with open(file_obj, 'r') as f:
211
  text = f.read()
212
  except Exception as e:
213
+ return (f"Error reading file: {str(e)}", None, None, None, None)
214
  else:
215
+ return ("Please provide a FASTA sequence.", None, None, None, None)
216
 
217
  # Parse FASTA
218
  sequences = parse_fasta(text)
219
  if not sequences:
220
+ return ("No valid FASTA sequences found.", None, None, None, None)
221
 
222
  header, seq = sequences[0]
223
 
 
228
  model.load_state_dict(torch.load('model.pt', map_location=device))
229
  scaler = joblib.load('scaler.pkl')
230
  except Exception as e:
231
+ return (f"Error loading model: {str(e)}", None, None, None, None)
232
 
233
+ # Vectorize + scale
234
  freq_vector = sequence_to_kmer_vector(seq)
235
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
236
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
237
 
238
+ # SHAP + classification
239
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
240
+ prob_nonhuman = 1.0 - prob_human
241
 
242
+ classification = "Human" if prob_human > 0.5 else "Non-human"
243
+ confidence = max(prob_human, prob_nonhuman)
244
+
245
+ # Build results text
246
+ results_text = (
247
+ f"Sequence: {header}\n"
248
+ f"Length: {len(seq):,} bases\n"
249
+ f"Classification: {classification}\n"
250
+ f"Confidence: {confidence:.3f}\n"
251
+ f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})"
252
+ )
253
+
254
+ # K-mer importance plot
255
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
256
+ bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
257
+ bar_img = fig_to_image(bar_fig)
258
+
259
+ # Per-base SHAP for entire genome
 
 
260
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
261
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
262
  heatmap_img = fig_to_image(heatmap_fig)
263
+
264
+ # Return:
265
+ # 1) results text
266
+ # 2) k-mer bar image
267
+ # 3) full-genome heatmap
268
+ # 4) the "state" we need for step 2: (sequence, shap_means)
269
+ # We'll store these in a dictionary so we can pass it around in Gradio.
270
+ state_dict = {
271
+ "seq": seq,
272
+ "shap_means": shap_means
273
+ }
274
+
275
+ return (results_text, bar_img, heatmap_img, state_dict, header)
276
+
277
+ ###############################################################################
278
+ # 7. SUBREGION ANALYSIS (Gradio Step 2)
279
+ ###############################################################################
280
+
281
+ def analyze_subregion(state, header, region_start, region_end):
282
+ """
283
+ Takes stored data from step 1 and a user-chosen region.
284
+ Returns a subregion heatmap and some stats (like GC content, average SHAP).
285
+ """
286
+ if not state or "seq" not in state or "shap_means" not in state:
287
+ return ("No sequence data found. Please run Step 1 first.", None)
288
 
289
+ seq = state["seq"]
290
+ shap_means = state["shap_means"]
291
+
292
+ # Validate bounds
293
+ region_start = max(0, min(region_start, len(seq)))
294
+ region_end = max(0, min(region_end, len(seq)))
295
+ if region_end <= region_start:
296
+ return ("Invalid region range. End must be > Start.", None)
297
+
298
+ # Subsequence
299
+ region_seq = seq[region_start:region_end]
300
+ region_shap = shap_means[region_start:region_end]
301
+
302
+ # Some stats
303
+ gc_percent = compute_gc_content(region_seq)
304
+ avg_shap = float(np.mean(region_shap))
305
 
306
+ region_info = (
307
+ f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
308
+ f"Region length: {len(region_seq)} bases\n"
309
+ f"GC content: {gc_percent:.2f}%\n"
310
+ f"Average SHAP in region: {avg_shap:.4f} "
311
+ f"({'toward human' if avg_shap > 0 else 'toward non-human' if avg_shap < 0 else 'neutral'})"
312
+ )
313
+
314
+ # Plot region as small heatmap
315
+ fig = plot_linear_heatmap(shap_means,
316
+ title="Subregion SHAP",
317
+ start=region_start,
318
+ end=region_end)
319
+ heatmap_img = fig_to_image(fig)
320
+
321
+ return (region_info, heatmap_img)
322
 
323
 
324
  ###############################################################################
325
+ # 8. BUILD GRADIO INTERFACE
326
  ###############################################################################
327
 
328
  css = """
 
333
 
334
  with gr.Blocks(css=css) as iface:
335
  gr.Markdown("""
336
+ # Virus Host Classifier (with Interactive Region Viewer)
337
+ **Step 1**: Predict overall viral sequence origin (human vs non-human)
338
+ **Step 2**: Explore subregions to see local SHAP signals and GC content
339
  """)
340
 
341
+ with gr.Tab("1) Full-Sequence Analysis"):
342
+ with gr.Row():
343
+ with gr.Column(scale=1):
344
+ file_input = gr.File(
345
+ label="Upload FASTA file",
346
+ file_types=[".fasta", ".fa", ".txt"],
347
+ type="filepath"
348
+ )
349
+ text_input = gr.Textbox(
350
+ label="Or paste FASTA sequence",
351
+ placeholder=">sequence_name\nACGTACGT...",
352
+ lines=5
353
+ )
354
+ top_k = gr.Slider(
355
+ minimum=5,
356
+ maximum=30,
357
+ value=10,
358
+ step=1,
359
+ label="Number of top k-mers to display"
360
+ )
361
+ analyze_btn = gr.Button("Analyze Sequence", variant="primary")
362
+
363
+ with gr.Column(scale=2):
364
+ results_box = gr.Textbox(
365
+ label="Classification Results", lines=7, interactive=False
366
+ )
367
+ kmer_img = gr.Image(label="Top k-mer SHAP")
368
+ genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
369
+
370
+ # Hidden states that store data for step 2
371
+ # "state" will hold (sequence, shap_means).
372
+ # "header" is optional meta info
373
+ seq_state = gr.State()
374
+ header_state = gr.State()
375
+
376
+ # The "analyze_sequence" function returns 5 values, which we map here:
377
+ analyze_btn.click(
378
+ analyze_sequence,
379
+ inputs=[file_input, top_k, text_input],
380
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
381
+ )
382
 
383
+ with gr.Tab("2) Subregion Exploration"):
384
+ gr.Markdown("""
385
+ Select start/end positions to view local SHAP signals.
386
+ """)
387
+ with gr.Row():
388
+ region_start = gr.Number(label="Region Start", value=0)
389
+ region_end = gr.Number(label="Region End", value=500)
390
+ region_btn = gr.Button("Analyze Subregion")
391
+
392
+ subregion_info = gr.Textbox(
393
+ label="Subregion Analysis",
394
+ lines=4,
395
+ interactive=False
396
+ )
397
+ subregion_img = gr.Image(label="Subregion SHAP Heatmap")
398
+
399
+ region_btn.click(
400
+ analyze_subregion,
401
+ inputs=[seq_state, header_state, region_start, region_end],
402
+ outputs=[subregion_info, subregion_img]
403
+ )
404
 
405
  gr.Markdown("""
406
+ ### What does this interface provide?
407
+ 1. **Overall Classification** (human vs non-human), using a learned model on k-mer frequencies.
408
+ 2. **SHAP Analysis** (ablation-based) to see which k-mer features push classification toward or away from "human".
409
+ 3. **Genome-Wide SHAP Heatmap**: Each base's average SHAP across overlapping k-mers.
410
+ 4. **Subregion Exploration**:
411
+ - View SHAP signals in a user-chosen region.
412
+ - Calculate local GC content, average SHAP, etc.
413
+
414
+ ### Tips
415
+ - For very large sequences (e.g., >100k bases), the full heatmap might be large; consider downsampling if needed.
416
+ - Adjust *Region Start* and *End* to explore different parts of the genome.
417
  """)
418
 
419
  if __name__ == "__main__":