File size: 17,009 Bytes
5263bd3
 
 
 
 
4a7c026
 
40fe6da
a6886ca
 
 
afbf1c6
5263bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b775b7
a6886ca
 
 
 
 
 
 
3b775b7
 
 
a6886ca
 
 
 
3b775b7
 
 
a6886ca
 
 
 
 
 
 
 
 
 
b5edb58
a6886ca
b5edb58
a6886ca
 
 
 
870813f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c88c65
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e19501
a6886ca
b5edb58
 
 
 
 
 
 
a6886ca
 
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
b5edb58
a6886ca
b5edb58
 
 
 
 
 
 
 
 
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5edb58
 
 
 
 
a6886ca
b5edb58
6c88c65
b5edb58
3b775b7
a6886ca
 
3b775b7
a6886ca
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
 
 
3b775b7
 
 
8c49ca8
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b775b7
a6886ca
3b775b7
7e19501
 
 
 
a6886ca
3b775b7
a6886ca
 
 
 
 
 
b5edb58
a6886ca
723da6d
b5edb58
a6886ca
b5edb58
9d48283
3b775b7
a6886ca
 
b5edb58
a6886ca
 
 
 
 
0d2d632
a6886ca
 
 
b5edb58
a6886ca
 
b5edb58
a6886ca
 
 
 
b5edb58
a6886ca
3b775b7
b5edb58
 
 
a6886ca
 
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
 
b5edb58
a6886ca
 
 
3b775b7
a6886ca
b5edb58
3b775b7
 
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5edb58
a6886ca
b5edb58
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
 
 
 
b5edb58
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
d192dd4
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b775b7
a6886ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2d632
723da6d
a6886ca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
import gradio as gr
import torch
import joblib
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import io
from PIL import Image
from itertools import product

# --------------- Model Definition ---------------

class VirusClassifier(nn.Module):
    def __init__(self, input_shape: int):
        super(VirusClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_shape, 64),
            nn.GELU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        return self.network(x)
    
    def get_gradient_importance(self, x, class_index=1):
        """
        Calculate gradient-based importance for each input feature.
        By default, we compute the gradient wrt the 'human' class (index=1).
        This method is akin to a raw gradient or 'saliency' approach.
        """
        x = x.clone().detach().requires_grad_(True)
        output = self.network(x)
        probs = torch.softmax(output, dim=1)
        
        # Probability of the specified class
        target_prob = probs[..., class_index]
        
        # Zero existing gradients if any
        if x.grad is not None:
            x.grad.zero_()
        
        # Backprop on that probability
        target_prob.backward()
        
        # Raw gradient is now in x.grad
        importance = x.grad.detach()
        
        # Optional: Multiply by input to get a more "integrated gradients"-like measure
        # importance = importance * x.detach()
        
        return importance, float(target_prob)

# --------------- Utility Functions ---------------

def parse_fasta(text: str):
    """
    Parse a FASTA string and return a list of (header, sequence) pairs.
    """
    sequences = []
    current_header = None
    current_sequence = []
    
    for line in text.split('\n'):
        line = line.strip()
        if not line:
            continue
        if line.startswith('>'):
            if current_header:
                sequences.append((current_header, ''.join(current_sequence)))
            current_header = line[1:]
            current_sequence = []
        else:
            current_sequence.append(line.upper())
    if current_header:
        sequences.append((current_header, ''.join(current_sequence)))
    return sequences

def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
    """
    Convert a nucleotide sequence into a k-mer frequency vector.
    Defaults to k=4.
    """
    # Generate all possible k-mers
    kmers = [''.join(p) for p in product("ACGT", repeat=k)]
    kmer_dict = {km: i for i, km in enumerate(kmers)}
    vec = np.zeros(len(kmers), dtype=np.float32)
    
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if kmer in kmer_dict:
            vec[kmer_dict[kmer]] += 1

    total_kmers = len(sequence) - k + 1
    if total_kmers > 0:
        vec = vec / total_kmers

    return vec

