File size: 11,699 Bytes
98950ac
f321ade
41f7b15
98c48e9
 
 
 
 
 
 
9976695
 
02a9726
5b50998
02a9726
5b50998
 
 
 
 
 
 
 
 
 
 
98c48e9
98950ac
 
 
 
 
 
 
 
 
 
 
 
 
 
f321ade
41f7b15
f321ade
 
 
 
41f7b15
 
f321ade
 
 
 
 
 
 
 
41f7b15
f321ade
 
 
 
 
41f7b15
 
f321ade
 
41f7b15
f321ade
 
 
 
 
 
 
 
 
 
98950ac
 
ef67a66
b8175a8
 
 
 
ef67a66
b8175a8
 
 
 
ef67a66
e557ff0
 
41f7b15
 
b8175a8
5b50998
b8175a8
e557ff0
 
 
 
f321ade
 
e557ff0
 
 
b8175a8
41f7b15
b8175a8
 
 
 
86d28da
b8175a8
 
 
86d28da
b8175a8
 
 
 
98c48e9
 
 
 
 
 
 
b8175a8
 
 
 
 
 
 
 
98c48e9
b8175a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86d28da
b8175a8
 
 
 
 
 
 
5b50998
98c48e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9976695
 
 
304c449
 
9976695
304c449
 
9976695
304c449
9976695
 
304c449
 
 
9976695
 
 
 
9213ba3
9976695
 
 
 
 
 
 
 
9213ba3
9976695
 
 
 
 
 
 
9213ba3
 
 
 
 
304c449
9213ba3
304c449
9213ba3
304c449
9213ba3
 
 
 
304c449
9976695
9213ba3
 
9976695
a5b0df3
98950ac
 
9976695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213ba3
 
 
9976695
 
 
 
 
 
371f76e
 
9976695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9213ba3
9976695
 
 
98950ac
02a9726
7fccc04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from Bio.PDB import MMCIFParser, PDBIO
from folding_studio.client import Client
from folding_studio.query.boltz import BoltzQuery, BoltzParameters
from pathlib import Path
import gradio as gr
import hashlib
import logging
import numpy as np
import os
import plotly.graph_objects as go
import pandas as pd
from scipy.stats import spearmanr

from molecule import molecule

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
    ]
)
logger = logging.getLogger(__name__)


def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
    """Convert a .cif file to .pdb format using Biopython.
    
    Args:
        cif_path (str): Path to input .cif file
        pdb_path (str): Path to output .pdb file
    """
    # Parse the CIF file
    parser = MMCIFParser()
    structure = parser.get_structure("structure", cif_path)
    
    # Save as PDB
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_path)

def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None:
    """Call Boltz prediction."""
    # Initialize parameters with CLI-provided values
    parameters = {
        "recycling_steps": 3,
        "sampling_steps": 200,
        "diffusion_samples": 1,
        "step_scale": 1.638,
        "msa_pairing_strategy": "greedy",
        "write_full_pae": False,
        "write_full_pde": False,
        "use_msa_server": True,
        "seed": 0,
        "custom_msa_paths": None,
    }
    
    # Create a client using API key
    logger.info("Authenticating client with API key")
    client = Client.from_api_key(api_key=api_key)

    # Define query
    seq_file = Path(seq_file)
    query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters))
    query.save_parameters(output_dir)

    logger.info("Payload: %s", query.payload)

    # Send a request
    logger.info("Sending request to Folding Studio API")
    response = client.send_request(query, project_code=os.environ["FOLDING_PROJECT_CODE"])

    # Access confidence data
    logger.info("Confidence data: %s", response.confidence_data)

    response.download_results(output_dir=output_dir, force=True, unzip=True)
    logger.info("Results downloaded to %s", output_dir)


def predict(sequence: str, api_key: str) -> str:
    """Predict protein structure from amino acid sequence using Boltz model.
    
    Args:
        sequence (str): Amino acid sequence to predict structure for
        api_key (str): Folding API key
        
    Returns:
        str: HTML iframe containing 3D molecular visualization
    """
    
    # Set up unique output directory based on sequence hash
    seq_id = hashlib.sha1(sequence.encode()).hexdigest()
    seq_file = Path(f"sequence_{seq_id}.fasta")
    _write_fasta_file(seq_file, sequence)
    output_dir = Path(f"sequence_{seq_id}")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Check if prediction already exists
    pred_cif = list(output_dir.rglob("*_model_0.cif"))
    if not pred_cif:
        # Run Boltz prediction
        logger.info(f"Predicting {seq_file.stem}")
        call_boltz(seq_file=seq_file, api_key=api_key, output_dir=output_dir)
        logger.info("Prediction done. Output directory: %s", output_dir)
    else:
        logger.info("Prediction already exists. Output directory: %s", output_dir)

    # output_dir = Path("boltz_results") # debug
    # Convert output CIF to PDB
    pred_cif = list(output_dir.rglob("*_model_0.cif"))[0]
    logger.info("Output file: %s", pred_cif)
    
    converted_pdb_path = str(output_dir / "pred.pdb")
    convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path))
    logger.info("Converted PDB file: %s", converted_pdb_path)


    # Generate molecular visualization
    mol = _create_molecule_visualization(
        converted_pdb_path,
        sequence,
    )    

    plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0]
    logger.info("plddt file: %s", plddt_file)
    plddt_vals = np.load(plddt_file)["plddt"]

    return _wrap_in_iframe(mol), add_plddt_plot(plddt_vals=plddt_vals)


def _write_fasta_file(filepath: Path, sequence: str) -> None:
    """Write sequence to FASTA file."""
    with open(filepath, "w") as f:
        f.write(f">A|protein\n{sequence}")


