File size: 6,732 Bytes
98950ac
f321ade
41f7b15
98c48e9
 
 
 
 
 
 
02a9726
5b50998
02a9726
5b50998
 
 
 
 
 
 
 
 
 
 
98c48e9
98950ac
 
 
 
 
 
 
 
 
 
 
 
 
 
f321ade
41f7b15
f321ade
 
 
 
41f7b15
 
f321ade
 
 
 
 
 
 
 
41f7b15
f321ade
 
 
 
 
41f7b15
 
f321ade
 
41f7b15
f321ade
 
 
 
 
 
 
 
 
 
98950ac
 
ef67a66
b8175a8
 
 
 
ef67a66
b8175a8
 
 
 
ef67a66
e557ff0
 
41f7b15
 
b8175a8
5b50998
b8175a8
e557ff0
 
 
 
f321ade
 
e557ff0
 
 
b8175a8
41f7b15
b8175a8
 
 
 
86d28da
b8175a8
 
 
86d28da
b8175a8
 
 
 
98c48e9
 
 
 
 
 
 
b8175a8
 
 
 
 
 
 
 
98c48e9
b8175a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86d28da
b8175a8
 
 
 
 
 
 
5b50998
98c48e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5b0df3
98950ac
 
 
 
 
b8175a8
ef67a66
98950ac
 
 
a5b0df3
98c48e9
 
 
 
 
98950ac
 
ef67a66
98c48e9
98950ac
 
02a9726
7fccc04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from Bio.PDB import MMCIFParser, PDBIO
from folding_studio.client import Client
from folding_studio.query.boltz import BoltzQuery, BoltzParameters
from pathlib import Path
import gradio as gr
import hashlib
import logging
import numpy as np
import os
import plotly.graph_objects as go

from molecule import molecule

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
    ]
)
logger = logging.getLogger(__name__)


def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
    """Convert a .cif file to .pdb format using Biopython.
    
    Args:
        cif_path (str): Path to input .cif file
        pdb_path (str): Path to output .pdb file
    """
    # Parse the CIF file
    parser = MMCIFParser()
    structure = parser.get_structure("structure", cif_path)
    
    # Save as PDB
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_path)

def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None:
    """Call Boltz prediction."""
    # Initialize parameters with CLI-provided values
    parameters = {
        "recycling_steps": 3,
        "sampling_steps": 200,
        "diffusion_samples": 1,
        "step_scale": 1.638,
        "msa_pairing_strategy": "greedy",
        "write_full_pae": False,
        "write_full_pde": False,
        "use_msa_server": True,
        "seed": 0,
        "custom_msa_paths": None,
    }
    
    # Create a client using API key
    logger.info("Authenticating client with API key")
    client = Client.from_api_key(api_key=api_key)

    # Define query
    seq_file = Path(seq_file)
    query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters))
    query.save_parameters(output_dir)

    logger.info("Payload: %s", query.payload)

    # Send a request
    logger.info("Sending request to Folding Studio API")
    response = client.send_request(query, project_code=os.environ["FOLDING_PROJECT_CODE"])

    # Access confidence data
    logger.info("Confidence data: %s", response.confidence_data)

    response.download_results(output_dir=output_dir, force=True, unzip=True)
    logger.info("Results downloaded to %s", output_dir)


def predict(sequence: str, api_key: str) -> str:
    """Predict protein structure from amino acid sequence using Boltz model.
    
    Args:
        sequence (str): Amino acid sequence to predict structure for
        api_key (str): Folding API key
        
    Returns:
        str: HTML iframe containing 3D molecular visualization
    """
    
    # Set up unique output directory based on sequence hash
    seq_id = hashlib.sha1(sequence.encode()).hexdigest()
    seq_file = Path(f"sequence_{seq_id}.fasta")
    _write_fasta_file(seq_file, sequence)
    output_dir = Path(f"sequence_{seq_id}")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Check if prediction already exists
    pred_cif = list(output_dir.rglob("*_model_0.cif"))
    if not pred_cif:
        # Run Boltz prediction
        logger.info(f"Predicting {seq_file.stem}")
        call_boltz(seq_file=seq_file, api_key=api_key, output_dir=output_dir)
        logger.info("Prediction done. Output directory: %s", output_dir)
    else:
        logger.info("Prediction already exists. Output directory: %s", output_dir)

    # output_dir = Path("boltz_results") # debug
    # Convert output CIF to PDB
    pred_cif = list(output_dir.rglob("*_model_0.cif"))[0]
    logger.info("Output file: %s", pred_cif)
    
    converted_pdb_path = str(output_dir / "pred.pdb")
    convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path))
    logger.info("Converted PDB file: %s", converted_pdb_path)


    # Generate molecular visualization
    mol = _create_molecule_visualization(
        converted_pdb_path,
        sequence,
    )    

    plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0]
    logger.info("plddt file: %s", plddt_file)
    plddt_vals = np.load(plddt_file)["plddt"]

    return _wrap_in_iframe(mol), add_plddt_plot(plddt_vals=plddt_vals)


def _write_fasta_file(filepath: Path, sequence: str) -> None:
    """Write sequence to FASTA file."""
    with open(filepath, "w") as f:
        f.write(f">A|protein\n{sequence}")


def _create_molecule_visualization(pdb_path: Path, sequence: str) -> str:
    """Create molecular visualization using molecule module."""
    return molecule(
        str(pdb_path),
        lenSeqs=1,
        num_res=len(sequence),
        selectedResidues=list(range(1, len(sequence) + 1)),
        allSeqs=[sequence],
        sequences=[{
            "Score": 0,
            "RMSD": 0, 
            "Recovery": 0,
            "Mean pLDDT": 0,
            "seq": sequence
        }],
    )


def _wrap_in_iframe(content: str) -> str:
    """Wrap content in an HTML iframe with appropriate styling and permissions."""
    return f"""<iframe 
        name="result" 
        style="width: 100%; height: 100vh;"
        allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
        sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups allow-top-navigation-by-user-activation allow-downloads"
        allowfullscreen=""
        allowpaymentrequest=""
        frameborder="0"
        srcdoc='{content}'
    ></iframe>"""

def add_plddt_plot(plddt_vals: list[float]) -> str:
    """Create a plot of metrics."""
    visible = True
    plddt_trace = go.Scatter(
            x=np.arange(len(plddt_vals)),
            y=plddt_vals,
            hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
            name="seq",
            visible=visible,
        )
    
    plddt_fig = go.Figure(data=[plddt_trace])
    plddt_fig.update_layout(
        title="pLDDT",
        xaxis_title="Residue index",
        yaxis_title="pLDDT",
        height=500,
        template="simple_white",
        legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
    )
    return plddt_fig

demo = gr.Blocks(title="Folding Studio: structure prediction with Boltz-1")

with demo:
    gr.Markdown("# Input")
    with gr.Row():
        with gr.Column():
            sequence = gr.Textbox(label="Sequence", value="")
            api_key = gr.Textbox(label="Folding API Key", type="password")
    gr.Markdown("# Output")
    with gr.Row():
        predict_btn = gr.Button("Predict")
    with gr.Row():
        with gr.Column():
            mol_output = gr.HTML()
        with gr.Column():
            metrics_plot = gr.Plot(label="pLDDT")
           
    predict_btn.click(
        fn=predict,
        inputs=[sequence, api_key],
        outputs=[mol_output, metrics_plot]
    )

demo.launch()