|
"""Test simple prediction.""" |
|
|
|
import os |
|
from pathlib import Path |
|
from unittest import mock |
|
|
|
|
|
import pytest |
|
from folding_studio.api_call.predict.simple_predict import ( |
|
simple_prediction, |
|
single_job_prediction, |
|
) |
|
from folding_studio.config import API_URL, REQUEST_TIMEOUT |
|
from folding_studio.utils.data_model import ( |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio.utils.exceptions import ProjectCodeNotFound |
|
from folding_studio_data_models import AF2Parameters, OpenFoldParameters |
|
from folding_studio_data_models.request.folding import FoldingModel |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_post(): |
|
post_mock = mock.Mock() |
|
mock_response = mock.MagicMock() |
|
mock_response.ok = True |
|
post_mock.return_value = mock_response |
|
with mock.patch("requests.post", post_mock): |
|
yield post_mock |
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_get_auth_headers(): |
|
with mock.patch( |
|
"folding_studio.api_call.predict.simple_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'} |
|
) as m: |
|
yield m |
|
|
|
def test_simple_prediction_pass( |
|
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str] |
|
): |
|
"""Test simple prediction pass.""" |
|
|
|
file = tmp_path / "fasta_file.fasta" |
|
file.touch() |
|
|
|
params = PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode="search", |
|
custom_template_ids=[], |
|
msa_mode="search", |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[1, 3, 4], |
|
) |
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
simple_prediction( |
|
file, |
|
folding_model, |
|
params, |
|
custom_files, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
params = params.model_dump(mode="json") |
|
params.update( |
|
{ |
|
"folding_model": folding_model, |
|
"custom_msa_files": [], |
|
"custom_template_files": [], |
|
"initial_guess_file": None, |
|
"templates_masks_file": None, |
|
} |
|
) |
|
mock_post.assert_called_once_with( |
|
API_URL + "predict", |
|
data=params, |
|
headers=headers, |
|
files=mock.ANY, |
|
timeout=REQUEST_TIMEOUT, |
|
params={"project_code": "FOLDING_DEV"}, |
|
) |
|
|
|
|
|
def test_simple_prediction_fail_because_no_project_code( |
|
remove_project_code_from_env_var, tmp_path: Path |
|
): |
|
"""Test simple prediction fails due to unset project code.""" |
|
|
|
file = tmp_path / "fasta_file.fasta" |
|
file.touch() |
|
|
|
assert os.environ.get("FOLDING_PROJECT_CODE") is None |
|
params = PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode="search", |
|
custom_template_ids=[], |
|
msa_mode="search", |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[1, 3, 4], |
|
) |
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
with pytest.raises(ProjectCodeNotFound): |
|
simple_prediction( |
|
file, |
|
FoldingModel.AF2, |
|
params, |
|
custom_files, |
|
) |
|
|
|
|
|
def test_single_job_prediction_pass( |
|
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str] |
|
): |
|
"""Test simple prediction pass.""" |
|
|
|
file = tmp_path / "fasta_file.fasta" |
|
file.touch() |
|
|
|
parameters = ( |
|
OpenFoldParameters() |
|
if folding_model == FoldingModel.OPENFOLD |
|
else AF2Parameters() |
|
) |
|
|
|
single_job_prediction( |
|
fasta_file=file, |
|
parameters=parameters, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
params = parameters.model_dump(mode="json") |
|
params.update( |
|
{ |
|
"folding_model": folding_model.value, |
|
"custom_msa_files": [], |
|
"custom_template_ids": [], |
|
"custom_template_files": [], |
|
"initial_guess_file": None, |
|
"templates_masks_file": None, |
|
"ignore_cache": False, |
|
} |
|
) |
|
mock_post.assert_called_once_with( |
|
API_URL + "predict", |
|
data=params, |
|
headers=headers, |
|
files=mock.ANY, |
|
timeout=REQUEST_TIMEOUT, |
|
params={"project_code": "FOLDING_DEV"}, |
|
) |
|
|
|
|
|
def test_single_job_prediction_handle_deprecated_af2_parameters( |
|
tmp_path: Path, mock_post: pytest.FixtureRequest, headers: dict[str, str] |
|
): |
|
"""Test simple prediction pass.""" |
|
|
|
file = tmp_path / "fasta_file.fasta" |
|
file.touch() |
|
|
|
parameters = AF2Parameters() |
|
|
|
with pytest.deprecated_call(): |
|
single_job_prediction( |
|
fasta_file=file, |
|
af2_parameters=parameters, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
with pytest.raises(ValueError): |
|
single_job_prediction( |
|
fasta_file=file, |
|
af2_parameters=parameters, |
|
parameters=parameters, |
|
project_code="FOLDING_DEV", |
|
) |
|
|