def _create_molecule_visualization(pdb_path: Path, sequence: str) -> str:
    """Create molecular visualization using molecule module."""
    return molecule(
        str(pdb_path),
        lenSeqs=1,
        num_res=len(sequence),
        selectedResidues=list(range(1, len(sequence) + 1)),
        allSeqs=[sequence],
        sequences=[{
            "Score": 0,
            "RMSD": 0, 
            "Recovery": 0,
            "Mean pLDDT": 0,
            "seq": sequence
        }],
    )


def _wrap_in_iframe(content: str) -> str:
    """Wrap content in an HTML iframe with appropriate styling and permissions."""
    return f"""<iframe 
        name="result" 
        style="width: 100%; height: 100vh;"
        allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
        sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups allow-top-navigation-by-user-activation allow-downloads"
        allowfullscreen=""
        allowpaymentrequest=""
        frameborder="0"
        srcdoc='{content}'
    ></iframe>"""

def add_plddt_plot(plddt_vals: list[float]) -> str:
    """Create a plot of metrics."""
    visible = True
    plddt_trace = go.Scatter(
            x=np.arange(len(plddt_vals)),
            y=plddt_vals,
            hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
            name="seq",
            visible=visible,
        )
    
    plddt_fig = go.Figure(data=[plddt_trace])
    plddt_fig.update_layout(
        title="pLDDT",
        xaxis_title="Residue index",
        yaxis_title="pLDDT",
        height=500,
        template="simple_white",
        legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
    )
    return plddt_fig

def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
    """Fake predict structures of all complexes and correlate the results."""
    corr_data = []
    spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
    kd_col = "KD (nM)"
    for score_col in score_cols:
        logger.info(f"Computing correlation between {score_col} and KD (nM)")
        res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
        corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
        logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")

    corr_data = pd.DataFrame(corr_data)
    # Find the lines in corr_data with NaN values and remove them
    corr_data = corr_data[corr_data["correlation"].notna()]
    logger.info("Correlation data: %s", corr_data)
    # Sort correlation data by correlation value
    corr_data = corr_data.sort_values('correlation', ascending=True)
    
    # Create bar plot of correlations
    corr_ranking_plot = go.Figure(data=[
        go.Bar(
            x=corr_data["correlation"],
            y=corr_data["score"], 
            name="correlation",
            orientation='h',
            hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
        )
    ])
    corr_ranking_plot.update_layout(
        title="Correlation with Binding Affinity",
        yaxis_title="Score Type",
        xaxis_title="Spearman Correlation",
        template="simple_white",
        showlegend=False
    )

    # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
    scatters = []
    for score_col in score_cols:
        scatters.append(
            go.Scatter(
                x=spr_data_with_scores[kd_col],
                y=spr_data_with_scores[score_col],
                name=f"{kd_col} vs {score_col}",
                mode='markers',  # Only show markers/dots, no lines
                hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
            )   
        )
    corr_plot = go.Figure(data=scatters)
    
    cols_to_show = [kd_col]
    cols_to_show.extend(score_cols)

    return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot

demo = gr.Blocks(title="Folding Studio: structure prediction with Boltz-1")

with demo:
    with gr.Tabs() as tabs:
        with gr.TabItem("Inference"):
            gr.Markdown("# Input")
            with gr.Row():
                with gr.Column():
                    sequence = gr.Textbox(label="Sequence", value="")
                    api_key = gr.Textbox(label="Folding API Key", type="password")
            gr.Markdown("# Output")
            with gr.Row():
                predict_btn = gr.Button("Predict")
            with gr.Row():
                with gr.Column():
                    mol_output = gr.HTML()
                with gr.Column():
                    metrics_plot = gr.Plot(label="pLDDT")
                
            predict_btn.click(
                fn=predict,
                inputs=[sequence, api_key],
                outputs=[mol_output, metrics_plot]
            )
        with gr.TabItem("Correlations"):
            gr.Markdown("# Upload binding affinity data")
            spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
            with gr.Row():
                csv_file = gr.File(label="Upload CSV file", file_types=[".csv"])
            with gr.Row():
                dataframe = gr.Dataframe(label="Binding Affinity Data")

            gr.Markdown("# Prediction and correlation")
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        fake_predict_btn = gr.Button("Predict structures of all complexes")
                    with gr.Row():
                        prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
                with gr.Column():
                    correlation_ranking_plot = gr.Plot(label="Correlation ranking")
                    correlation_plot = gr.Plot(label="Correlation with binding affinity")

            

            cols = [
                "confidence_score_boltz",
                "ptm_boltz",
                "iptm_boltz", 
                # "ligand_iptm_boltz",
                # "protein_iptm_boltz",
                "complex_plddt_boltz",
                "complex_iplddt_boltz",
                "complex_pde_boltz",
                "complex_ipde_boltz",
                "interchain_pae_monomer",
                "interface_pae_monomer",
                "overall_pae_monomer",
                "interface_plddt_monomer",
                "average_plddt_monomer",
                "ptm_monomer",
                "interface_ptm_monomer", 
                "interchain_pae_multimer",
                "interface_pae_multimer",
                "overall_pae_multimer",
                "interface_plddt_multimer",
                "average_plddt_multimer",
                "ptm_multimer",
                "interface_ptm_multimer"
            ]
            csv_file.change(
                fn=lambda file: spr_data_with_scores.drop(columns=cols) if file else None,
                inputs=csv_file,
                outputs=dataframe
            )
            fake_predict_btn.click(
                fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, cols),
                inputs=None,
                outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
            )

    

demo.launch()