File size: 14,022 Bytes
5263bd3
 
 
 
 
 
4a7c026
 
40fe6da
5263bd3
8c49ca8
 
 
5263bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd8a58
 
8c49ca8
 
 
 
cdd8a58
 
40fe6da
cdd8a58
8c49ca8
40fe6da
6c88c65
 
40fe6da
8c49ca8
40fe6da
 
5263bd3
8c49ca8
 
 
870813f
8c49ca8
870813f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c49ca8
 
 
 
 
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c88c65
8c49ca8
6c88c65
8c49ca8
 
 
6c88c65
8c49ca8
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
6c88c65
8c49ca8
 
 
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c88c65
8c49ca8
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c88c65
8c49ca8
 
 
 
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c88c65
 
 
8c49ca8
 
 
 
63d967d
8c49ca8
 
 
 
 
 
 
63d967d
8c49ca8
870813f
8c49ca8
723da6d
 
 
 
 
 
4a7c026
723da6d
8c49ca8
6a3b036
cdd8a58
 
 
8c49ca8
723da6d
 
6a3b036
9d48283
 
723da6d
 
 
8c49ca8
723da6d
4a7c026
 
 
723da6d
8c49ca8
723da6d
8c49ca8
 
4a7c026
8c49ca8
 
 
4a7c026
 
 
8c49ca8
 
17c9ecb
 
 
4a7c026
8c49ca8
 
 
 
 
4a7c026
8c49ca8
6c88c65
8c49ca8
6c88c65
8c49ca8
 
 
 
 
 
 
6c88c65
8c49ca8
 
b0fba50
6c88c65
8c49ca8
 
6c88c65
 
 
 
8c49ca8
4a7c026
 
6c88c65
8c49ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a7c026
8c49ca8
6c88c65
8c49ca8
6c88c65
 
 
4a7c026
723da6d
8c49ca8
723da6d
4a7c026
5263bd3
8c49ca8
 
 
5263bd3
870813f
723da6d
6c88c65
8c49ca8
6c88c65
 
8c49ca8
 
 
 
 
 
 
5263bd3
 
723da6d
8c49ca8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
import torch.nn as nn
import matplotlib.pyplot as plt
import io
from PIL import Image

###############################################################################
# Model Definition
###############################################################################
class VirusClassifier(nn.Module):
    def __init__(self, input_shape: int):
        super(VirusClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_shape, 64),
            nn.GELU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.GELU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3),
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        return self.network(x)
    
    def get_feature_importance(self, x):
        """
        Calculate gradient-based feature importance.
        We'll compute the gradient of the 'human' probability w.r.t. the input vector.
        """
        x.requires_grad_(True)
        output = self.network(x)
        probs = torch.softmax(output, dim=1)
        
        # Gradient wrt 'human' class probability (index=1)
        human_prob = probs[..., 1]
        if x.grad is not None:
            x.grad.zero_()
        human_prob.backward()
        importance = x.grad  # shape: (batch_size, n_features)
        
        return importance, float(human_prob)

###############################################################################
# Utility Functions
###############################################################################
def parse_fasta(text):
    """Parses text input in FASTA format into a list of (header, sequence)."""
    sequences = []
    current_header = None
    current_sequence = []
    
    for line in text.split('\n'):
        line = line.strip()
        if not line:
            continue
        if line.startswith('>'):
            if current_header:
                sequences.append((current_header, ''.join(current_sequence)))
            current_header = line[1:]
            current_sequence = []
        else:
            current_sequence.append(line.upper())
    if current_header:
        sequences.append((current_header, ''.join(current_sequence)))
    return sequences

def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
    """Convert a single nucleotide sequence to a k-mer frequency vector."""
    kmers = [''.join(p) for p in product("ACGT", repeat=k)]
    kmer_dict = {km: i for i, km in enumerate(kmers)}
    vec = np.zeros(len(kmers), dtype=np.float32)
    
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i+k]
        if kmer in kmer_dict:
            vec[kmer_dict[kmer]] += 1

    total_kmers = len(sequence) - k + 1
    if total_kmers > 0:
        vec = vec / total_kmers  # normalize frequencies

    return vec


