|
"""API simple prediction call wrappers.""" |
|
|
|
import logging |
|
import warnings |
|
from pathlib import Path |
|
|
|
import requests |
|
import typer |
|
from folding_studio_data_models import AF2Parameters, OpenFoldParameters |
|
from folding_studio_data_models.request.folding import FoldingModel |
|
|
|
from folding_studio.config import API_URL, REQUEST_TIMEOUT |
|
from folding_studio.utils.data_model import ( |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio.utils.file_helpers import partition_template_pdb_from_file |
|
from folding_studio.utils.headers import get_auth_headers |
|
from folding_studio.utils.project_validation import define_project_code_or_raise |
|
|
|
|
|
def single_job_prediction( |
|
fasta_file: Path, |
|
parameters: AF2Parameters | OpenFoldParameters | None = None, |
|
project_code: str | None = None, |
|
api_key: str | None = None, |
|
*, |
|
ignore_cache: bool = False, |
|
**kwargs, |
|
) -> dict: |
|
"""Make a single job prediction from folding parameters and a FASTA file. |
|
|
|
This is a helper function to be called in users scripts. |
|
|
|
Args: |
|
fasta_file (Path): Input FASTA file |
|
parameters (AF2Parameters | OpenFoldParameters | None, optional): Job parameters. |
|
For backward compatibility, can be aliased with `af2_parameters`. Defaults to None. |
|
project_code (str | None, optional): Project code under which the jobs are billed. |
|
If None, value is attempted to be read from environment. Defaults to None. |
|
ignore_cache (bool, optional): Force the job submission or not. Defaults to False. |
|
|
|
Raises: |
|
ValueError: _description_ |
|
typer.Exit: If an error occurs during the API call. |
|
|
|
Returns: |
|
dict: API response. |
|
""" |
|
|
|
old_parameters = kwargs.get("af2_parameters") |
|
if parameters is None: |
|
if old_parameters is None: |
|
msg = "Argument `parameters` must be specified if deprecated alias `af2_parameters` is not. " |
|
raise ValueError(msg) |
|
else: |
|
warnings.warn( |
|
"Argument 'af2_parameters' is deprecated and will be removed in future release; use 'parameters' instead.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
parameters = old_parameters |
|
elif old_parameters is not None: |
|
raise ValueError("Use either 'parameters' or 'af2_parameters', not both.") |
|
|
|
project_code = define_project_code_or_raise(project_code=project_code) |
|
|
|
custom_files = PredictRequestCustomFiles( |
|
templates=parameters.custom_templates, |
|
msas=parameters.custom_msas, |
|
initial_guess_files=[parameters.initial_guess_file] |
|
if parameters.initial_guess_file |
|
else None, |
|
templates_masks_files=[parameters.templates_masks_file] |
|
if parameters.templates_masks_file |
|
else None, |
|
) |
|
_ = custom_files.upload(api_key=api_key) |
|
|
|
params = parameters.model_dump(mode="json") |
|
pdb_ids, _ = partition_template_pdb_from_file( |
|
custom_templates=parameters.custom_templates |
|
) |
|
|
|
folding_model = ( |
|
FoldingModel.OPENFOLD |
|
if isinstance(parameters, OpenFoldParameters) |
|
else FoldingModel.AF2 |
|
) |
|
|
|
params.update( |
|
{ |
|
"folding_model": folding_model.value, |
|
"custom_msa_files": custom_files.msas, |
|
"custom_template_ids": list(pdb_ids), |
|
"custom_template_files": custom_files.templates, |
|
"initial_guess_file": custom_files.initial_guess_files[0] |
|
if custom_files.initial_guess_files |
|
else None, |
|
"templates_masks_file": custom_files.templates_masks_files[0] |
|
if custom_files.templates_masks_files |
|
else None, |
|
"ignore_cache": ignore_cache, |
|
} |
|
) |
|
|
|
url = API_URL + "predict" |
|
response = requests.post( |
|
url, |
|
data=params, |
|
headers=get_auth_headers(api_key), |
|
files=[("fasta_file", fasta_file.open("rb"))], |
|
params={"project_code": project_code}, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
response.raise_for_status() |
|
|
|
logging.info("Single job successfully submitted.") |
|
response_json = response.json() |
|
return response_json |
|
|
|
|
|
def simple_prediction( |
|
file: Path, |
|
folding_model: FoldingModel, |
|
params: PredictRequestParams, |
|
custom_files: PredictRequestCustomFiles, |
|
project_code: str | None = None, |
|
) -> dict: |
|
"""Make a simple prediction from a file. |
|
|
|
Args: |
|
file (Path): Data source file path. |
|
params (PredictRequestParams): API request parameters. |
|
custom_files (PredictRequestCustomFiles): API request custom files. |
|
project_code (str|None): Project code under which the jobs are billed. |
|
|
|
Raises: |
|
typer.Exit: If an error occurs during the API call. |
|
""" |
|
project_code = define_project_code_or_raise(project_code=project_code) |
|
|
|
url = API_URL + "predict" |
|
|
|
_ = custom_files.upload() |
|
|
|
params = params.model_dump(mode="json") |
|
params.update( |
|
{ |
|
"folding_model": folding_model.value, |
|
"custom_msa_files": custom_files.msas, |
|
"custom_template_files": custom_files.templates, |
|
"initial_guess_file": custom_files.initial_guess_files[0] |
|
if custom_files.initial_guess_files |
|
else None, |
|
"templates_masks_file": custom_files.templates_masks_files[0] |
|
if custom_files.templates_masks_files |
|
else None, |
|
} |
|
) |
|
response = requests.post( |
|
url, |
|
data=params, |
|
headers=get_auth_headers(), |
|
files=[("fasta_file", file.open("rb"))], |
|
params={"project_code": project_code}, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
|
|
if not response.ok: |
|
print(f"An error occurred: {response.content.decode()}") |
|
raise typer.Exit(code=1) |
|
|
|
print("Single job successfully submitted.") |
|
response_json = response.json() |
|
return response_json |
|
|