"""Test simple prediction.""" import os from pathlib import Path from unittest import mock import pytest from folding_studio.api_call.predict.simple_predict import ( simple_prediction, single_job_prediction, ) from folding_studio.config import API_URL, REQUEST_TIMEOUT from folding_studio.utils.data_model import ( PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio.utils.exceptions import ProjectCodeNotFound from folding_studio_data_models import AF2Parameters, OpenFoldParameters from folding_studio_data_models.request.folding import FoldingModel @pytest.fixture(autouse=True) def mock_post(): post_mock = mock.Mock() mock_response = mock.MagicMock() mock_response.ok = True post_mock.return_value = mock_response with mock.patch("requests.post", post_mock): yield post_mock @pytest.fixture(autouse=True) def mock_get_auth_headers(): with mock.patch( "folding_studio.api_call.predict.simple_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'} ) as m: yield m def test_simple_prediction_pass( tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str] ): """Test simple prediction pass.""" file = tmp_path / "fasta_file.fasta" file.touch() params = PredictRequestParams( ignore_cache=False, template_mode="search", custom_template_ids=[], msa_mode="search", max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[1, 3, 4], ) custom_files = PredictRequestCustomFiles(templates=[], msas=[]) simple_prediction( file, folding_model, params, custom_files, project_code="FOLDING_DEV", ) params = params.model_dump(mode="json") params.update( { "folding_model": folding_model, "custom_msa_files": [], "custom_template_files": [], "initial_guess_file": None, "templates_masks_file": None, } ) mock_post.assert_called_once_with( API_URL + "predict", data=params, headers=headers, files=mock.ANY, timeout=REQUEST_TIMEOUT, params={"project_code": "FOLDING_DEV"}, ) def test_simple_prediction_fail_because_no_project_code( remove_project_code_from_env_var, tmp_path: Path ): """Test simple prediction fails due to unset project code.""" file = tmp_path / "fasta_file.fasta" file.touch() assert os.environ.get("FOLDING_PROJECT_CODE") is None params = PredictRequestParams( ignore_cache=False, template_mode="search", custom_template_ids=[], msa_mode="search", max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[1, 3, 4], ) custom_files = PredictRequestCustomFiles(templates=[], msas=[]) with pytest.raises(ProjectCodeNotFound): simple_prediction( file, FoldingModel.AF2, params, custom_files, ) def test_single_job_prediction_pass( tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str] ): """Test simple prediction pass.""" file = tmp_path / "fasta_file.fasta" file.touch() parameters = ( OpenFoldParameters() if folding_model == FoldingModel.OPENFOLD else AF2Parameters() ) single_job_prediction( fasta_file=file, parameters=parameters, project_code="FOLDING_DEV", ) params = parameters.model_dump(mode="json") params.update( { "folding_model": folding_model.value, "custom_msa_files": [], "custom_template_ids": [], "custom_template_files": [], "initial_guess_file": None, "templates_masks_file": None, "ignore_cache": False, } ) mock_post.assert_called_once_with( API_URL + "predict", data=params, headers=headers, files=mock.ANY, timeout=REQUEST_TIMEOUT, params={"project_code": "FOLDING_DEV"}, ) def test_single_job_prediction_handle_deprecated_af2_parameters( tmp_path: Path, mock_post: pytest.FixtureRequest, headers: dict[str, str] ): """Test simple prediction pass.""" file = tmp_path / "fasta_file.fasta" file.touch() parameters = AF2Parameters() with pytest.deprecated_call(): single_job_prediction( fasta_file=file, af2_parameters=parameters, project_code="FOLDING_DEV", ) with pytest.raises(ValueError): single_job_prediction( fasta_file=file, af2_parameters=parameters, parameters=parameters, project_code="FOLDING_DEV", )