###############################################################################
# Visualization
###############################################################################
def create_visualization(important_kmers, human_prob, title):
    """
    Create a multi-panel figure showing:
    1) A waterfall-like plot for how each top k-mer shifts the probability from 0.5
       (the baseline) to the final 'human' probability.
    2) A side-by-side bar plot for frequency (%) and σ from mean for each important k-mer.
    """

    # Figure & GridSpec Layout
    fig = plt.figure(figsize=(14, 10))
    gs = plt.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1.2, 1], hspace=0.35, wspace=0.3)

    # -------------------------------------------------------------------------
    # 1. Waterfall-like Plot (top-left subplot)
    # -------------------------------------------------------------------------
    ax_waterfall = plt.subplot(gs[0, 0])

    # Start from baseline prob=0.5
    baseline = 0.5
    current_prob = baseline
    steps = [("Baseline", current_prob, 0.0)]
    
    # Build up the step changes
    for kmer in important_kmers:
        direction_multiplier = 1 if kmer["direction"] == "human" else -1
        change = kmer["impact"] * 0.05 * direction_multiplier  
        # ^ scale changes so that the sum doesn't overshadow the final probability.
        current_prob += change
        steps.append((kmer["kmer"], current_prob, change))
    
    # X-values for step plot
    x_vals = range(len(steps))
    y_vals = [s[1] for s in steps]

    ax_waterfall.step(x_vals, y_vals, where='post', color='blue', linewidth=2, label='Probability')
    ax_waterfall.plot(x_vals, y_vals, 'b.', markersize=8)

    # Reference lines
    ax_waterfall.axhline(y=baseline, color='gray', linestyle='--', label='Baseline=0.5')

    # Annotate each step
    for i, (kmer, prob, change) in enumerate(steps):
        if i == 0:  # baseline
            ax_waterfall.annotate(kmer, (i, prob), textcoords="offset points", xytext=(0, -15), ha='center', color='black')
            continue
        
        color = "green" if change > 0 else "red"
        ax_waterfall.annotate(
            f"{kmer}\n({change:+.3f})",
            (i, prob),
            textcoords="offset points",
            xytext=(0, -15),
            ha='center',
            color=color,
            fontsize=9
        )

    ax_waterfall.set_ylim(0, 1)
    ax_waterfall.set_xlabel("k-mer Step")
    ax_waterfall.set_ylabel("Running Probability (Human)")
    ax_waterfall.set_title(f"K-mer Waterfall Plot — Final Probability: {human_prob:.3f}")
    ax_waterfall.grid(alpha=0.3)
    ax_waterfall.legend()

    # -------------------------------------------------------------------------
    # 2. Frequency & σ from Mean (top-right subplot)
    # -------------------------------------------------------------------------
    ax_bar = plt.subplot(gs[0, 1])

    kmers = [k["kmer"] for k in important_kmers]
    frequencies = [k["occurrence"] for k in important_kmers]  # in %
    sigmas = [k["sigma"] for k in important_kmers]
    directions = [k["direction"] for k in important_kmers]
    
    # X-locations
    x = np.arange(len(kmers))
    width = 0.4

    # We will create twin axes: one for frequency, one for σ
    bars1 = ax_bar.bar(x - width/2, frequencies, width, label='Frequency (%)', 
                       alpha=0.7, color=['green' if d=='human' else 'red' for d in directions])
    ax_bar.set_ylabel("Frequency (%)")
    ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
    ax_bar.set_title("Frequency vs. σ from Mean")

    # Twin axis for σ
    ax_bar_twin = ax_bar.twinx()
    bars2 = ax_bar_twin.bar(x + width/2, sigmas, width, label='σ from Mean', 
                            alpha=0.5, color='gray')
    ax_bar_twin.set_ylabel("Standard Deviations (σ)")

    ax_bar.set_xticks(x)
    ax_bar.set_xticklabels(kmers, rotation=45, ha='right', fontsize=9)
    
    # Combine legends
    lines1, labels1 = ax_bar.get_legend_handles_labels()
    lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
    ax_bar.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    
    # -------------------------------------------------------------------------
    # 3. Top Feature Importances (Bottom, spanning both columns)
    # -------------------------------------------------------------------------
    ax_imp = plt.subplot(gs[1, :])

    # Sort by absolute impact
    sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
    top_kmer_labels = [k['kmer'] for k in sorted_kmers]
    top_kmer_impacts = [k['impact'] for k in sorted_kmers]
    top_kmer_dirs = [k['direction'] for k in sorted_kmers]

    x_imp = np.arange(len(top_kmer_impacts))
    bar_colors = ['green' if d == 'human' else 'red' for d in top_kmer_dirs]

    ax_imp.bar(x_imp, top_kmer_impacts, color=bar_colors, alpha=0.7)
    ax_imp.set_xticks(x_imp)
    ax_imp.set_xticklabels(top_kmer_labels, rotation=45, ha='right', fontsize=9)
    ax_imp.set_title("Absolute Feature Importance (Top k-mers)")
    ax_imp.set_ylabel("Importance (gradient magnitude)")
    ax_imp.grid(alpha=0.3, axis='y')

    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    return fig


