"""Test simple prediction.""" from pathlib import Path from unittest import mock import pytest from folding_studio.api_call.predict.batch_predict import ( _build_request_from_fasta, _extract_sequences_from_file, batch_prediction, batch_prediction_from_file, ) from folding_studio.config import API_URL, REQUEST_TIMEOUT from folding_studio.utils.data_model import ( PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio_data_models import ( AF2Request, BatchRequest, FeatureMode, FoldingModel, OpenFoldRequest, ) from folding_studio_data_models.exceptions import DuplicatedRequest current_workdir = Path(__file__).parent.resolve() data_dir = Path(current_workdir / "data") @pytest.fixture() def mock_post(): post_mock = mock.Mock() mock_response = mock.MagicMock() mock_response.ok = True post_mock.return_value = mock_response with mock.patch("requests.post", post_mock): yield post_mock @pytest.fixture(autouse=True) def mock_get_auth_headers(): with mock.patch( "folding_studio.api_call.predict.batch_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'} ) as m: yield m def test_build_request_from_fasta_pass(folding_model: FoldingModel): file = Path(data_dir, "protein.fasta") params = PredictRequestParams( ignore_cache=False, template_mode=FeatureMode.SEARCH, custom_template_ids=["AB12"], msa_mode=FeatureMode.SEARCH, max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[1, 3, 4], ) custom_files = PredictRequestCustomFiles( templates=[Path("gs://custom_template.cif")], msas=[Path("gs://custom_msa.sto")], uploaded=True, ) request = _build_request_from_fasta( file, folding_model=folding_model, params=params, custom_files=custom_files, ) parameters = dict( num_recycle=params.num_recycle, random_seed=params.random_seed, custom_templates=params.custom_template_ids + ["gs://custom_template.cif"], custom_msas=["gs://custom_msa.sto"], gap_trick=params.gap_trick, msa_mode=params.msa_mode, max_msa_clusters=params.max_msa_clusters, max_extra_msa=params.max_extra_msa, template_mode=params.template_mode, model_subset=params.model_subset, ) if folding_model == FoldingModel.AF2: expected_request = AF2Request( complex_id="protein", sequences=_extract_sequences_from_file(file), parameters=parameters, ) else: expected_request = OpenFoldRequest( complex_id="protein", sequences=_extract_sequences_from_file(file), parameters=parameters, ) assert request == expected_request @pytest.mark.parametrize( "fasta_files, file_contents", [ ( ["fasta_file_1.fasta", "fasta_file_1_duplicate.fasta"], [">A\nA\n", ">A\nA\n"], ), # Duplicate content, different files ], ) def test_batch_prediction_fail_duplicate( tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers, ): """Test batch prediction for duplicates and unique files.""" fasta_paths = [tmp_path / file_name for file_name in fasta_files] for f, content in zip(fasta_paths, file_contents): with open(f, "w") as fasta_file: fasta_file.write(content) params = PredictRequestParams( ignore_cache=False, template_mode=FeatureMode.SEARCH, custom_template_ids=[], msa_mode=FeatureMode.SEARCH, max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[], ) custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True) with pytest.raises(DuplicatedRequest): expected_request = BatchRequest( requests=[ _build_request_from_fasta( file, folding_model=folding_model, params=params, custom_files=custom_files, ) for file in fasta_paths ] ) mocked_local_to_uploaded = {} with mock.patch( "folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", return_value=mocked_local_to_uploaded, ): with pytest.raises(DuplicatedRequest): batch_prediction( fasta_paths, folding_model, params, custom_files, num_seed=None, project_code="FOLDING_DEV", ) mock_post.assert_called_once_with( API_URL + "batchPredict", data={"batch_jobs_request": expected_request.model_dump_json()}, headers=headers, timeout=REQUEST_TIMEOUT, params={"project_code": "FOLDING_DEV"}, ) @pytest.mark.parametrize( "fasta_files, file_contents", [ (["fasta_file_1.fasta", "fasta_file_1_unique.fasta"], [">A\nA\n", ">B\nb\n"]), ], ) def test_batch_prediction_pass( tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers, ): """Test batch prediction for duplicates and unique files.""" fasta_paths = [tmp_path / file_name for file_name in fasta_files] for f, content in zip(fasta_paths, file_contents): with open(f, "w") as fasta_file: fasta_file.write(content) params = PredictRequestParams( ignore_cache=False, template_mode=FeatureMode.SEARCH, custom_template_ids=[], msa_mode=FeatureMode.SEARCH, max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[], ) custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True) expected_request = BatchRequest( requests=[ _build_request_from_fasta( file, folding_model=folding_model, params=params, custom_files=custom_files, ) for file in fasta_paths ] ) mocked_local_to_uploaded = {} with mock.patch( "folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", return_value=mocked_local_to_uploaded, ): batch_prediction( fasta_paths, folding_model, params, custom_files, num_seed=None, project_code="FOLDING_DEV", ) mock_post.assert_called_once_with( API_URL + "batchPredict", data={"batch_jobs_request": expected_request.model_dump_json()}, headers=headers, timeout=REQUEST_TIMEOUT, params={"project_code": "FOLDING_DEV"}, ) @pytest.mark.parametrize("batch_file", ["batch.json", "batch.csv"]) def test_batch_prediction_from_file_pass(mock_post: mock.Mock, batch_file: str, headers: dict[str, str]): """Test batch prediction pass.""" batch_jobs_file = Path(__file__).parent / f"data/{batch_file}" custom_files = [ batch_jobs_file.parent / "1agw.cif", batch_jobs_file.parent / "1agz.cif", batch_jobs_file.parent / "1agb_A.sto", batch_jobs_file.parent / "1agb_B.sto", batch_jobs_file.parent / "6m0j_A.sto", batch_jobs_file.parent / "6m0j_B.sto", ] for f in custom_files: f.touch() mocked_local_to_uploaded = {} for local in custom_files: mocked_local_to_uploaded[str(local.name)] = f"gs://bucket/{local.name}" with mock.patch( "folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", return_value=mocked_local_to_uploaded, ): batch_prediction_from_file(batch_jobs_file, project_code="FOLDING_DEV") try: mock_post.assert_called_once_with( API_URL + "batchPredictFromFile", headers=headers, files=[ ("batch_jobs_file", mock.ANY), ], timeout=REQUEST_TIMEOUT, params={"project_code": "FOLDING_DEV"}, ) finally: for f in custom_files: f.unlink()