|
"""Predict protein structure using Folding Studio.""" |
|
|
|
import hashlib |
|
import logging |
|
import os |
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
from Bio import SeqIO |
|
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer |
|
from folding_studio.client import Client |
|
from folding_studio.query import Query |
|
from folding_studio.query.boltz import BoltzQuery |
|
from folding_studio.query.chai import ChaiQuery |
|
from folding_studio.query.protenix import ProtenixQuery |
|
from folding_studio_data_models import FoldingModel |
|
|
|
from folding_studio_demo.model_fasta_validators import ( |
|
BaseFastaValidator, |
|
BoltzFastaValidator, |
|
ChaiFastaValidator, |
|
ProtenixFastaValidator, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SEQUENCE_DIR = Path("sequences") |
|
SEQUENCE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
OUTPUT_DIR = Path("output") |
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: |
|
"""Convert a .cif file to .pdb format using Biopython. |
|
|
|
Args: |
|
cif_path (str): Path to input .cif file |
|
pdb_path (str): Path to output .pdb file |
|
""" |
|
|
|
parser = MMCIFParser() |
|
structure = parser.get_structure("structure", cif_path) |
|
|
|
|
|
io = PDBIO() |
|
io.set_structure(structure) |
|
io.save(pdb_path) |
|
|
|
|
|
def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure: |
|
"""Create a plot of metrics.""" |
|
visible = True |
|
plddt_traces = [ |
|
go.Scatter( |
|
x=np.arange(len(plddt_val)), |
|
y=plddt_val, |
|
hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>", |
|
name=f"{model_name} {i}", |
|
visible=visible, |
|
) |
|
for i, plddt_val in enumerate(plddt_vals) |
|
] |
|
|
|
plddt_fig = go.Figure(data=plddt_traces) |
|
plddt_fig.update_layout( |
|
title="pLDDT", |
|
xaxis_title="Residue index", |
|
yaxis_title="pLDDT", |
|
height=500, |
|
template="simple_white", |
|
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), |
|
) |
|
return plddt_fig |
|
|
|
|
|
def _write_fasta_file( |
|
sequence: str, directory: Path = SEQUENCE_DIR |
|
) -> tuple[str, Path]: |
|
"""Write sequence to FASTA file. |
|
|
|
Args: |
|
sequence (str): Sequence to write to FASTA file |
|
directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR) |
|
|
|
Returns: |
|
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file |
|
""" |
|
input_rep = list(SeqIO.parse(StringIO(sequence), "fasta")) |
|
if not input_rep: |
|
raise gr.Error("No sequence found") |
|
|
|
seq_id = hashlib.sha256( |
|
"_".join([str(records.seq) for records in input_rep]).encode() |
|
).hexdigest() |
|
seq_file = directory / f"sequence_{seq_id}.fasta" |
|
with open(seq_file, "w") as f: |
|
f.write(sequence) |
|
return seq_id, seq_file |
|
|
|
|
|
class AF3Model: |
|
def __init__( |
|
self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator |
|
): |
|
self.api_key = api_key |
|
self.model_name = model_name |
|
self.query = query |
|
self.validator = validator |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path) -> None: |
|
"""Predict protein structure from amino acid sequence using AF3 model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
""" |
|
|
|
is_valid, error_msg = self.check_file_description(seq_file) |
|
if not is_valid: |
|
logger.error(error_msg) |
|
raise gr.Error(error_msg) |
|
|
|
|
|
logger.info("Authenticating client with API key") |
|
client = Client.from_api_key(api_key=self.api_key) |
|
|
|
|
|
query: Query = self.query.from_file(path=seq_file, query_name="gradio") |
|
query.save_parameters(output_dir) |
|
|
|
logger.info("Payload: %s", query.payload) |
|
|
|
|
|
logger.info(f"Sending {self.model_name} request to Folding Studio API") |
|
response = client.send_request( |
|
query, project_code=os.environ["FOLDING_PROJECT_CODE"] |
|
) |
|
|
|
|
|
logger.info("Confidence data: %s", response.confidence_data) |
|
|
|
response.download_results(output_dir=output_dir, force=True, unzip=True) |
|
logger.info("Results downloaded to %s", output_dir) |
|
|
|
def format_fasta(self, sequence: str) -> str: |
|
"""Format sequence to FASTA format.""" |
|
return f">{self.model_name}\n{sequence}" |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction.""" |
|
raise NotImplementedError("Not implemented") |
|
|
|
def has_prediction(self, output_dir: Path) -> bool: |
|
"""Check if prediction exists in output directory.""" |
|
return len(self.predictions(output_dir)) > 0 |
|
|
|
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: |
|
"""Check if the file description is correct. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
|
|
is_valid, error_msg = self.validator.is_valid_fasta(seq_file) |
|
if not is_valid: |
|
return False, error_msg |
|
|
|
return True, None |
|
|
|
|
|
class ChaiModel(AF3Model): |
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator()) |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path) -> None: |
|
"""Predict protein structure from amino acid sequence using Chai model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
""" |
|
super().call(seq_file, output_dir) |
|
|
|
def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]: |
|
"""Get pairs of .cif and .npz files with matching model indices. |
|
|
|
Args: |
|
directory (Path): Directory containing the prediction files |
|
|
|
Returns: |
|
list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs |
|
""" |
|
|
|
|
|
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]: |
|
"""Get the path to the prediction.""" |
|
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None) |
|
if prediction is None: |
|
return {} |
|
|
|
cif_files = { |
|
int(f.stem.split("model_idx_")[1]): f |
|
for f in prediction.parent.glob("pred.model_idx_*.cif") |
|
} |
|
|
|
|
|
npz_files = { |
|
int(f.stem.split("model_idx_")[1]): f |
|
for f in prediction.parent.glob("scores.model_idx_*.npz") |
|
} |
|
|
|
|
|
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys())) |
|
|
|
return { |
|
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])} |
|
for idx in common_indices |
|
} |
|
|
|
|
|
class ProtenixModel(AF3Model): |
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator()) |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path) -> None: |
|
"""Predict protein structure from amino acid sequence using Protenix model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
""" |
|
super().call(seq_file, output_dir) |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction.""" |
|
return list(output_dir.rglob("*_model_[0-9].cif")) |
|
|
|
|
|
class BoltzModel(AF3Model): |
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator()) |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path) -> None: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
""" |
|
|
|
super().call(seq_file, output_dir) |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction.""" |
|
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif")) |
|
return { |
|
int(cif_path.stem[-1]): { |
|
"prediction_path": cif_path, |
|
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]), |
|
} |
|
for cif_path in prediction_paths |
|
} |
|
|
|
|
|
def extract_plddt_from_cif(cif_path): |
|
structure = MMCIFParser().get_structure("structure", cif_path) |
|
|
|
|
|
plddt_values = [] |
|
|
|
|
|
for model in structure: |
|
for chain in model: |
|
for residue in chain: |
|
|
|
if "CA" in residue: |
|
|
|
plddt = residue["CA"].get_bfactor() |
|
plddt_values.append(plddt) |
|
|
|
return plddt_values |
|
|
|
|
|
def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
model (FoldingModel): Folding model to use |
|
|
|
Returns: |
|
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot |
|
""" |
|
if not api_key: |
|
raise gr.Error("Missing API key, please enter a valid API key") |
|
|
|
|
|
seq_id, seq_file = _write_fasta_file(sequence) |
|
output_dir = OUTPUT_DIR / seq_id / model_type |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
if model_type == FoldingModel.BOLTZ: |
|
model = BoltzModel(api_key) |
|
elif model_type == FoldingModel.CHAI: |
|
model = ChaiModel(api_key) |
|
elif model_type == FoldingModel.PROTENIX: |
|
model = ProtenixModel(api_key) |
|
else: |
|
raise ValueError(f"Model {model_type} not supported") |
|
|
|
|
|
if not model.has_prediction(output_dir): |
|
|
|
logger.info(f"Predicting {seq_id}") |
|
model.call(seq_file=seq_file, output_dir=output_dir) |
|
logger.info("Prediction done. Output directory: %s", output_dir) |
|
else: |
|
logger.info("Prediction already exists. Output directory: %s", output_dir) |
|
|
|
|
|
|
|
|
|
if not model.has_prediction(output_dir): |
|
raise gr.Error("No prediction found") |
|
|
|
predictions = model.predictions(output_dir) |
|
pdb_paths = [] |
|
model_plddt_vals = [] |
|
for model_idx, prediction in predictions.items(): |
|
cif_path = prediction["prediction_path"] |
|
logger.info( |
|
"CIF file: %s", |
|
) |
|
|
|
converted_pdb_path = str( |
|
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb" |
|
) |
|
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path)) |
|
plddt_vals = extract_plddt_from_cif(cif_path) |
|
pdb_paths.append(converted_pdb_path) |
|
model_plddt_vals.append(plddt_vals) |
|
plddt_plot = add_plddt_plot( |
|
plddt_vals=model_plddt_vals, model_name=model.model_name |
|
) |
|
return pdb_paths, plddt_plot |
|
|
|
|
|
def align_structures(pdb_paths: list[str]) -> list[str]: |
|
"""Align multiple PDB structures to the first structure. |
|
|
|
Args: |
|
pdb_paths (list[str]): List of paths to PDB files to align |
|
|
|
Returns: |
|
list[str]: List of paths to aligned PDB files |
|
""" |
|
|
|
parser = PDBParser() |
|
io = PDBIO() |
|
|
|
|
|
ref_structure = parser.get_structure("reference", pdb_paths[0]) |
|
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"] |
|
|
|
aligned_paths = [pdb_paths[0]] |
|
|
|
|
|
for i, pdb_path in enumerate(pdb_paths[1:], start=1): |
|
|
|
structure = parser.get_structure(f"model_{i}", pdb_path) |
|
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"] |
|
|
|
|
|
sup = Superimposer() |
|
|
|
|
|
sup.set_atoms(ref_atoms, atoms) |
|
|
|
|
|
sup.apply(structure.get_atoms()) |
|
|
|
|
|
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}") |
|
io.set_structure(structure) |
|
io.save(aligned_path) |
|
aligned_paths.append(aligned_path) |
|
|
|
return aligned_paths |
|
|
|
|
|
def predict_comparison( |
|
sequence: str, api_key: str, model_types: list[FoldingModel] |
|
) -> tuple[str, str]: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
model (FoldingModel): Folding model to use |
|
|
|
Returns: |
|
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot |
|
""" |
|
if not api_key: |
|
raise gr.Error("Missing API key, please enter a valid API key") |
|
|
|
|
|
pdb_paths = [] |
|
for model_type in model_types: |
|
model_pdb_paths, _ = predict(sequence, api_key, model_type) |
|
pdb_paths += model_pdb_paths |
|
|
|
aligned_paths = align_structures(pdb_paths) |
|
|
|
return aligned_paths |
|
|