###############################################################################
# Prediction Function
###############################################################################
def predict(file_obj):
    """
    Main function that Gradio will call:
      1. Reads the uploaded FASTA file (or text).
      2. Loads the model and scaler.
      3. Generates predictions, probabilities, and top k-mers.
      4. Creates a summary text and a matplotlib figure for visualization.
    """
    if file_obj is None:
        return "Please upload a FASTA file.", None
    
    # Read text from file
    try:
        if isinstance(file_obj, str):
            text = file_obj
        else:
            text = file_obj.decode('utf-8')
    except Exception as e:
        return f"Error reading file: {str(e)}", None

    # Build k-mer dictionary
    k = 4
    kmers = [''.join(p) for p in product("ACGT", repeat=k)]
    kmer_dict = {km: i for i, km in enumerate(kmers)}
    
    # Load model & scaler
    try:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = VirusClassifier(256).to(device)
        state_dict = torch.load('model.pt', map_location=device)
        model.load_state_dict(state_dict)
        scaler = joblib.load('scaler.pkl')
        model.eval()
    except Exception as e:
        return f"Error loading model or scaler: {str(e)}", None

    results_text = ""
    plot_image = None
    
    try:
        # Parse FASTA
        sequences = parse_fasta(text)
        if len(sequences) == 0:
            return "No valid FASTA sequences found. Please check your input.", None
        
        header, seq = sequences[0]  # For simplicity, we'll only classify the first sequence

        # Transform sequence to scaled k-mer vector
        raw_freq_vector = sequence_to_kmer_vector(seq)
        kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
        X_tensor = torch.FloatTensor(kmer_vector).to(device)

        # Inference
        with torch.no_grad():
            output = model(X_tensor)
            probs = torch.softmax(output, dim=1)
        
        # Feature Importance
        importance, hum_prob_grad = model.get_feature_importance(X_tensor)
        kmer_importance = importance[0].cpu().numpy()  # shape: (256,)

        # Top k-mers by absolute importance
        top_k = 10
        top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]  # largest -> smallest
        important_kmers = []
        
        for idx in top_indices:
            # find corresponding k-mer by index
            for kmer_str, i_ in kmer_dict.items():
                if i_ == idx:
                    kmer_name = kmer_str
                    break
            
            imp_val = float(abs(kmer_importance[idx]))
            direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
            freq = float(raw_freq_vector[idx] * 100)  # frequency in %
            sigma = float(kmer_vector[0][idx])  # scaled value (Z-score if standard scaler)
            
            important_kmers.append({
                'kmer': kmer_name,
                'impact': imp_val,
                'direction': direction,
                'occurrence': freq,
                'sigma': sigma
            })

        pred_class = 1 if probs[0][1] > probs[0][0] else 0
        pred_label = 'human' if pred_class == 1 else 'non-human'
        human_prob = float(probs[0][1])
        non_human_prob = float(probs[0][0])
        conf = float(max(probs[0]))  # confidence in the predicted class

        # Generate text results
        results_text = (
            f"**Sequence Header**: {header}\n\n"
            f"**Predicted Label**: {pred_label}\n"
            f"**Confidence**: {conf:.4f}\n\n"
            f"**Human Probability**: {human_prob:.4f}\n"
            f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
            "### Most Influential k-mers:\n"
        )
        for k in important_kmers:
            direction_text = f"pushes toward {k['direction']}"
            occurrence_text = f"{k['occurrence']:.2f}% of sequence"
            sigma_text = f"{abs(k['sigma']):.2f}σ " + ("above" if k['sigma'] > 0 else "below") + " mean"
            results_text += (
                f"- **{k['kmer']}**: "
                f"impact = {k['impact']:.4f}, {direction_text}, "
                f"occurrence = {occurrence_text}, "
                f"({sigma_text})\n"
            )

        # Create figure
        fig = create_visualization(important_kmers, human_prob, f"{header}")
        
        # Convert figure to image
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
        buf.seek(0)
        plot_image = Image.open(buf)
        plt.close(fig)
        
    except Exception as e:
        return f"Error during prediction or visualization: {str(e)}", None

    return results_text, plot_image

###############################################################################
# Gradio Interface
###############################################################################
iface = gr.Interface(
    fn=predict,
    inputs=gr.File(label="Upload FASTA file", type="binary"),
    outputs=[
        gr.Markdown(label="Prediction Results"),
        gr.Image(label="K-mer Analysis Visualization")
    ],
    title="Virus Host Classifier",
    description=(
        "Upload a FASTA file containing a single nucleotide sequence. "
        "This model will predict whether the virus host is **human** or **non-human**, "
        "provide a confidence score, and highlight the most influential k-mers in the classification."
    ),
    allow_flagging="never",
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860, share=True)