"""API simple prediction call wrappers.""" import logging import warnings from pathlib import Path import requests import typer from folding_studio_data_models import AF2Parameters, OpenFoldParameters from folding_studio_data_models.request.folding import FoldingModel from folding_studio.config import API_URL, REQUEST_TIMEOUT from folding_studio.utils.data_model import ( PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio.utils.file_helpers import partition_template_pdb_from_file from folding_studio.utils.headers import get_auth_headers from folding_studio.utils.project_validation import define_project_code_or_raise def single_job_prediction( fasta_file: Path, parameters: AF2Parameters | OpenFoldParameters | None = None, project_code: str | None = None, *, ignore_cache: bool = False, **kwargs, ) -> dict: """Make a single job prediction from folding parameters and a FASTA file. This is a helper function to be called in users scripts. Args: fasta_file (Path): Input FASTA file parameters (AF2Parameters | OpenFoldParameters | None, optional): Job parameters. For backward compatibility, can be aliased with `af2_parameters`. Defaults to None. project_code (str | None, optional): Project code under which the jobs are billed. If None, value is attempted to be read from environment. Defaults to None. ignore_cache (bool, optional): Force the job submission or not. Defaults to False. Raises: ValueError: _description_ typer.Exit: If an error occurs during the API call. Returns: dict: API response. """ old_parameters = kwargs.get("af2_parameters") if parameters is None: if old_parameters is None: msg = "Argument `parameters` must be specified if deprecated alias `af2_parameters` is not. " raise ValueError(msg) else: warnings.warn( "Argument 'af2_parameters' is deprecated and will be removed in future release; use 'parameters' instead.", DeprecationWarning, stacklevel=2, ) parameters = old_parameters elif old_parameters is not None: raise ValueError("Use either 'parameters' or 'af2_parameters', not both.") project_code = define_project_code_or_raise(project_code=project_code) custom_files = PredictRequestCustomFiles( templates=parameters.custom_templates, msas=parameters.custom_msas, initial_guess_files=[parameters.initial_guess_file] if parameters.initial_guess_file else None, templates_masks_files=[parameters.templates_masks_file] if parameters.templates_masks_file else None, ) _ = custom_files.upload() params = parameters.model_dump(mode="json") pdb_ids, _ = partition_template_pdb_from_file( custom_templates=parameters.custom_templates ) folding_model = ( FoldingModel.OPENFOLD if isinstance(parameters, OpenFoldParameters) else FoldingModel.AF2 ) params.update( { "folding_model": folding_model.value, "custom_msa_files": custom_files.msas, "custom_template_ids": list(pdb_ids), "custom_template_files": custom_files.templates, "initial_guess_file": custom_files.initial_guess_files[0] if custom_files.initial_guess_files else None, "templates_masks_file": custom_files.templates_masks_files[0] if custom_files.templates_masks_files else None, "ignore_cache": ignore_cache, } ) url = API_URL + "predict" response = requests.post( url, data=params, headers=get_auth_headers(), files=[("fasta_file", fasta_file.open("rb"))], params={"project_code": project_code}, timeout=REQUEST_TIMEOUT, ) response.raise_for_status() logging.info("Single job successfully submitted.") response_json = response.json() return response_json def simple_prediction( file: Path, folding_model: FoldingModel, params: PredictRequestParams, custom_files: PredictRequestCustomFiles, project_code: str | None = None, ) -> dict: """Make a simple prediction from a file. Args: file (Path): Data source file path. params (PredictRequestParams): API request parameters. custom_files (PredictRequestCustomFiles): API request custom files. project_code (str|None): Project code under which the jobs are billed. Raises: typer.Exit: If an error occurs during the API call. """ project_code = define_project_code_or_raise(project_code=project_code) url = API_URL + "predict" _ = custom_files.upload() params = params.model_dump(mode="json") params.update( { "folding_model": folding_model.value, "custom_msa_files": custom_files.msas, "custom_template_files": custom_files.templates, "initial_guess_file": custom_files.initial_guess_files[0] if custom_files.initial_guess_files else None, "templates_masks_file": custom_files.templates_masks_files[0] if custom_files.templates_masks_files else None, } ) response = requests.post( url, data=params, headers=get_auth_headers(), files=[("fasta_file", file.open("rb"))], params={"project_code": project_code}, timeout=REQUEST_TIMEOUT, ) if not response.ok: print(f"An error occurred: {response.content.decode()}") raise typer.Exit(code=1) print("Single job successfully submitted.") response_json = response.json() return response_json