jfaustin's picture
remove duplicate score
371f76e
raw
history blame
11.7 kB
from Bio.PDB import MMCIFParser, PDBIO
from folding_studio.client import Client
from folding_studio.query.boltz import BoltzQuery, BoltzParameters
from pathlib import Path
import gradio as gr
import hashlib
import logging
import numpy as np
import os
import plotly.graph_objects as go
import pandas as pd
from scipy.stats import spearmanr
from molecule import molecule
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
]
)
logger = logging.getLogger(__name__)
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
"""Convert a .cif file to .pdb format using Biopython.
Args:
cif_path (str): Path to input .cif file
pdb_path (str): Path to output .pdb file
"""
# Parse the CIF file
parser = MMCIFParser()
structure = parser.get_structure("structure", cif_path)
# Save as PDB
io = PDBIO()
io.set_structure(structure)
io.save(pdb_path)
def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None:
"""Call Boltz prediction."""
# Initialize parameters with CLI-provided values
parameters = {
"recycling_steps": 3,
"sampling_steps": 200,
"diffusion_samples": 1,
"step_scale": 1.638,
"msa_pairing_strategy": "greedy",
"write_full_pae": False,
"write_full_pde": False,
"use_msa_server": True,
"seed": 0,
"custom_msa_paths": None,
}
# Create a client using API key
logger.info("Authenticating client with API key")
client = Client.from_api_key(api_key=api_key)
# Define query
seq_file = Path(seq_file)
query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters))
query.save_parameters(output_dir)
logger.info("Payload: %s", query.payload)
# Send a request
logger.info("Sending request to Folding Studio API")
response = client.send_request(query, project_code=os.environ["FOLDING_PROJECT_CODE"])
# Access confidence data
logger.info("Confidence data: %s", response.confidence_data)
response.download_results(output_dir=output_dir, force=True, unzip=True)
logger.info("Results downloaded to %s", output_dir)
def predict(sequence: str, api_key: str) -> str:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
Returns:
str: HTML iframe containing 3D molecular visualization
"""
# Set up unique output directory based on sequence hash
seq_id = hashlib.sha1(sequence.encode()).hexdigest()
seq_file = Path(f"sequence_{seq_id}.fasta")
_write_fasta_file(seq_file, sequence)
output_dir = Path(f"sequence_{seq_id}")
output_dir.mkdir(parents=True, exist_ok=True)
# Check if prediction already exists
pred_cif = list(output_dir.rglob("*_model_0.cif"))
if not pred_cif:
# Run Boltz prediction
logger.info(f"Predicting {seq_file.stem}")
call_boltz(seq_file=seq_file, api_key=api_key, output_dir=output_dir)
logger.info("Prediction done. Output directory: %s", output_dir)
else:
logger.info("Prediction already exists. Output directory: %s", output_dir)
# output_dir = Path("boltz_results") # debug
# Convert output CIF to PDB
pred_cif = list(output_dir.rglob("*_model_0.cif"))[0]
logger.info("Output file: %s", pred_cif)
converted_pdb_path = str(output_dir / "pred.pdb")
convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path))
logger.info("Converted PDB file: %s", converted_pdb_path)
# Generate molecular visualization
mol = _create_molecule_visualization(
converted_pdb_path,
sequence,
)
plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0]
logger.info("plddt file: %s", plddt_file)
plddt_vals = np.load(plddt_file)["plddt"]
return _wrap_in_iframe(mol), add_plddt_plot(plddt_vals=plddt_vals)
def _write_fasta_file(filepath: Path, sequence: str) -> None:
"""Write sequence to FASTA file."""
with open(filepath, "w") as f:
f.write(f">A|protein\n{sequence}")
def _create_molecule_visualization(pdb_path: Path, sequence: str) -> str:
"""Create molecular visualization using molecule module."""
return molecule(
str(pdb_path),
lenSeqs=1,
num_res=len(sequence),
selectedResidues=list(range(1, len(sequence) + 1)),
allSeqs=[sequence],
sequences=[{
"Score": 0,
"RMSD": 0,
"Recovery": 0,
"Mean pLDDT": 0,
"seq": sequence
}],
)
def _wrap_in_iframe(content: str) -> str:
"""Wrap content in an HTML iframe with appropriate styling and permissions."""
return f"""<iframe
name="result"
style="width: 100%; height: 100vh;"
allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;"
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups allow-top-navigation-by-user-activation allow-downloads"
allowfullscreen=""
allowpaymentrequest=""
frameborder="0"
srcdoc='{content}'
></iframe>"""
def add_plddt_plot(plddt_vals: list[float]) -> str:
"""Create a plot of metrics."""
visible = True
plddt_trace = go.Scatter(
x=np.arange(len(plddt_vals)),
y=plddt_vals,
hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
name="seq",
visible=visible,
)
plddt_fig = go.Figure(data=[plddt_trace])
plddt_fig.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
return plddt_fig
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
"""Fake predict structures of all complexes and correlate the results."""
corr_data = []
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
kd_col = "KD (nM)"
for score_col in score_cols:
logger.info(f"Computing correlation between {score_col} and KD (nM)")
res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
corr_data = pd.DataFrame(corr_data)
# Find the lines in corr_data with NaN values and remove them
corr_data = corr_data[corr_data["correlation"].notna()]
logger.info("Correlation data: %s", corr_data)
# Sort correlation data by correlation value
corr_data = corr_data.sort_values('correlation', ascending=True)
# Create bar plot of correlations
corr_ranking_plot = go.Figure(data=[
go.Bar(
x=corr_data["correlation"],
y=corr_data["score"],
name="correlation",
orientation='h',
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
)
])
corr_ranking_plot.update_layout(
title="Correlation with Binding Affinity",
yaxis_title="Score Type",
xaxis_title="Spearman Correlation",
template="simple_white",
showlegend=False
)
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
scatters = []
for score_col in score_cols:
scatters.append(
go.Scatter(
x=spr_data_with_scores[kd_col],
y=spr_data_with_scores[score_col],
name=f"{kd_col} vs {score_col}",
mode='markers', # Only show markers/dots, no lines
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
)
)
corr_plot = go.Figure(data=scatters)
cols_to_show = [kd_col]
cols_to_show.extend(score_cols)
return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot
demo = gr.Blocks(title="Folding Studio: structure prediction with Boltz-1")
with demo:
with gr.Tabs() as tabs:
with gr.TabItem("Inference"):
gr.Markdown("# Input")
with gr.Row():
with gr.Column():
sequence = gr.Textbox(label="Sequence", value="")
api_key = gr.Textbox(label="Folding API Key", type="password")
gr.Markdown("# Output")
with gr.Row():
predict_btn = gr.Button("Predict")
with gr.Row():
with gr.Column():
mol_output = gr.HTML()
with gr.Column():
metrics_plot = gr.Plot(label="pLDDT")
predict_btn.click(
fn=predict,
inputs=[sequence, api_key],
outputs=[mol_output, metrics_plot]
)
with gr.TabItem("Correlations"):
gr.Markdown("# Upload binding affinity data")
spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
with gr.Row():
csv_file = gr.File(label="Upload CSV file", file_types=[".csv"])
with gr.Row():
dataframe = gr.Dataframe(label="Binding Affinity Data")
gr.Markdown("# Prediction and correlation")
with gr.Row():
with gr.Column():
with gr.Row():
fake_predict_btn = gr.Button("Predict structures of all complexes")
with gr.Row():
prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
with gr.Column():
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
correlation_plot = gr.Plot(label="Correlation with binding affinity")
cols = [
"confidence_score_boltz",
"ptm_boltz",
"iptm_boltz",
# "ligand_iptm_boltz",
# "protein_iptm_boltz",
"complex_plddt_boltz",
"complex_iplddt_boltz",
"complex_pde_boltz",
"complex_ipde_boltz",
"interchain_pae_monomer",
"interface_pae_monomer",
"overall_pae_monomer",
"interface_plddt_monomer",
"average_plddt_monomer",
"ptm_monomer",
"interface_ptm_monomer",
"interchain_pae_multimer",
"interface_pae_multimer",
"overall_pae_multimer",
"interface_plddt_multimer",
"average_plddt_multimer",
"ptm_multimer",
"interface_ptm_multimer"
]
csv_file.change(
fn=lambda file: spr_data_with_scores.drop(columns=cols) if file else None,
inputs=csv_file,
outputs=dataframe
)
fake_predict_btn.click(
fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, cols),
inputs=None,
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
)
demo.launch()