|
from Bio.PDB import MMCIFParser, PDBIO |
|
from folding_studio.client import Client |
|
from folding_studio.query.boltz import BoltzQuery, BoltzParameters |
|
from pathlib import Path |
|
import gradio as gr |
|
import hashlib |
|
import logging |
|
import numpy as np |
|
import os |
|
import plotly.graph_objects as go |
|
|
|
from molecule import molecule |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(), |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: |
|
"""Convert a .cif file to .pdb format using Biopython. |
|
|
|
Args: |
|
cif_path (str): Path to input .cif file |
|
pdb_path (str): Path to output .pdb file |
|
""" |
|
|
|
parser = MMCIFParser() |
|
structure = parser.get_structure("structure", cif_path) |
|
|
|
|
|
io = PDBIO() |
|
io.set_structure(structure) |
|
io.save(pdb_path) |
|
|
|
def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None: |
|
"""Call Boltz prediction.""" |
|
|
|
parameters = { |
|
"recycling_steps": 3, |
|
"sampling_steps": 200, |
|
"diffusion_samples": 1, |
|
"step_scale": 1.638, |
|
"msa_pairing_strategy": "greedy", |
|
"write_full_pae": False, |
|
"write_full_pde": False, |
|
"use_msa_server": True, |
|
"seed": 0, |
|
"custom_msa_paths": None, |
|
} |
|
|
|
|
|
logger.info("Authenticating client with API key") |
|
client = Client.from_api_key(api_key=api_key) |
|
|
|
|
|
seq_file = Path(seq_file) |
|
query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters)) |
|
query.save_parameters(output_dir) |
|
|
|
logger.info("Payload: %s", query.payload) |
|
|
|
|
|
logger.info("Sending request to Folding Studio API") |
|
response = client.send_request(query, project_code=os.environ["FOLDING_PROJECT_CODE"]) |
|
|
|
|
|
logger.info("Confidence data: %s", response.confidence_data) |
|
|
|
response.download_results(output_dir=output_dir, force=True, unzip=True) |
|
logger.info("Results downloaded to %s", output_dir) |
|
|
|
|
|
def predict(sequence: str, api_key: str) -> str: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
|
|
Returns: |
|
str: HTML iframe containing 3D molecular visualization |
|
""" |
|
|
|
|
|
seq_id = hashlib.sha1(sequence.encode()).hexdigest() |
|
seq_file = Path(f"sequence_{seq_id}.fasta") |
|
_write_fasta_file(seq_file, sequence) |
|
output_dir = Path(f"sequence_{seq_id}") |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
pred_cif = list(output_dir.rglob("*_model_0.cif")) |
|
if not pred_cif: |
|
|
|
logger.info(f"Predicting {seq_file.stem}") |
|
call_boltz(seq_file=seq_file, api_key=api_key, output_dir=output_dir) |
|
logger.info("Prediction done. Output directory: %s", output_dir) |
|
else: |
|
logger.info("Prediction already exists. Output directory: %s", output_dir) |
|
|
|
|
|
|
|
pred_cif = list(output_dir.rglob("*_model_0.cif"))[0] |
|
logger.info("Output file: %s", pred_cif) |
|
|
|
converted_pdb_path = str(output_dir / "pred.pdb") |
|
convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path)) |
|
logger.info("Converted PDB file: %s", converted_pdb_path) |
|
|
|
|
|
|
|
mol = _create_molecule_visualization( |
|
converted_pdb_path, |
|
sequence, |
|
) |
|
|
|
plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0] |
|
logger.info("plddt file: %s", plddt_file) |
|
plddt_vals = np.load(plddt_file)["plddt"] |
|
|
|
return _wrap_in_iframe(mol), add_plddt_plot(plddt_vals=plddt_vals) |
|
|
|
|
|
def _write_fasta_file(filepath: Path, sequence: str) -> None: |
|
"""Write sequence to FASTA file.""" |
|
with open(filepath, "w") as f: |
|
f.write(f">A|protein\n{sequence}") |
|
|
|
|
|
def _create_molecule_visualization(pdb_path: Path, sequence: str) -> str: |
|
"""Create molecular visualization using molecule module.""" |
|
return molecule( |
|
str(pdb_path), |
|
lenSeqs=1, |
|
num_res=len(sequence), |
|
selectedResidues=list(range(1, len(sequence) + 1)), |
|
allSeqs=[sequence], |
|
sequences=[{ |
|
"Score": 0, |
|
"RMSD": 0, |
|
"Recovery": 0, |
|
"Mean pLDDT": 0, |
|
"seq": sequence |
|
}], |
|
) |
|
|
|
|
|
def _wrap_in_iframe(content: str) -> str: |
|
"""Wrap content in an HTML iframe with appropriate styling and permissions.""" |
|
return f"""<iframe |
|
name="result" |
|
style="width: 100%; height: 100vh;" |
|
allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" |
|
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups allow-top-navigation-by-user-activation allow-downloads" |
|
allowfullscreen="" |
|
allowpaymentrequest="" |
|
frameborder="0" |
|
srcdoc='{content}' |
|
></iframe>""" |
|
|
|
def add_plddt_plot(plddt_vals: list[float]) -> str: |
|
"""Create a plot of metrics.""" |
|
visible = True |
|
plddt_trace = go.Scatter( |
|
x=np.arange(len(plddt_vals)), |
|
y=plddt_vals, |
|
hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>", |
|
name="seq", |
|
visible=visible, |
|
) |
|
|
|
plddt_fig = go.Figure(data=[plddt_trace]) |
|
plddt_fig.update_layout( |
|
title="pLDDT", |
|
xaxis_title="Residue index", |
|
yaxis_title="pLDDT", |
|
height=500, |
|
template="simple_white", |
|
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), |
|
) |
|
return plddt_fig |
|
|
|
demo = gr.Blocks(title="Folding Studio: structure prediction with Boltz-1") |
|
|
|
with demo: |
|
gr.Markdown("# Input") |
|
with gr.Row(): |
|
with gr.Column(): |
|
sequence = gr.Textbox(label="Sequence", value="") |
|
api_key = gr.Textbox(label="Folding API Key", type="password") |
|
gr.Markdown("# Output") |
|
with gr.Row(): |
|
predict_btn = gr.Button("Predict") |
|
with gr.Row(): |
|
with gr.Column(): |
|
mol_output = gr.HTML() |
|
with gr.Column(): |
|
metrics_plot = gr.Plot(label="pLDDT") |
|
|
|
predict_btn.click( |
|
fn=predict, |
|
inputs=[sequence, api_key], |
|
outputs=[mol_output, metrics_plot] |
|
) |
|
|
|
demo.launch() |
|
|
|
|