File size: 5,033 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""Test simple prediction."""
import os
from pathlib import Path
from unittest import mock
import pytest
from folding_studio.api_call.predict.simple_predict import (
simple_prediction,
single_job_prediction,
)
from folding_studio.config import API_URL, REQUEST_TIMEOUT
from folding_studio.utils.data_model import (
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio.utils.exceptions import ProjectCodeNotFound
from folding_studio_data_models import AF2Parameters, OpenFoldParameters
from folding_studio_data_models.request.folding import FoldingModel
@pytest.fixture(autouse=True)
def mock_post():
post_mock = mock.Mock()
mock_response = mock.MagicMock()
mock_response.ok = True
post_mock.return_value = mock_response
with mock.patch("requests.post", post_mock):
yield post_mock
@pytest.fixture(autouse=True)
def mock_get_auth_headers():
with mock.patch(
"folding_studio.api_call.predict.simple_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'}
) as m:
yield m
def test_simple_prediction_pass(
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
params = PredictRequestParams(
ignore_cache=False,
template_mode="search",
custom_template_ids=[],
msa_mode="search",
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[1, 3, 4],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
simple_prediction(
file,
folding_model,
params,
custom_files,
project_code="FOLDING_DEV",
)
params = params.model_dump(mode="json")
params.update(
{
"folding_model": folding_model,
"custom_msa_files": [],
"custom_template_files": [],
"initial_guess_file": None,
"templates_masks_file": None,
}
)
mock_post.assert_called_once_with(
API_URL + "predict",
data=params,
headers=headers,
files=mock.ANY,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
def test_simple_prediction_fail_because_no_project_code(
remove_project_code_from_env_var, tmp_path: Path
):
"""Test simple prediction fails due to unset project code."""
file = tmp_path / "fasta_file.fasta"
file.touch()
assert os.environ.get("FOLDING_PROJECT_CODE") is None
params = PredictRequestParams(
ignore_cache=False,
template_mode="search",
custom_template_ids=[],
msa_mode="search",
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[1, 3, 4],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
with pytest.raises(ProjectCodeNotFound):
simple_prediction(
file,
FoldingModel.AF2,
params,
custom_files,
)
def test_single_job_prediction_pass(
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
parameters = (
OpenFoldParameters()
if folding_model == FoldingModel.OPENFOLD
else AF2Parameters()
)
single_job_prediction(
fasta_file=file,
parameters=parameters,
project_code="FOLDING_DEV",
)
params = parameters.model_dump(mode="json")
params.update(
{
"folding_model": folding_model.value,
"custom_msa_files": [],
"custom_template_ids": [],
"custom_template_files": [],
"initial_guess_file": None,
"templates_masks_file": None,
"ignore_cache": False,
}
)
mock_post.assert_called_once_with(
API_URL + "predict",
data=params,
headers=headers,
files=mock.ANY,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
def test_single_job_prediction_handle_deprecated_af2_parameters(
tmp_path: Path, mock_post: pytest.FixtureRequest, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
parameters = AF2Parameters()
with pytest.deprecated_call():
single_job_prediction(
fasta_file=file,
af2_parameters=parameters,
project_code="FOLDING_DEV",
)
with pytest.raises(ValueError):
single_job_prediction(
fasta_file=file,
af2_parameters=parameters,
parameters=parameters,
project_code="FOLDING_DEV",
)
|