|
"""API batch prediction call wrappers.""" |
|
|
|
from io import StringIO |
|
from pathlib import Path |
|
|
|
import requests |
|
import typer |
|
from Bio import SeqIO |
|
from folding_studio_data_models import ( |
|
AF2Request, |
|
BatchRequest, |
|
FoldingModel, |
|
OpenFoldRequest, |
|
Sequence, |
|
) |
|
from rich import print |
|
|
|
from folding_studio.config import API_URL, REQUEST_TIMEOUT |
|
from folding_studio.utils.data_model import ( |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio.utils.headers import get_auth_headers |
|
from folding_studio.utils.project_validation import define_project_code_or_raise |
|
|
|
|
|
def _extract_sequences_from_file(file: Path) -> list[Sequence]: |
|
content = SeqIO.parse(StringIO(file.read_text()), "fasta") |
|
sequences = [] |
|
for records in content: |
|
description = str(records.description) |
|
sequences.append( |
|
Sequence(description=description, fasta_sequence=str(records.seq)) |
|
) |
|
return sequences |
|
|
|
|
|
def _build_request_from_fasta( |
|
file: Path, |
|
folding_model: FoldingModel, |
|
params: PredictRequestParams, |
|
custom_files: PredictRequestCustomFiles, |
|
) -> AF2Request | OpenFoldRequest: |
|
"""Build an AF2Request from a fasta file path and request parameters. |
|
|
|
Args: |
|
file (Path): Path to a file describing the protein. |
|
folding_model (FoldingModel): Folding model to run the inference with. |
|
params (PredictRequestParams): API request parameters. |
|
custom_files (PredictRequestCustomFiles): API request custom files. |
|
|
|
Returns: |
|
AF2Request | OpenFoldRequest: Request object. |
|
""" |
|
parameters = dict( |
|
num_recycle=params.num_recycle, |
|
random_seed=params.random_seed, |
|
custom_templates=params.custom_template_ids |
|
+ [str(f) for f in custom_files.templates], |
|
custom_msas=[str(f) for f in custom_files.msas], |
|
gap_trick=params.gap_trick, |
|
msa_mode=params.msa_mode, |
|
max_msa_clusters=params.max_msa_clusters, |
|
max_extra_msa=params.max_extra_msa, |
|
template_mode=params.template_mode, |
|
model_subset=params.model_subset, |
|
initial_guess_file=custom_files.initial_guess_files, |
|
templates_masks_file=custom_files.templates_masks_files, |
|
) |
|
if folding_model == FoldingModel.AF2: |
|
return AF2Request( |
|
complex_id=file.stem, |
|
sequences=_extract_sequences_from_file(file), |
|
parameters=parameters, |
|
ignore_cache=params.ignore_cache, |
|
) |
|
return OpenFoldRequest( |
|
complex_id=file.stem, |
|
sequences=_extract_sequences_from_file(file), |
|
parameters=parameters, |
|
ignore_cache=params.ignore_cache, |
|
) |
|
|
|
|
|
def batch_prediction( |
|
files: list[Path], |
|
folding_model: FoldingModel, |
|
params: PredictRequestParams, |
|
custom_files: PredictRequestCustomFiles, |
|
project_code: str | None = None, |
|
num_seed: int | None = None, |
|
) -> dict: |
|
"""Make a batch prediction from a list of files. |
|
|
|
Args: |
|
files (list[Path]): List of data source file paths. |
|
params (PredictRequestParams): API request parameters. |
|
custom_files (PredictRequestCustomFiles): API request custom files. |
|
project_code (str|None): Project code under which the jobs are billed. |
|
num_seed (int | None, optional): Number of random seeds. Defaults to None. |
|
|
|
Raises: |
|
typer.Exit: If an error occurs during the API call. |
|
""" |
|
project_code = define_project_code_or_raise(project_code=project_code) |
|
|
|
custom_files.upload() |
|
|
|
if num_seed is not None: |
|
folding_requests = [] |
|
for seed in range(num_seed): |
|
params.random_seed = seed |
|
folding_requests += [ |
|
_build_request_from_fasta( |
|
file=file, |
|
folding_model=folding_model, |
|
params=params, |
|
custom_files=custom_files, |
|
) |
|
for file in files |
|
] |
|
else: |
|
folding_requests = [ |
|
_build_request_from_fasta( |
|
file=file, |
|
folding_model=folding_model, |
|
params=params, |
|
custom_files=custom_files, |
|
) |
|
for file in files |
|
] |
|
batch_request = BatchRequest(requests=folding_requests) |
|
url = API_URL + "batchPredict" |
|
|
|
response = requests.post( |
|
url, |
|
data={"batch_jobs_request": batch_request.model_dump_json()}, |
|
params={"project_code": project_code}, |
|
headers=get_auth_headers(), |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
|
|
if not response.ok: |
|
print(f"An error occurred: {response.content.decode()}") |
|
raise typer.Exit(code=1) |
|
|
|
response_json = response.json() |
|
return response_json |
|
|
|
|
|
def batch_prediction_from_file( |
|
file: Path, |
|
project_code: str | None = None, |
|
) -> dict: |
|
"""Make a batch prediction from a configuration files. |
|
|
|
Args: |
|
file (Path): Configuration file path. |
|
project_code (str|None): Project code under which the jobs are billed. |
|
|
|
Raises: |
|
typer.Exit: If an error occurs during the API call. |
|
""" |
|
project_code = define_project_code_or_raise(project_code=project_code) |
|
url = API_URL + "batchPredictFromFile" |
|
|
|
custom_files = PredictRequestCustomFiles.from_batch_jobs_file(batch_jobs_file=file) |
|
local_to_uploaded = custom_files.upload() |
|
|
|
if local_to_uploaded: |
|
content = file.read_text() |
|
for local, uploaded in local_to_uploaded.items(): |
|
content = content.replace(local, uploaded) |
|
tmp_file = Path("tmp_batch_job" + file.suffix) |
|
tmp_file.write_text(content) |
|
file_to_upload = tmp_file |
|
else: |
|
tmp_file = None |
|
file_to_upload = file |
|
|
|
with file_to_upload.open("rb") as input_file: |
|
response = requests.post( |
|
url, |
|
headers=get_auth_headers(), |
|
files=[("batch_jobs_file", input_file)], |
|
params={"project_code": project_code}, |
|
timeout=REQUEST_TIMEOUT, |
|
) |
|
|
|
if tmp_file and tmp_file.exists(): |
|
tmp_file.unlink() |
|
|
|
if not response.ok: |
|
print(f"An error occurred: {response.content.decode()}") |
|
raise typer.Exit(code=1) |
|
|
|
return response.json() |
|
|