def compute_sequence_stats(sequence: str):
    """
    Compute various statistics for a given sequence:
      - Length
      - GC content (%)
      - A/C/G/T counts
    """
    length = len(sequence)
    if length == 0:
        return {
            'length': 0,
            'gc_content': 0,
            'counts': {'A': 0, 'C': 0, 'G': 0, 'T': 0}
        }

    counts = {
        'A': sequence.count('A'),
        'C': sequence.count('C'),
        'G': sequence.count('G'),
        'T': sequence.count('T')
    }
    gc_content = (counts['G'] + counts['C']) / length * 100.0

    return {
        'length': length,
        'gc_content': gc_content,
        'counts': counts
    }

# --------------- Visualization Functions ---------------

def plot_shap_like_bars(kmers, importance_values, top_k=10):
    """
    Create a bar chart that mimics a SHAP summary plot:
      - k-mers on y-axis
      - importance magnitude on x-axis
      - color indicating positive (push towards human) vs negative (push towards non-human)
    """
    abs_importance = np.abs(importance_values)
    # Sort by absolute importance
    sorted_indices = np.argsort(abs_importance)[::-1]
    top_indices = sorted_indices[:top_k]
    
    # Prepare data
    top_kmers = [kmers[i] for i in top_indices]
    top_importances = importance_values[top_indices]
    
    # Create plot
    fig, ax = plt.subplots(figsize=(8, 6))
    colors = ['green' if val > 0 else 'red' for val in top_importances]
    ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors)
    ax.set_yticks(range(len(top_kmers)))
    ax.set_yticklabels(top_kmers)
    ax.invert_yaxis()  # So that the highest value is at the top
    ax.set_xlabel("Feature Importance (Gradient Magnitude)")
    ax.set_title(f"Top-{top_k} SHAP-like Feature Importances")
    plt.tight_layout()
    return fig

def plot_kmer_distribution(kmer_freq_vector, kmers):
    """
    Plot a histogram of k-mer frequencies for the entire vector.
    (Optional if you want a quick distribution overview)
    """
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6)
    ax.set_xlabel("K-mer Index")
    ax.set_ylabel("Frequency")
    ax.set_title("K-mer Frequency Distribution")
    ax.set_xticks([])
    plt.tight_layout()
    return fig

def create_step_visualization(important_kmers, human_prob):
    """
    Re-implementation of your step-wise probability plot.
    Shows how each top k-mer 'pushes' the probability from 0.5 to the final value.
    """
    fig = plt.figure(figsize=(8, 5))
    ax = fig.add_subplot(111)
    
    # Start from 0.5
    current_prob = 0.5
    steps = [('Start', current_prob, 0)]
    
    for kmer in important_kmers:
        change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
        current_prob += change
        steps.append((kmer['kmer'], current_prob, change))

    x_vals = range(len(steps))
    y_vals = [s[1] for s in steps]

    ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2)
    ax.plot(x_vals, y_vals, 'b.', markersize=10)
    
    # Reference line at 0.5
    ax.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
    ax.set_ylim(0, 1)
    ax.set_ylabel('Human Probability')
    ax.set_title(f'K-mer Contributions (final p={human_prob:.3f})')
    ax.grid(True, linestyle='--', alpha=0.7)

    for i, (kmer, prob, change) in enumerate(steps):
        ax.annotate(kmer, 
                    (i, prob),
                    xytext=(0, 10 if i % 2 == 0 else -20),
                    textcoords='offset points',
                    ha='center',
                    rotation=45)
        
        if i > 0:
            change_text = f'{change:+.3f}'
            color = 'green' if change > 0 else 'red'
            ax.annotate(change_text,
                        (i, prob),
                        xytext=(0, -20 if i % 2 == 0 else 10),
                        textcoords='offset points',
                        ha='center',
                        color=color)

    ax.legend()
    plt.tight_layout()
    return fig

