|
"""Models for the Folding Studio API.""" |
|
|
|
import json |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import folding_studio |
|
import gradio as gr |
|
import numpy as np |
|
from folding_studio import single_job_prediction |
|
from folding_studio.client import Client |
|
from folding_studio.commands.experiment import results as get_results |
|
from folding_studio.commands.experiment import status as get_status |
|
from folding_studio.query import Query |
|
from folding_studio.query.boltz import BoltzQuery |
|
from folding_studio.query.chai import ChaiQuery |
|
from folding_studio.query.protenix import ProtenixQuery |
|
from folding_studio_data_models import AF2Parameters, OpenFoldParameters |
|
from folding_studio_data_models.parameters.base import BaseFoldingParameters |
|
|
|
from folding_studio_demo.model_fasta_validators import ( |
|
BaseFastaValidator, |
|
BoltzFastaValidator, |
|
ChaiFastaValidator, |
|
ProtenixFastaValidator, |
|
) |
|
|
|
|
|
class Capturing(list): |
|
"""Capture stdout output.""" |
|
|
|
def __enter__(self): |
|
self._stdout = sys.stdout |
|
sys.stdout = self._stringio = StringIO() |
|
return self |
|
|
|
def __exit__(self, *args): |
|
self.extend(self._stringio.getvalue().splitlines()) |
|
del self._stringio |
|
sys.stdout = self._stdout |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AF3Model: |
|
model_name = None |
|
|
|
def __init__(self, api_key: str, query: Query, validator: BaseFastaValidator): |
|
self.api_key = api_key |
|
self.query = query |
|
self.validator = validator |
|
|
|
def call( |
|
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False |
|
) -> None: |
|
"""Predict protein structure from amino acid sequence using AF3 model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
format_description (bool): Whether to format the description of the sequence |
|
""" |
|
|
|
is_valid, error_msg = self.check_file_description(seq_file) |
|
if format_fasta and not is_valid: |
|
logger.info("Invalid FASTA file format, forcing formatting...") |
|
self.format_fasta(seq_file) |
|
elif not is_valid: |
|
logger.error(error_msg) |
|
raise gr.Error(error_msg) |
|
|
|
|
|
logger.info("Authenticating client with API key") |
|
client = Client.from_api_key(api_key=self.api_key) |
|
|
|
|
|
query: Query = self.query.from_file(path=seq_file, query_name="gradio") |
|
query.save_parameters(output_dir) |
|
|
|
logger.info("Payload: %s", query.payload) |
|
|
|
|
|
logger.info(f"Sending {self.model_name} request to Folding Studio API") |
|
response = client.send_request( |
|
query, project_code=os.environ["FOLDING_PROJECT_CODE"] |
|
) |
|
|
|
|
|
logger.info("Confidence data: %s", response.confidence_data) |
|
|
|
response.download_results(output_dir=output_dir, force=True, unzip=True) |
|
logger.info("Results downloaded to %s", output_dir) |
|
|
|
def format_fasta(self, seq_file: Path | str) -> None: |
|
"""Format sequence to FASTA format. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file |
|
""" |
|
formatted_fasta = self.validator.transform_fasta(seq_file) |
|
with open(seq_file, "w") as f: |
|
f.write(formatted_fasta) |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction. |
|
|
|
Args: |
|
output_dir (Path): Path to output directory |
|
|
|
Returns: |
|
list[Path]: List of paths to predictions |
|
""" |
|
raise NotImplementedError() |
|
|
|
def has_prediction(self, output_dir: Path) -> bool: |
|
"""Check if prediction exists in output directory.""" |
|
return len(self.predictions(output_dir)) > 0 |
|
|
|
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: |
|
"""Check if the file description is correct. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
|
|
is_valid, error_msg = self.validator.is_valid_fasta(seq_file) |
|
if not is_valid: |
|
return False, error_msg |
|
|
|
return True, None |
|
|
|
|
|
class ChaiModel(AF3Model): |
|
model_name = "Chai" |
|
|
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, ChaiQuery, ChaiFastaValidator()) |
|
|
|
def call( |
|
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False |
|
) -> None: |
|
"""Predict protein structure from amino acid sequence using Chai model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
format_fasta (bool): Whether to format the FASTA file |
|
""" |
|
super().call(seq_file, output_dir, format_fasta) |
|
|
|
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]: |
|
"""Get the path to the prediction.""" |
|
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None) |
|
if prediction is None: |
|
return {} |
|
|
|
cif_files = { |
|
int(f.stem.split("model_idx_")[1]): f |
|
for f in prediction.parent.glob("pred.model_idx_*.cif") |
|
} |
|
|
|
|
|
npz_files = { |
|
int(f.stem.split("model_idx_")[1]): f |
|
for f in prediction.parent.glob("scores.model_idx_*.npz") |
|
} |
|
|
|
|
|
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys())) |
|
|
|
return { |
|
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])} |
|
for idx in common_indices |
|
} |
|
|
|
|
|
class ProtenixModel(AF3Model): |
|
model_name = "Protenix" |
|
|
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, ProtenixQuery, ProtenixFastaValidator()) |
|
|
|
def call( |
|
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False |
|
) -> None: |
|
"""Predict protein structure from amino acid sequence using Protenix model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
format_fasta (bool): Whether to format the FASTA file |
|
""" |
|
super().call(seq_file, output_dir, format_fasta) |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction.""" |
|
prediction = next(output_dir.rglob("sequence_*_sample_[0-9].cif"), None) |
|
if prediction is None: |
|
return {} |
|
|
|
cif_files = { |
|
int(f.stem[-1]): f |
|
for f in prediction.parent.glob("sequence_*_sample_[0-9].cif") |
|
} |
|
|
|
|
|
json_files = { |
|
int(f.stem[-1]): f |
|
for f in prediction.parent.glob( |
|
"sequence_*_summary_confidence_sample_[0-9].json" |
|
) |
|
} |
|
|
|
|
|
common_indices = sorted(set(cif_files.keys()) & set(json_files.keys())) |
|
|
|
return { |
|
idx: { |
|
"prediction_path": cif_files[idx], |
|
"metrics": json.load(open(json_files[idx])), |
|
} |
|
for idx in common_indices |
|
} |
|
|
|
|
|
class BoltzModel(AF3Model): |
|
model_name = "Boltz" |
|
|
|
def __init__(self, api_key: str): |
|
super().__init__(api_key, BoltzQuery, BoltzFastaValidator()) |
|
|
|
def call( |
|
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False |
|
) -> None: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
format_fasta (bool): Whether to format the FASTA file |
|
""" |
|
|
|
super().call(seq_file, output_dir, format_fasta) |
|
|
|
def predictions(self, output_dir: Path) -> list[Path]: |
|
"""Get the path to the prediction.""" |
|
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif")) |
|
return { |
|
int(cif_path.stem[-1]): { |
|
"prediction_path": cif_path, |
|
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]), |
|
} |
|
for cif_path in prediction_paths |
|
} |
|
|
|
|
|
class OldModel: |
|
model_name = None |
|
|
|
def __init__(self, api_key: str): |
|
self.api_key = api_key |
|
|
|
def call( |
|
self, |
|
seq_file: Path | str, |
|
output_dir: Path, |
|
parameters: BaseFoldingParameters, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
"""Predict protein structure from amino acid sequence using AF2 model. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file containing amino acid sequence |
|
output_dir (Path): Path to output directory |
|
""" |
|
output = single_job_prediction( |
|
fasta_file=seq_file, |
|
parameters=parameters, |
|
api_key=self.api_key, |
|
) |
|
experiment_id = output["message"]["experiment_id"] |
|
done = False |
|
while not done: |
|
with Capturing() as output: |
|
get_status(experiment_id, api_key=self.api_key) |
|
status = output[0] |
|
logger.info(f"Experiment {experiment_id} status: {status}") |
|
if status == "Done": |
|
done = True |
|
logger.info("Downloading results") |
|
get_results( |
|
experiment_id, |
|
force=True, |
|
unzip=True, |
|
output=output_dir / "results.zip", |
|
api_key=self.api_key, |
|
) |
|
logger.info("Results downloaded to %s", output_dir) |
|
else: |
|
logger.info("Sleeping for 10 seconds") |
|
time.sleep(10) |
|
|
|
def format_fasta(self, seq_file: Path | str) -> None: |
|
"""Format sequence to FASTA format. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file |
|
""" |
|
return |
|
|
|
def predictions(self, output_dir: Path) -> dict[int, dict[str, Any]]: |
|
"""Get the path to the prediction. |
|
|
|
Args: |
|
output_dir (Path): Path to output directory |
|
|
|
Returns: |
|
dict[int, dict[str, Any]]: Dictionary mapping model indices to their prediction paths and metrics |
|
""" |
|
prediction_paths = list( |
|
(output_dir / "results").rglob("relaxed_model_[0-9]_*_pred_0.pdb") |
|
) |
|
metrics_path = output_dir / "results" / "metrics_per_model.json" |
|
if not metrics_path.exists(): |
|
return {} |
|
with open(metrics_path, "r") as f: |
|
metrics = json.load(f) |
|
output = {} |
|
for pred_path in prediction_paths: |
|
model_id = int(pred_path.stem.split("_")[2]) |
|
model_name = "_".join(pred_path.stem.split("_")[1:-2]) |
|
output[model_id] = { |
|
"prediction_path": pred_path, |
|
"metrics": metrics[model_name], |
|
} |
|
return output |
|
|
|
def has_prediction(self, output_dir: Path) -> bool: |
|
"""Check if prediction exists in output directory.""" |
|
return len(self.predictions(output_dir)) > 0 |
|
|
|
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: |
|
"""Check if the file description is correct. |
|
|
|
Args: |
|
seq_file (Path | str): Path to FASTA file |
|
|
|
Returns: |
|
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not |
|
""" |
|
|
|
return True, None |
|
|
|
|
|
class AF2Model(OldModel): |
|
model_name = "AlphaFold2" |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None: |
|
super().call(seq_file, output_dir, AF2Parameters(), *args, **kwargs) |
|
|
|
|
|
class OpenFoldModel(OldModel): |
|
model_name = "OpenFold" |
|
|
|
def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None: |
|
super().call(seq_file, output_dir, OpenFoldParameters(), *args, **kwargs) |
|
|