"""API batch prediction call wrappers.""" from io import StringIO from pathlib import Path import requests import typer from Bio import SeqIO from folding_studio_data_models import ( AF2Request, BatchRequest, FoldingModel, OpenFoldRequest, Sequence, ) from rich import print # pylint:disable=redefined-builtin from folding_studio.config import API_URL, REQUEST_TIMEOUT from folding_studio.utils.data_model import ( PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio.utils.headers import get_auth_headers from folding_studio.utils.project_validation import define_project_code_or_raise def _extract_sequences_from_file(file: Path) -> list[Sequence]: content = SeqIO.parse(StringIO(file.read_text()), "fasta") sequences = [] for records in content: description = str(records.description) sequences.append( Sequence(description=description, fasta_sequence=str(records.seq)) ) return sequences def _build_request_from_fasta( file: Path, folding_model: FoldingModel, params: PredictRequestParams, custom_files: PredictRequestCustomFiles, ) -> AF2Request | OpenFoldRequest: """Build an AF2Request from a fasta file path and request parameters. Args: file (Path): Path to a file describing the protein. folding_model (FoldingModel): Folding model to run the inference with. params (PredictRequestParams): API request parameters. custom_files (PredictRequestCustomFiles): API request custom files. Returns: AF2Request | OpenFoldRequest: Request object. """ parameters = dict( num_recycle=params.num_recycle, random_seed=params.random_seed, custom_templates=params.custom_template_ids + [str(f) for f in custom_files.templates], custom_msas=[str(f) for f in custom_files.msas], gap_trick=params.gap_trick, msa_mode=params.msa_mode, max_msa_clusters=params.max_msa_clusters, max_extra_msa=params.max_extra_msa, template_mode=params.template_mode, model_subset=params.model_subset, initial_guess_file=custom_files.initial_guess_files, templates_masks_file=custom_files.templates_masks_files, ) if folding_model == FoldingModel.AF2: return AF2Request( complex_id=file.stem, sequences=_extract_sequences_from_file(file), parameters=parameters, ignore_cache=params.ignore_cache, ) return OpenFoldRequest( complex_id=file.stem, sequences=_extract_sequences_from_file(file), parameters=parameters, ignore_cache=params.ignore_cache, ) def batch_prediction( files: list[Path], folding_model: FoldingModel, params: PredictRequestParams, custom_files: PredictRequestCustomFiles, project_code: str | None = None, num_seed: int | None = None, ) -> dict: """Make a batch prediction from a list of files. Args: files (list[Path]): List of data source file paths. params (PredictRequestParams): API request parameters. custom_files (PredictRequestCustomFiles): API request custom files. project_code (str|None): Project code under which the jobs are billed. num_seed (int | None, optional): Number of random seeds. Defaults to None. Raises: typer.Exit: If an error occurs during the API call. """ project_code = define_project_code_or_raise(project_code=project_code) # upload custom files if any custom_files.upload() if num_seed is not None: folding_requests = [] for seed in range(num_seed): params.random_seed = seed folding_requests += [ _build_request_from_fasta( file=file, folding_model=folding_model, params=params, custom_files=custom_files, ) for file in files ] else: folding_requests = [ _build_request_from_fasta( file=file, folding_model=folding_model, params=params, custom_files=custom_files, ) for file in files ] batch_request = BatchRequest(requests=folding_requests) url = API_URL + "batchPredict" response = requests.post( url, data={"batch_jobs_request": batch_request.model_dump_json()}, params={"project_code": project_code}, headers=get_auth_headers(), timeout=REQUEST_TIMEOUT, ) if not response.ok: print(f"An error occurred: {response.content.decode()}") raise typer.Exit(code=1) response_json = response.json() return response_json def batch_prediction_from_file( file: Path, project_code: str | None = None, ) -> dict: """Make a batch prediction from a configuration files. Args: file (Path): Configuration file path. project_code (str|None): Project code under which the jobs are billed. Raises: typer.Exit: If an error occurs during the API call. """ project_code = define_project_code_or_raise(project_code=project_code) url = API_URL + "batchPredictFromFile" custom_files = PredictRequestCustomFiles.from_batch_jobs_file(batch_jobs_file=file) local_to_uploaded = custom_files.upload() if local_to_uploaded: content = file.read_text() for local, uploaded in local_to_uploaded.items(): content = content.replace(local, uploaded) tmp_file = Path("tmp_batch_job" + file.suffix) tmp_file.write_text(content) file_to_upload = tmp_file else: tmp_file = None file_to_upload = file with file_to_upload.open("rb") as input_file: response = requests.post( url, headers=get_auth_headers(), files=[("batch_jobs_file", input_file)], params={"project_code": project_code}, timeout=REQUEST_TIMEOUT, ) if tmp_file and tmp_file.exists(): tmp_file.unlink() if not response.ok: print(f"An error occurred: {response.content.decode()}") raise typer.Exit(code=1) return response.json()