jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Test simple prediction."""
from pathlib import Path
from unittest import mock
import pytest
from folding_studio.api_call.predict.batch_predict import (
_build_request_from_fasta,
_extract_sequences_from_file,
batch_prediction,
batch_prediction_from_file,
)
from folding_studio.config import API_URL, REQUEST_TIMEOUT
from folding_studio.utils.data_model import (
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio_data_models import (
AF2Request,
BatchRequest,
FeatureMode,
FoldingModel,
OpenFoldRequest,
)
from folding_studio_data_models.exceptions import DuplicatedRequest
current_workdir = Path(__file__).parent.resolve()
data_dir = Path(current_workdir / "data")
@pytest.fixture()
def mock_post():
post_mock = mock.Mock()
mock_response = mock.MagicMock()
mock_response.ok = True
post_mock.return_value = mock_response
with mock.patch("requests.post", post_mock):
yield post_mock
@pytest.fixture(autouse=True)
def mock_get_auth_headers():
with mock.patch(
"folding_studio.api_call.predict.batch_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'}
) as m:
yield m
def test_build_request_from_fasta_pass(folding_model: FoldingModel):
file = Path(data_dir, "protein.fasta")
params = PredictRequestParams(
ignore_cache=False,
template_mode=FeatureMode.SEARCH,
custom_template_ids=["AB12"],
msa_mode=FeatureMode.SEARCH,
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[1, 3, 4],
)
custom_files = PredictRequestCustomFiles(
templates=[Path("gs://custom_template.cif")],
msas=[Path("gs://custom_msa.sto")],
uploaded=True,
)
request = _build_request_from_fasta(
file,
folding_model=folding_model,
params=params,
custom_files=custom_files,
)
parameters = dict(
num_recycle=params.num_recycle,
random_seed=params.random_seed,
custom_templates=params.custom_template_ids + ["gs://custom_template.cif"],
custom_msas=["gs://custom_msa.sto"],
gap_trick=params.gap_trick,
msa_mode=params.msa_mode,
max_msa_clusters=params.max_msa_clusters,
max_extra_msa=params.max_extra_msa,
template_mode=params.template_mode,
model_subset=params.model_subset,
)
if folding_model == FoldingModel.AF2:
expected_request = AF2Request(
complex_id="protein",
sequences=_extract_sequences_from_file(file),
parameters=parameters,
)
else:
expected_request = OpenFoldRequest(
complex_id="protein",
sequences=_extract_sequences_from_file(file),
parameters=parameters,
)
assert request == expected_request
@pytest.mark.parametrize(
"fasta_files, file_contents",
[
(
["fasta_file_1.fasta", "fasta_file_1_duplicate.fasta"],
[">A\nA\n", ">A\nA\n"],
), # Duplicate content, different files
],
)
def test_batch_prediction_fail_duplicate(
tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers,
):
"""Test batch prediction for duplicates and unique files."""
fasta_paths = [tmp_path / file_name for file_name in fasta_files]
for f, content in zip(fasta_paths, file_contents):
with open(f, "w") as fasta_file:
fasta_file.write(content)
params = PredictRequestParams(
ignore_cache=False,
template_mode=FeatureMode.SEARCH,
custom_template_ids=[],
msa_mode=FeatureMode.SEARCH,
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True)
with pytest.raises(DuplicatedRequest):
expected_request = BatchRequest(
requests=[
_build_request_from_fasta(
file,
folding_model=folding_model,
params=params,
custom_files=custom_files,
)
for file in fasta_paths
]
)
mocked_local_to_uploaded = {}
with mock.patch(
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload",
return_value=mocked_local_to_uploaded,
):
with pytest.raises(DuplicatedRequest):
batch_prediction(
fasta_paths,
folding_model,
params,
custom_files,
num_seed=None,
project_code="FOLDING_DEV",
)
mock_post.assert_called_once_with(
API_URL + "batchPredict",
data={"batch_jobs_request": expected_request.model_dump_json()},
headers=headers,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
@pytest.mark.parametrize(
"fasta_files, file_contents",
[
(["fasta_file_1.fasta", "fasta_file_1_unique.fasta"], [">A\nA\n", ">B\nb\n"]),
],
)
def test_batch_prediction_pass(
tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers,
):
"""Test batch prediction for duplicates and unique files."""
fasta_paths = [tmp_path / file_name for file_name in fasta_files]
for f, content in zip(fasta_paths, file_contents):
with open(f, "w") as fasta_file:
fasta_file.write(content)
params = PredictRequestParams(
ignore_cache=False,
template_mode=FeatureMode.SEARCH,
custom_template_ids=[],
msa_mode=FeatureMode.SEARCH,
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[],
)
custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True)
expected_request = BatchRequest(
requests=[
_build_request_from_fasta(
file,
folding_model=folding_model,
params=params,
custom_files=custom_files,
)
for file in fasta_paths
]
)
mocked_local_to_uploaded = {}
with mock.patch(
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload",
return_value=mocked_local_to_uploaded,
):
batch_prediction(
fasta_paths,
folding_model,
params,
custom_files,
num_seed=None,
project_code="FOLDING_DEV",
)
mock_post.assert_called_once_with(
API_URL + "batchPredict",
data={"batch_jobs_request": expected_request.model_dump_json()},
headers=headers,
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
@pytest.mark.parametrize("batch_file", ["batch.json", "batch.csv"])
def test_batch_prediction_from_file_pass(mock_post: mock.Mock, batch_file: str, headers: dict[str, str]):
"""Test batch prediction pass."""
batch_jobs_file = Path(__file__).parent / f"data/{batch_file}"
custom_files = [
batch_jobs_file.parent / "1agw.cif",
batch_jobs_file.parent / "1agz.cif",
batch_jobs_file.parent / "1agb_A.sto",
batch_jobs_file.parent / "1agb_B.sto",
batch_jobs_file.parent / "6m0j_A.sto",
batch_jobs_file.parent / "6m0j_B.sto",
]
for f in custom_files:
f.touch()
mocked_local_to_uploaded = {}
for local in custom_files:
mocked_local_to_uploaded[str(local.name)] = f"gs://bucket/{local.name}"
with mock.patch(
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload",
return_value=mocked_local_to_uploaded,
):
batch_prediction_from_file(batch_jobs_file, project_code="FOLDING_DEV")
try:
mock_post.assert_called_once_with(
API_URL + "batchPredictFromFile",
headers=headers,
files=[
("batch_jobs_file", mock.ANY),
],
timeout=REQUEST_TIMEOUT,
params={"project_code": "FOLDING_DEV"},
)
finally:
for f in custom_files:
f.unlink()