|
"""Test simple prediction.""" |
|
|
|
from pathlib import Path |
|
from unittest import mock |
|
|
|
import pytest |
|
from folding_studio.api_call.predict.batch_predict import ( |
|
_build_request_from_fasta, |
|
_extract_sequences_from_file, |
|
batch_prediction, |
|
batch_prediction_from_file, |
|
) |
|
from folding_studio.config import API_URL, REQUEST_TIMEOUT |
|
from folding_studio.utils.data_model import ( |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio_data_models import ( |
|
AF2Request, |
|
BatchRequest, |
|
FeatureMode, |
|
FoldingModel, |
|
OpenFoldRequest, |
|
) |
|
from folding_studio_data_models.exceptions import DuplicatedRequest |
|
|
|
current_workdir = Path(__file__).parent.resolve() |
|
data_dir = Path(current_workdir / "data") |
|
|
|
@pytest.fixture() |
|
def mock_post(): |
|
post_mock = mock.Mock() |
|
mock_response = mock.MagicMock() |
|
mock_response.ok = True |
|
post_mock.return_value = mock_response |
|
with mock.patch("requests.post", post_mock): |
|
yield post_mock |
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_get_auth_headers(): |
|
with mock.patch( |
|
"folding_studio.api_call.predict.batch_predict.get_auth_headers", return_value={'Authorization': 'Bearer identity_token'} |
|
) as m: |
|
yield m |
|
|
|
def test_build_request_from_fasta_pass(folding_model: FoldingModel): |
|
file = Path(data_dir, "protein.fasta") |
|
|
|
params = PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode=FeatureMode.SEARCH, |
|
custom_template_ids=["AB12"], |
|
msa_mode=FeatureMode.SEARCH, |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[1, 3, 4], |
|
) |
|
custom_files = PredictRequestCustomFiles( |
|
templates=[Path("gs://custom_template.cif")], |
|
msas=[Path("gs://custom_msa.sto")], |
|
uploaded=True, |
|
) |
|
request = _build_request_from_fasta( |
|
file, |
|
folding_model=folding_model, |
|
params=params, |
|
custom_files=custom_files, |
|
) |
|
|
|
parameters = dict( |
|
num_recycle=params.num_recycle, |
|
random_seed=params.random_seed, |
|
custom_templates=params.custom_template_ids + ["gs://custom_template.cif"], |
|
custom_msas=["gs://custom_msa.sto"], |
|
gap_trick=params.gap_trick, |
|
msa_mode=params.msa_mode, |
|
max_msa_clusters=params.max_msa_clusters, |
|
max_extra_msa=params.max_extra_msa, |
|
template_mode=params.template_mode, |
|
model_subset=params.model_subset, |
|
) |
|
|
|
if folding_model == FoldingModel.AF2: |
|
expected_request = AF2Request( |
|
complex_id="protein", |
|
sequences=_extract_sequences_from_file(file), |
|
parameters=parameters, |
|
) |
|
else: |
|
expected_request = OpenFoldRequest( |
|
complex_id="protein", |
|
sequences=_extract_sequences_from_file(file), |
|
parameters=parameters, |
|
) |
|
assert request == expected_request |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"fasta_files, file_contents", |
|
[ |
|
( |
|
["fasta_file_1.fasta", "fasta_file_1_duplicate.fasta"], |
|
[">A\nA\n", ">A\nA\n"], |
|
), |
|
], |
|
) |
|
def test_batch_prediction_fail_duplicate( |
|
tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers, |
|
): |
|
"""Test batch prediction for duplicates and unique files.""" |
|
|
|
fasta_paths = [tmp_path / file_name for file_name in fasta_files] |
|
for f, content in zip(fasta_paths, file_contents): |
|
with open(f, "w") as fasta_file: |
|
fasta_file.write(content) |
|
|
|
params = PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode=FeatureMode.SEARCH, |
|
custom_template_ids=[], |
|
msa_mode=FeatureMode.SEARCH, |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[], |
|
) |
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True) |
|
|
|
with pytest.raises(DuplicatedRequest): |
|
expected_request = BatchRequest( |
|
requests=[ |
|
_build_request_from_fasta( |
|
file, |
|
folding_model=folding_model, |
|
params=params, |
|
custom_files=custom_files, |
|
) |
|
for file in fasta_paths |
|
] |
|
) |
|
|
|
mocked_local_to_uploaded = {} |
|
with mock.patch( |
|
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", |
|
return_value=mocked_local_to_uploaded, |
|
): |
|
with pytest.raises(DuplicatedRequest): |
|
batch_prediction( |
|
fasta_paths, |
|
folding_model, |
|
params, |
|
custom_files, |
|
num_seed=None, |
|
project_code="FOLDING_DEV", |
|
) |
|
mock_post.assert_called_once_with( |
|
API_URL + "batchPredict", |
|
data={"batch_jobs_request": expected_request.model_dump_json()}, |
|
headers=headers, |
|
timeout=REQUEST_TIMEOUT, |
|
params={"project_code": "FOLDING_DEV"}, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"fasta_files, file_contents", |
|
[ |
|
(["fasta_file_1.fasta", "fasta_file_1_unique.fasta"], [">A\nA\n", ">B\nb\n"]), |
|
], |
|
) |
|
def test_batch_prediction_pass( |
|
tmp_path: Path, mock_post: mock.Mock, fasta_files, file_contents, folding_model, headers, |
|
): |
|
"""Test batch prediction for duplicates and unique files.""" |
|
|
|
fasta_paths = [tmp_path / file_name for file_name in fasta_files] |
|
for f, content in zip(fasta_paths, file_contents): |
|
with open(f, "w") as fasta_file: |
|
fasta_file.write(content) |
|
|
|
params = PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode=FeatureMode.SEARCH, |
|
custom_template_ids=[], |
|
msa_mode=FeatureMode.SEARCH, |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[], |
|
) |
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[], uploaded=True) |
|
|
|
expected_request = BatchRequest( |
|
requests=[ |
|
_build_request_from_fasta( |
|
file, |
|
folding_model=folding_model, |
|
params=params, |
|
custom_files=custom_files, |
|
) |
|
for file in fasta_paths |
|
] |
|
) |
|
|
|
mocked_local_to_uploaded = {} |
|
with mock.patch( |
|
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", |
|
return_value=mocked_local_to_uploaded, |
|
): |
|
batch_prediction( |
|
fasta_paths, |
|
folding_model, |
|
params, |
|
custom_files, |
|
num_seed=None, |
|
project_code="FOLDING_DEV", |
|
) |
|
mock_post.assert_called_once_with( |
|
API_URL + "batchPredict", |
|
data={"batch_jobs_request": expected_request.model_dump_json()}, |
|
headers=headers, |
|
timeout=REQUEST_TIMEOUT, |
|
params={"project_code": "FOLDING_DEV"}, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("batch_file", ["batch.json", "batch.csv"]) |
|
def test_batch_prediction_from_file_pass(mock_post: mock.Mock, batch_file: str, headers: dict[str, str]): |
|
"""Test batch prediction pass.""" |
|
|
|
|
|
batch_jobs_file = Path(__file__).parent / f"data/{batch_file}" |
|
custom_files = [ |
|
batch_jobs_file.parent / "1agw.cif", |
|
batch_jobs_file.parent / "1agz.cif", |
|
batch_jobs_file.parent / "1agb_A.sto", |
|
batch_jobs_file.parent / "1agb_B.sto", |
|
batch_jobs_file.parent / "6m0j_A.sto", |
|
batch_jobs_file.parent / "6m0j_B.sto", |
|
] |
|
for f in custom_files: |
|
f.touch() |
|
|
|
mocked_local_to_uploaded = {} |
|
|
|
for local in custom_files: |
|
mocked_local_to_uploaded[str(local.name)] = f"gs://bucket/{local.name}" |
|
with mock.patch( |
|
"folding_studio.api_call.predict.batch_predict.PredictRequestCustomFiles.upload", |
|
return_value=mocked_local_to_uploaded, |
|
): |
|
batch_prediction_from_file(batch_jobs_file, project_code="FOLDING_DEV") |
|
|
|
try: |
|
mock_post.assert_called_once_with( |
|
API_URL + "batchPredictFromFile", |
|
headers=headers, |
|
files=[ |
|
("batch_jobs_file", mock.ANY), |
|
], |
|
timeout=REQUEST_TIMEOUT, |
|
params={"project_code": "FOLDING_DEV"}, |
|
) |
|
finally: |
|
for f in custom_files: |
|
f.unlink() |
|
|