|
import os |
|
from unittest import mock |
|
|
|
import pytest |
|
from folding_studio.cli import app |
|
from folding_studio.utils.data_model import ( |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio.utils.headers import get_auth_headers |
|
from folding_studio_data_models import ( |
|
AF2Parameters, |
|
BatchMessageStatus, |
|
BatchPublication, |
|
FeatureMode, |
|
FoldingModel, |
|
Message, |
|
MessageStatus, |
|
Publication, |
|
) |
|
from typer.testing import CliRunner |
|
|
|
runner = CliRunner() |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_get_auth_headers(request): |
|
if "apikeytest" in request.keywords: |
|
os.environ["FOLDING_API_KEY"] = "MY_KEY" |
|
return_value = get_auth_headers() |
|
else: |
|
return_value = {"Authorization": "Bearer identity_token"} |
|
with mock.patch( |
|
"folding_studio.api_call.predict.simple_predict.get_auth_headers", |
|
return_value=return_value, |
|
) as m: |
|
yield m |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_batch_prediction_from_file(): |
|
batch_pub = BatchPublication( |
|
publications=[ |
|
Publication( |
|
folding_model=FoldingModel.AF2, |
|
message=Message( |
|
pipeline_name="alphafold_inference_pipeline", |
|
user_id="default-user", |
|
project_code="default-project", |
|
parameters=AF2Parameters(), |
|
experiment_id="dummy-experiment", |
|
model_preset="monomer", |
|
fasta_file_name=f"monomer_{idx}.fasta", |
|
ignore_cache=False, |
|
), |
|
status=MessageStatus.PUBLISHED, |
|
) |
|
for idx in range(3) |
|
], |
|
batch_id="batch_id", |
|
cached_publications=[], |
|
status=BatchMessageStatus.PUBLISHED, |
|
cached=False, |
|
) |
|
with mock.patch( |
|
"folding_studio.commands.predict.af2_predict.batch_prediction_from_file", |
|
return_value=batch_pub.model_dump(mode="json"), |
|
) as m: |
|
yield m |
|
|
|
|
|
@pytest.fixture() |
|
def mock_batch_prediction(): |
|
batch_pub = BatchPublication( |
|
publications=[ |
|
Publication( |
|
folding_model=FoldingModel.AF2, |
|
message=Message( |
|
pipeline_name="alphafold_inference_pipeline", |
|
user_id="default-user", |
|
project_code="default-project", |
|
parameters=AF2Parameters(), |
|
experiment_id="dummy-experiment", |
|
model_preset="monomer", |
|
fasta_file_name=f"monomer_{idx}.fasta", |
|
ignore_cache=False, |
|
), |
|
status=MessageStatus.PUBLISHED, |
|
) |
|
for idx in range(3) |
|
], |
|
batch_id="batch_id", |
|
cached_publications=[], |
|
status=BatchMessageStatus.PUBLISHED, |
|
cached=False, |
|
) |
|
|
|
with mock.patch( |
|
"folding_studio.commands.predict.af2_predict.batch_prediction", |
|
return_value=batch_pub.model_dump(mode="json"), |
|
) as m: |
|
yield m |
|
|
|
|
|
@pytest.fixture() |
|
def mock_simple_prediction(): |
|
pub = Publication( |
|
folding_model=FoldingModel.AF2, |
|
message=Message( |
|
pipeline_name="alphafold_inference_pipeline", |
|
user_id="default-user", |
|
project_code="default-project", |
|
parameters=AF2Parameters(), |
|
experiment_id="dummy-experiment", |
|
model_preset="monomer", |
|
fasta_file_name="monomer.fasta", |
|
ignore_cache=False, |
|
), |
|
status=MessageStatus.PUBLISHED, |
|
) |
|
|
|
with mock.patch( |
|
"folding_studio.commands.predict.af2_predict.simple_prediction", |
|
return_value=pub.model_dump(mode="json"), |
|
) as m: |
|
yield m |
|
|
|
|
|
@pytest.fixture() |
|
def default_params(): |
|
yield PredictRequestParams( |
|
ignore_cache=False, |
|
template_mode=FeatureMode.SEARCH, |
|
custom_template_ids=[], |
|
msa_mode=FeatureMode.SEARCH, |
|
max_msa_clusters=-1, |
|
max_extra_msa=-1, |
|
gap_trick=False, |
|
num_recycle=3, |
|
random_seed=0, |
|
model_subset=[], |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
def test_predict_with_unsupported_file_fails(tmp_files): |
|
result = runner.invoke(app, ["predict", "af2", str(tmp_files["invalid_source"])]) |
|
assert result.exit_code == 2 |
|
assert "Invalid value for 'SOURCE'" in result.stdout |
|
|
|
|
|
def test_predict_with_directory_containing_unsupported_file_fails(tmp_files): |
|
result = runner.invoke(app, ["predict", "af2", str(tmp_files["invalid_dir"])]) |
|
assert result.exit_code == 2 |
|
assert "Invalid value for 'SOURCE'" in result.stdout |
|
|
|
|
|
def test_predict_with_empty_directory_fails(tmp_files): |
|
result = runner.invoke(app, ["predict", "af2", str(tmp_files["empty_dir"])]) |
|
assert result.exit_code == 2 |
|
assert "Invalid value for 'SOURCE'" in result.stdout |
|
|
|
|
|
def test_predict_with_unsupported_custom_template_file_fails(tmp_files): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--custom-template", |
|
tmp_files["invalid_template"], |
|
], |
|
) |
|
assert result.exit_code == 2 |
|
assert "Invalid value for '--custom-template'" in result.stdout |
|
|
|
|
|
def test_predict_with_unsupported_custom_msa_file_fails(tmp_files): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--custom-msa", |
|
tmp_files["invalid_msa"], |
|
], |
|
) |
|
assert result.exit_code == 2 |
|
assert "Invalid value for '--custom-msa'" in result.stdout |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"mock_simple_prediction", |
|
( |
|
MessageStatus.PUBLISHED, |
|
MessageStatus.NOT_PUBLISHED_DONE, |
|
MessageStatus.NOT_PUBLISHED_PENDING, |
|
), |
|
indirect=True, |
|
) |
|
def test_predict_with_fasta_file_pass( |
|
mock_simple_prediction: mock.Mock, |
|
tmp_files, |
|
default_params, |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
mock_simple_prediction.assert_called_once_with( |
|
file=tmp_files["monomer_fasta"], |
|
folding_model=FoldingModel.AF2, |
|
params=default_params, |
|
custom_files=custom_files, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"mock_simple_prediction", |
|
( |
|
MessageStatus.PUBLISHED, |
|
MessageStatus.NOT_PUBLISHED_DONE, |
|
MessageStatus.NOT_PUBLISHED_PENDING, |
|
), |
|
indirect=True, |
|
) |
|
def test_predict_with_fasta_file_pass_with_project_code_from_env_var( |
|
mock_simple_prediction: mock.Mock, tmp_files, default_params |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
|
|
|
|
|
|
], |
|
) |
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
mock_simple_prediction.assert_called_once_with( |
|
file=tmp_files["monomer_fasta"], |
|
params=default_params, |
|
custom_files=custom_files, |
|
folding_model=FoldingModel.AF2, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"mock_simple_prediction", |
|
( |
|
MessageStatus.PUBLISHED, |
|
MessageStatus.NOT_PUBLISHED_DONE, |
|
MessageStatus.NOT_PUBLISHED_PENDING, |
|
), |
|
indirect=True, |
|
) |
|
@pytest.mark.apikeytest |
|
def test_predict_with_fasta_file_pass_with_jwt( |
|
mock_simple_prediction: mock.Mock, |
|
tmp_files, |
|
default_params, |
|
remove_api_key_from_env_var, |
|
): |
|
assert os.environ.get("FOLDING_API_KEY") is None |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
], |
|
) |
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
mock_simple_prediction.assert_called_once_with( |
|
file=tmp_files["monomer_fasta"], |
|
params=default_params, |
|
custom_files=custom_files, |
|
folding_model=FoldingModel.AF2, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
def test_predict_unset_project_code_fails( |
|
tmp_files, |
|
remove_project_code_from_env_var, |
|
): |
|
assert os.environ.get("FOLDING_PROJECT_CODE") is None |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
], |
|
) |
|
assert result.exit_code == 2, result.output |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"mock_simple_prediction", |
|
( |
|
MessageStatus.PUBLISHED, |
|
MessageStatus.NOT_PUBLISHED_DONE, |
|
MessageStatus.NOT_PUBLISHED_PENDING, |
|
), |
|
indirect=True, |
|
) |
|
def test_predict_with_fasta_file_with_initial_guess_pass( |
|
mock_simple_prediction: mock.Mock, tmp_files, default_params |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--initial-guess-file", |
|
str(tmp_files["valid_initial_guess"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles( |
|
templates=[], msas=[], initial_guess_files=[tmp_files["valid_initial_guess"]] |
|
) |
|
|
|
mock_simple_prediction.assert_called_once_with( |
|
file=tmp_files["monomer_fasta"], |
|
folding_model=FoldingModel.AF2, |
|
params=default_params, |
|
custom_files=custom_files, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"mock_simple_prediction", |
|
( |
|
MessageStatus.PUBLISHED, |
|
MessageStatus.NOT_PUBLISHED_DONE, |
|
MessageStatus.NOT_PUBLISHED_PENDING, |
|
), |
|
indirect=True, |
|
) |
|
def test_predict_with_fasta_file_with_templates_masks_pass( |
|
mock_simple_prediction: mock.Mock, tmp_files, default_params, valid_templates_masks |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--templates-masks-file", |
|
str(valid_templates_masks), |
|
"--custom-template", |
|
str(tmp_files["valid_template_3"]), |
|
"--custom-template", |
|
str(tmp_files["valid_template_4"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles( |
|
templates=[tmp_files["valid_template_3"], tmp_files["valid_template_4"]], |
|
msas=[], |
|
templates_masks_files=[valid_templates_masks], |
|
) |
|
|
|
mock_simple_prediction.assert_called_once_with( |
|
file=tmp_files["monomer_fasta"], |
|
folding_model=FoldingModel.AF2, |
|
params=default_params, |
|
custom_files=custom_files, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
def test_predict_with_fasta_file_with_templates_masks_fail( |
|
tmp_files, valid_templates_masks |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--templates-masks-file", |
|
str(valid_templates_masks), |
|
"--custom-template", |
|
str(tmp_files["valid_template_3"]), |
|
"--custom-template", |
|
str(tmp_files["valid_template_2"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
assert result.exit_code == 1, result.output |
|
assert "Check your input command." in str(result.exception) |
|
|
|
|
|
def test_predict_with_fasta_file_with_multi_seed_pass( |
|
mock_batch_prediction: mock.Mock, |
|
tmp_files, |
|
default_params, |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--num-seed", |
|
5, |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
|
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
mock_batch_prediction.assert_called_once_with( |
|
files=[tmp_files["monomer_fasta"]], |
|
folding_model=FoldingModel.AF2, |
|
params=default_params, |
|
custom_files=custom_files, |
|
num_seed=5, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
def test_predict_with_fasta_file_with_initial_guess_with_multi_seed_pass( |
|
mock_batch_prediction: mock.Mock, tmp_files, default_params |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["monomer_fasta"]), |
|
"--num-seed", |
|
5, |
|
"--initial-guess-file", |
|
str(tmp_files["valid_initial_guess"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
|
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles( |
|
templates=[], msas=[], initial_guess_files=[tmp_files["valid_initial_guess"]] |
|
) |
|
|
|
mock_batch_prediction.assert_called_once_with( |
|
files=[tmp_files["monomer_fasta"]], |
|
folding_model=FoldingModel.AF2, |
|
params=default_params, |
|
custom_files=custom_files, |
|
num_seed=5, |
|
project_code="FOLDING_DEV", |
|
) |
|
|
|
|
|
def test_predict_with_directory_pass( |
|
mock_batch_prediction: mock.Mock, |
|
tmp_files, |
|
default_params, |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files["valid_dir"]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
|
|
assert result.exit_code == 0, result.output |
|
|
|
custom_files = PredictRequestCustomFiles(templates=[], msas=[]) |
|
|
|
mock_batch_prediction.assert_called_once_with( |
|
files=mock.ANY, |
|
params=default_params, |
|
folding_model=FoldingModel.AF2, |
|
custom_files=custom_files, |
|
num_seed=None, |
|
project_code="FOLDING_DEV", |
|
) |
|
_, kwargs = mock_batch_prediction.call_args |
|
assert sorted(kwargs["files"]) == sorted( |
|
list(tmp_files["valid_dir"].iterdir()), |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("file", ["valid_batch_file_json", "valid_batch_file_csv"]) |
|
def test_predict_with_json_batch_file_pass( |
|
file, mock_batch_prediction_from_file: mock.Mock, tmp_files |
|
): |
|
result = runner.invoke( |
|
app, |
|
[ |
|
"predict", |
|
"af2", |
|
str(tmp_files[file]), |
|
"--metadata-file", |
|
str(tmp_files["metadata_file"]), |
|
"--project-code", |
|
"FOLDING_DEV", |
|
], |
|
) |
|
|
|
assert result.exit_code == 0, result.output |
|
|
|
mock_batch_prediction_from_file.assert_called_once_with( |
|
file=tmp_files[file], |
|
project_code="FOLDING_DEV", |
|
) |
|
|