jfaustin's picture
add dockerfile and folding studio cli
44459bb
raw
history blame
5.03 kB
"""Test simple prediction."""
import os
from pathlib import Path
from unittest import mock
import pytest
from folding_studio.api_call.predict.simple_predict import (
simple_prediction,
single_job_prediction,
)
from folding_studio.config import API_URL, REQUEST_TIMEOUT
from folding_studio.utils.data_model import (
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio.utils.exceptions import ProjectCodeNotFound
from folding_studio_data_models import AF2Parameters, OpenFoldParameters
from folding_studio_data_models.request.folding import FoldingModel
@pytest.fixture(autouse=True)
def mock_post():
post_mock = mock.Mock()
mock_response = mock.MagicMock()
mock_response.ok = True
post_mock.return_value = mock_response
with mock.patch("requests.post", post_mock):
yield post_mock
@pytest.fixture(autouse=True)
def mock_get_auth_headers():
with mock.patch(
"folding_studio.api_call.predict.simple_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'}
) as m:
yield m
def test_simple_prediction_pass(
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
params = PredictRequestParams(
ignore_cache=False,
template_mode="search",
custom_template_ids=[],
msa_mode="search",
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[1, 3, 4],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
simple_prediction(
file,
folding_model,
params,
custom_files,
project_code="FOLDING_DEV",
)
params = params.model_dump(mode="json")
params.update(
{
"folding_model": folding_model,
"custom_msa_files": [],
"custom_template_files": [],
"initial_guess_file": None,
"templates_masks_file": None,
}
)
mock_post.assert_called_once_with(
API_URL + "predict",
data=params,
headers=headers,
files=mock.ANY,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
def test_simple_prediction_fail_because_no_project_code(
remove_project_code_from_env_var, tmp_path: Path
):
"""Test simple prediction fails due to unset project code."""
file = tmp_path / "fasta_file.fasta"
file.touch()
assert os.environ.get("FOLDING_PROJECT_CODE") is None
params = PredictRequestParams(
ignore_cache=False,
template_mode="search",
custom_template_ids=[],
msa_mode="search",
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[1, 3, 4],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
with pytest.raises(ProjectCodeNotFound):
simple_prediction(
file,
FoldingModel.AF2,
params,
custom_files,
)
def test_single_job_prediction_pass(
tmp_path: Path, mock_post: pytest.FixtureRequest, folding_model: FoldingModel, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
parameters = (
OpenFoldParameters()
if folding_model == FoldingModel.OPENFOLD
else AF2Parameters()
)
single_job_prediction(
fasta_file=file,
parameters=parameters,
project_code="FOLDING_DEV",
)
params = parameters.model_dump(mode="json")
params.update(
{
"folding_model": folding_model.value,
"custom_msa_files": [],
"custom_template_ids": [],
"custom_template_files": [],
"initial_guess_file": None,
"templates_masks_file": None,
"ignore_cache": False,
}
)
mock_post.assert_called_once_with(
API_URL + "predict",
data=params,
headers=headers,
files=mock.ANY,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
def test_single_job_prediction_handle_deprecated_af2_parameters(
tmp_path: Path, mock_post: pytest.FixtureRequest, headers: dict[str, str]
):
"""Test simple prediction pass."""
file = tmp_path / "fasta_file.fasta"
file.touch()
parameters = AF2Parameters()
with pytest.deprecated_call():
single_job_prediction(
fasta_file=file,
af2_parameters=parameters,
project_code="FOLDING_DEV",
)
with pytest.raises(ValueError):
single_job_prediction(
fasta_file=file,
af2_parameters=parameters,
parameters=parameters,
project_code="FOLDING_DEV",
)