def plot_kmer_freq_and_sigma(important_kmers):
    """
    Plot frequencies vs. sigma from mean for the top k-mers.
    This reuses logic from the original create_visualization second subplot,
    but as its own function for clarity.
    """
    fig, ax = plt.subplots(figsize=(8, 5))
    
    # Prepare data
    kmers = [k['kmer'] for k in important_kmers]
    frequencies = [k['occurrence'] for k in important_kmers]
    sigmas = [k['sigma'] for k in important_kmers]
    colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers]
    
    x = np.arange(len(kmers))
    width = 0.35
    
    # Frequency bars
    ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
    
    # Create a twin axis for sigma
    ax2 = ax.twinx()
    # Sigma bars
    ax2.bar(x + width/2, sigmas, width, label='σ from mean', 
            color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
    
    ax.set_xticks(x)
    ax.set_xticklabels(kmers, rotation=45)
    ax.set_ylabel('Frequency (%)')
    ax2.set_ylabel('Standard Deviations (σ) from Mean')
    ax.set_title("K-mer Frequencies & Statistical Significance")

    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='best')
    
    plt.tight_layout()
    return fig

# --------------- Main Prediction Logic ---------------

def predict_fasta(
    file_obj, 
    k_size=4, 
    top_k=10,
    advanced_analysis=False
):
    """
    Main function to predict classes for each sequence in an uploaded FASTA.
    Returns:
      - Combined textual report for all sequences
      - A list of generated PIL Image plots
    """
    # 1. Read raw text from file or string
    if file_obj is None:
        return "Please upload a FASTA file", []
    
    try:
        if isinstance(file_obj, str):
            text = file_obj
        else:
            text = file_obj.decode('utf-8', errors='replace')
    except Exception as e:
        return f"Error reading file: {str(e)}", []
    
    # 2. Parse the FASTA
    sequences = parse_fasta(text)
    if not sequences:
        return "No valid FASTA sequences found!", []
    
    # 3. Load model & scaler
    try:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = VirusClassifier(input_shape=(4 ** k_size)).to(device)
        state_dict = torch.load('model.pt', map_location=device)
        model.load_state_dict(state_dict)
        model.eval()
        
        scaler = joblib.load('scaler.pkl')
    except Exception as e:
        return f"Error loading model/scaler: {str(e)}", []
    
    # 4. Prepare k-mer dictionary for reference
    all_kmers = [''.join(p) for p in product("ACGT", repeat=k_size)]
    kmer_dict = {km: i for i, km in enumerate(all_kmers)}

    # 5. Iterate over sequences and build output
    final_text_report = []
    plots = []
    
    for idx, (header, seq) in enumerate(sequences, start=1):
        seq_stats = compute_sequence_stats(seq)
        
        # Convert sequence -> raw freq -> scaled freq
        raw_kmer_freq = sequence_to_kmer_vector(seq, k=k_size)
        scaled_kmer_freq = scaler.transform(raw_kmer_freq.reshape(1, -1))
        X_tensor = torch.FloatTensor(scaled_kmer_freq).to(device)
        
        # Predict
        with torch.no_grad():
            output = model(X_tensor)
            probs = torch.softmax(output, dim=1)
        
        # Determine class
        pred_class = torch.argmax(probs, dim=1).item()
        pred_label = 'human' if pred_class == 1 else 'non-human'
        human_prob = float(probs[0][1])
        non_human_prob = float(probs[0][0])
        confidence = float(torch.max(probs[0]).item())
        
        # Compute gradient-based importance
        importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1)
        importance = importance[0].cpu().numpy()  # shape: (num_features,)

        # Identify top-k features (by absolute gradient)
        abs_importance = np.abs(importance)
        sorted_indices = np.argsort(abs_importance)[::-1]
        top_indices = sorted_indices[:top_k]
        
        # Build a list of top k-mers
        top_kmers_info = []
        for i in top_indices:
            kmer_name = all_kmers[i]
            imp_val = float(importance[i])
            direction = 'human' if imp_val > 0 else 'non-human'
            freq_perc = float(raw_kmer_freq[i] * 100.0)  # in percent
            sigma = float(scaled_kmer_freq[0][i])  # This is the scaled value (stdev from mean if the scaler is StandardScaler)
            
            top_kmers_info.append({
                'kmer': kmer_name,
                'impact': abs(imp_val),
                'direction': direction,
                'occurrence': freq_perc,
                'sigma': sigma
            })
        
        # Text summary for this sequence
        seq_report = []
        seq_report.append(f"=== Sequence {idx} ===")
        seq_report.append(f"Header: {header}")
        seq_report.append(f"Length: {seq_stats['length']}")
        seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%")
        seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}")
        seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})")
        seq_report.append(f"  Human Probability: {human_prob:.4f}")
        seq_report.append(f"  Non-human Probability: {non_human_prob:.4f}")
        seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):")
        for tkm in top_kmers_info:
            seq_report.append(
                f"  {tkm['kmer']}: pushes towards {tkm['direction']} "
                f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, "
                f"sigma={tkm['sigma']:.2f}"
            )
        
        final_text_report.append("\n".join(seq_report))
        
        # 6. Generate Plots (for each sequence)
        if advanced_analysis:
            # 6A. SHAP-like bar chart
            fig_shap = plot_shap_like_bars(
                kmers=all_kmers, 
                importance_values=importance, 
                top_k=top_k
            )
            buf_shap = io.BytesIO()
            fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150)
            buf_shap.seek(0)
            plots.append(Image.open(buf_shap))
            plt.close(fig_shap)

            # 6B. k-mer distribution histogram
            fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers)
            buf_dist = io.BytesIO()
            fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150)
            buf_dist.seek(0)
            plots.append(Image.open(buf_dist))
            plt.close(fig_kmer_dist)
        
        # 6C. Original step visualization for top k k-mers
        # Sort by actual 'impact' to preserve that step logic
        # (largest absolute impact first)
        top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True)
        fig_step = create_step_visualization(top_kmers_info_step, human_prob)
        buf_step = io.BytesIO()
        fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150)
        buf_step.seek(0)
        plots.append(Image.open(buf_step))
        plt.close(fig_step)
        
        # 6D. Frequency vs. sigma bar chart
        fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step)
        buf_freq_sigma = io.BytesIO()
        fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150)
        buf_freq_sigma.seek(0)
        plots.append(Image.open(buf_freq_sigma))
        plt.close(fig_freq_sigma)
    
    # Combine all text results
    combined_text = "\n\n".join(final_text_report)
    return combined_text, plots

