chengzhang1006's picture
add more informations (#15)
01fba1c verified
"""API simple prediction call wrappers."""
import logging
import warnings
from pathlib import Path
import requests
import typer
from folding_studio_data_models import AF2Parameters, OpenFoldParameters
from folding_studio_data_models.request.folding import FoldingModel
from folding_studio.config import API_URL, REQUEST_TIMEOUT
from folding_studio.utils.data_model import (
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio.utils.file_helpers import partition_template_pdb_from_file
from folding_studio.utils.headers import get_auth_headers
from folding_studio.utils.project_validation import define_project_code_or_raise
def single_job_prediction(
fasta_file: Path,
parameters: AF2Parameters | OpenFoldParameters | None = None,
project_code: str | None = None,
api_key: str | None = None,
*,
ignore_cache: bool = False,
**kwargs,
) -> dict:
"""Make a single job prediction from folding parameters and a FASTA file.
This is a helper function to be called in users scripts.
Args:
fasta_file (Path): Input FASTA file
parameters (AF2Parameters | OpenFoldParameters | None, optional): Job parameters.
For backward compatibility, can be aliased with `af2_parameters`. Defaults to None.
project_code (str | None, optional): Project code under which the jobs are billed.
If None, value is attempted to be read from environment. Defaults to None.
ignore_cache (bool, optional): Force the job submission or not. Defaults to False.
Raises:
ValueError: _description_
typer.Exit: If an error occurs during the API call.
Returns:
dict: API response.
"""
old_parameters = kwargs.get("af2_parameters")
if parameters is None:
if old_parameters is None:
msg = "Argument `parameters` must be specified if deprecated alias `af2_parameters` is not. "
raise ValueError(msg)
else:
warnings.warn(
"Argument 'af2_parameters' is deprecated and will be removed in future release; use 'parameters' instead.",
DeprecationWarning,
stacklevel=2,
)
parameters = old_parameters
elif old_parameters is not None:
raise ValueError("Use either 'parameters' or 'af2_parameters', not both.")
project_code = define_project_code_or_raise(project_code=project_code)
custom_files = PredictRequestCustomFiles(
templates=parameters.custom_templates,
msas=parameters.custom_msas,
initial_guess_files=[parameters.initial_guess_file]
if parameters.initial_guess_file
else None,
templates_masks_files=[parameters.templates_masks_file]
if parameters.templates_masks_file
else None,
)
_ = custom_files.upload(api_key=api_key)
params = parameters.model_dump(mode="json")
pdb_ids, _ = partition_template_pdb_from_file(
custom_templates=parameters.custom_templates
)
folding_model = (
FoldingModel.OPENFOLD
if isinstance(parameters, OpenFoldParameters)
else FoldingModel.AF2
)
params.update(
{
"folding_model": folding_model.value,
"custom_msa_files": custom_files.msas,
"custom_template_ids": list(pdb_ids),
"custom_template_files": custom_files.templates,
"initial_guess_file": custom_files.initial_guess_files[0]
if custom_files.initial_guess_files
else None,
"templates_masks_file": custom_files.templates_masks_files[0]
if custom_files.templates_masks_files
else None,
"ignore_cache": ignore_cache,
}
)
url = API_URL + "predict"
response = requests.post(
url,
data=params,
headers=get_auth_headers(api_key),
files=[("fasta_file", fasta_file.open("rb"))],
params={"project_code": project_code},
timeout=REQUEST_TIMEOUT,
)
response.raise_for_status()
logging.info("Single job successfully submitted.")
response_json = response.json()
return response_json
def simple_prediction(
file: Path,
folding_model: FoldingModel,
params: PredictRequestParams,
custom_files: PredictRequestCustomFiles,
project_code: str | None = None,
) -> dict:
"""Make a simple prediction from a file.
Args:
file (Path): Data source file path.
params (PredictRequestParams): API request parameters.
custom_files (PredictRequestCustomFiles): API request custom files.
project_code (str|None): Project code under which the jobs are billed.
Raises:
typer.Exit: If an error occurs during the API call.
"""
project_code = define_project_code_or_raise(project_code=project_code)
url = API_URL + "predict"
_ = custom_files.upload()
params = params.model_dump(mode="json")
params.update(
{
"folding_model": folding_model.value,
"custom_msa_files": custom_files.msas,
"custom_template_files": custom_files.templates,
"initial_guess_file": custom_files.initial_guess_files[0]
if custom_files.initial_guess_files
else None,
"templates_masks_file": custom_files.templates_masks_files[0]
if custom_files.templates_masks_files
else None,
}
)
response = requests.post(
url,
data=params,
headers=get_auth_headers(),
files=[("fasta_file", file.open("rb"))],
params={"project_code": project_code},
timeout=REQUEST_TIMEOUT,
)
if not response.ok:
print(f"An error occurred: {response.content.decode()}")
raise typer.Exit(code=1)
print("Single job successfully submitted.")
response_json = response.json()
return response_json