# --------------- Gradio Interface ---------------

def run_prediction(
    file_obj, 
    k_size, 
    top_k,
    advanced_analysis
):
    """
    Wrapper for Gradio to handle the outputs in (text, List[Image]) form.
    """
    text_output, pil_images = predict_fasta(
        file_obj=file_obj, 
        k_size=k_size,
        top_k=top_k,
        advanced_analysis=advanced_analysis
    )


    return text_output, pil_images


with gr.Blocks() as demo:
    gr.Markdown("# Virus Host Classifier (Improved!)")
    gr.Markdown(
        "Upload a FASTA file and configure k-mer size, number of top features, "
        "and whether to run advanced analysis (plots of SHAP-like bars & k-mer distribution)."
    )
    
    with gr.Row():
        with gr.Column():
            fasta_file = gr.File(label="Upload FASTA file", type="binary")
            kmer_slider = gr.Slider(minimum=2, maximum=6, value=4, step=1, label="K-mer Size")
            topk_slider = gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Top-k Features")
            advanced_check = gr.Checkbox(value=False, label="Advanced Analysis")
            predict_button = gr.Button("Predict")
            
        with gr.Column():
            results_text = gr.Textbox(
                label="Results", 
                lines=20, 
                placeholder="Prediction results will appear here..."
            )
    
    # We can display multiple images in a Gallery or as separate outputs.
    plots_gallery = gr.Gallery(label="Analysis Plots").style(grid=[2], height="auto")
    
    predict_button.click(
        fn=run_prediction,
        inputs=[fasta_file, kmer_slider, topk_slider, advanced_check],
        outputs=[results_text, plots_gallery]
    )

if __name__ == "__main__":
    demo.launch(share=True)