jfaustin's picture
add dockerfile and folding studio cli
44459bb
import os
from unittest import mock
import pytest
from folding_studio.cli import app
from folding_studio.utils.data_model import (
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio.utils.headers import get_auth_headers
from folding_studio_data_models import (
AF2Parameters,
BatchMessageStatus,
BatchPublication,
FeatureMode,
FoldingModel,
Message,
MessageStatus,
Publication,
)
from typer.testing import CliRunner
runner = CliRunner()
@pytest.fixture(autouse=True)
def mock_get_auth_headers(request):
if "apikeytest" in request.keywords:
os.environ["FOLDING_API_KEY"] = "MY_KEY"
return_value = get_auth_headers()
else:
return_value = {"Authorization": "Bearer identity_token"}
with mock.patch(
"folding_studio.api_call.predict.simple_predict.get_auth_headers",
return_value=return_value,
) as m:
yield m
@pytest.fixture(autouse=True)
def mock_batch_prediction_from_file():
batch_pub = BatchPublication(
publications=[
Publication(
folding_model=FoldingModel.AF2,
message=Message(
pipeline_name="alphafold_inference_pipeline",
user_id="default-user",
project_code="default-project",
parameters=AF2Parameters(),
experiment_id="dummy-experiment",
model_preset="monomer",
fasta_file_name=f"monomer_{idx}.fasta",
ignore_cache=False,
),
status=MessageStatus.PUBLISHED,
)
for idx in range(3)
],
batch_id="batch_id",
cached_publications=[],
status=BatchMessageStatus.PUBLISHED,
cached=False,
)
with mock.patch(
"folding_studio.commands.predict.af2_predict.batch_prediction_from_file",
return_value=batch_pub.model_dump(mode="json"),
) as m:
yield m
@pytest.fixture()
def mock_batch_prediction():
batch_pub = BatchPublication(
publications=[
Publication(
folding_model=FoldingModel.AF2,
message=Message(
pipeline_name="alphafold_inference_pipeline",
user_id="default-user",
project_code="default-project",
parameters=AF2Parameters(),
experiment_id="dummy-experiment",
model_preset="monomer",
fasta_file_name=f"monomer_{idx}.fasta",
ignore_cache=False,
),
status=MessageStatus.PUBLISHED,
)
for idx in range(3)
],
batch_id="batch_id",
cached_publications=[],
status=BatchMessageStatus.PUBLISHED,
cached=False,
)
with mock.patch(
"folding_studio.commands.predict.af2_predict.batch_prediction",
return_value=batch_pub.model_dump(mode="json"),
) as m:
yield m
@pytest.fixture()
def mock_simple_prediction():
pub = Publication(
folding_model=FoldingModel.AF2,
message=Message(
pipeline_name="alphafold_inference_pipeline",
user_id="default-user",
project_code="default-project",
parameters=AF2Parameters(),
experiment_id="dummy-experiment",
model_preset="monomer",
fasta_file_name="monomer.fasta",
ignore_cache=False,
),
status=MessageStatus.PUBLISHED,
)
with mock.patch(
"folding_studio.commands.predict.af2_predict.simple_prediction",
return_value=pub.model_dump(mode="json"),
) as m:
yield m
@pytest.fixture()
def default_params():
yield PredictRequestParams(
ignore_cache=False,
template_mode=FeatureMode.SEARCH,
custom_template_ids=[],
msa_mode=FeatureMode.SEARCH,
max_msa_clusters=-1,
max_extra_msa=-1,
gap_trick=False,
num_recycle=3,
random_seed=0,
model_subset=[],
project_code="FOLDING_DEV",
)
def test_predict_with_unsupported_file_fails(tmp_files):
result = runner.invoke(app, ["predict", "af2", str(tmp_files["invalid_source"])])
assert result.exit_code == 2
assert "Invalid value for 'SOURCE'" in result.stdout
def test_predict_with_directory_containing_unsupported_file_fails(tmp_files):
result = runner.invoke(app, ["predict", "af2", str(tmp_files["invalid_dir"])])
assert result.exit_code == 2
assert "Invalid value for 'SOURCE'" in result.stdout
def test_predict_with_empty_directory_fails(tmp_files):
result = runner.invoke(app, ["predict", "af2", str(tmp_files["empty_dir"])])
assert result.exit_code == 2
assert "Invalid value for 'SOURCE'" in result.stdout
def test_predict_with_unsupported_custom_template_file_fails(tmp_files):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--custom-template",
tmp_files["invalid_template"],
],
)
assert result.exit_code == 2
assert "Invalid value for '--custom-template'" in result.stdout
def test_predict_with_unsupported_custom_msa_file_fails(tmp_files):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--custom-msa",
tmp_files["invalid_msa"],
],
)
assert result.exit_code == 2
assert "Invalid value for '--custom-msa'" in result.stdout
@pytest.mark.parametrize(
"mock_simple_prediction",
(
MessageStatus.PUBLISHED,
MessageStatus.NOT_PUBLISHED_DONE,
MessageStatus.NOT_PUBLISHED_PENDING,
),
indirect=True,
)
def test_predict_with_fasta_file_pass(
mock_simple_prediction: mock.Mock,
tmp_files,
default_params,
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
mock_simple_prediction.assert_called_once_with(
file=tmp_files["monomer_fasta"],
folding_model=FoldingModel.AF2,
params=default_params,
custom_files=custom_files,
project_code="FOLDING_DEV",
)
@pytest.mark.parametrize(
"mock_simple_prediction",
(
MessageStatus.PUBLISHED,
MessageStatus.NOT_PUBLISHED_DONE,
MessageStatus.NOT_PUBLISHED_PENDING,
),
indirect=True,
)
def test_predict_with_fasta_file_pass_with_project_code_from_env_var(
mock_simple_prediction: mock.Mock, tmp_files, default_params
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
# Here we deliberately leave out the
# "--project-code" option to check that
# env variable is correctly used
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
mock_simple_prediction.assert_called_once_with(
file=tmp_files["monomer_fasta"],
params=default_params,
custom_files=custom_files,
folding_model=FoldingModel.AF2,
project_code="FOLDING_DEV",
)
@pytest.mark.parametrize(
"mock_simple_prediction",
(
MessageStatus.PUBLISHED,
MessageStatus.NOT_PUBLISHED_DONE,
MessageStatus.NOT_PUBLISHED_PENDING,
),
indirect=True,
)
@pytest.mark.apikeytest
def test_predict_with_fasta_file_pass_with_jwt(
mock_simple_prediction: mock.Mock,
tmp_files,
default_params,
remove_api_key_from_env_var,
):
assert os.environ.get("FOLDING_API_KEY") is None
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
mock_simple_prediction.assert_called_once_with(
file=tmp_files["monomer_fasta"],
params=default_params,
custom_files=custom_files,
folding_model=FoldingModel.AF2,
project_code="FOLDING_DEV",
)
def test_predict_unset_project_code_fails(
tmp_files,
remove_project_code_from_env_var,
):
assert os.environ.get("FOLDING_PROJECT_CODE") is None
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
],
)
assert result.exit_code == 2, result.output
@pytest.mark.parametrize(
"mock_simple_prediction",
(
MessageStatus.PUBLISHED,
MessageStatus.NOT_PUBLISHED_DONE,
MessageStatus.NOT_PUBLISHED_PENDING,
),
indirect=True,
)
def test_predict_with_fasta_file_with_initial_guess_pass(
mock_simple_prediction: mock.Mock, tmp_files, default_params
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--initial-guess-file",
str(tmp_files["valid_initial_guess"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(
templates=[], msas=[], initial_guess_files=[tmp_files["valid_initial_guess"]]
)
mock_simple_prediction.assert_called_once_with(
file=tmp_files["monomer_fasta"],
folding_model=FoldingModel.AF2,
params=default_params,
custom_files=custom_files,
project_code="FOLDING_DEV",
)
@pytest.mark.parametrize(
"mock_simple_prediction",
(
MessageStatus.PUBLISHED,
MessageStatus.NOT_PUBLISHED_DONE,
MessageStatus.NOT_PUBLISHED_PENDING,
),
indirect=True,
)
def test_predict_with_fasta_file_with_templates_masks_pass(
mock_simple_prediction: mock.Mock, tmp_files, default_params, valid_templates_masks
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--templates-masks-file",
str(valid_templates_masks),
"--custom-template",
str(tmp_files["valid_template_3"]),
"--custom-template",
str(tmp_files["valid_template_4"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(
templates=[tmp_files["valid_template_3"], tmp_files["valid_template_4"]],
msas=[],
templates_masks_files=[valid_templates_masks],
)
mock_simple_prediction.assert_called_once_with(
file=tmp_files["monomer_fasta"],
folding_model=FoldingModel.AF2,
params=default_params,
custom_files=custom_files,
project_code="FOLDING_DEV",
)
def test_predict_with_fasta_file_with_templates_masks_fail(
tmp_files, valid_templates_masks
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--templates-masks-file",
str(valid_templates_masks),
"--custom-template",
str(tmp_files["valid_template_3"]),
"--custom-template",
str(tmp_files["valid_template_2"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 1, result.output
assert "Check your input command." in str(result.exception)
def test_predict_with_fasta_file_with_multi_seed_pass(
mock_batch_prediction: mock.Mock,
tmp_files,
default_params,
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--num-seed",
5,
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
mock_batch_prediction.assert_called_once_with(
files=[tmp_files["monomer_fasta"]],
folding_model=FoldingModel.AF2,
params=default_params,
custom_files=custom_files,
num_seed=5,
project_code="FOLDING_DEV",
)
def test_predict_with_fasta_file_with_initial_guess_with_multi_seed_pass(
mock_batch_prediction: mock.Mock, tmp_files, default_params
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["monomer_fasta"]),
"--num-seed",
5,
"--initial-guess-file",
str(tmp_files["valid_initial_guess"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(
templates=[], msas=[], initial_guess_files=[tmp_files["valid_initial_guess"]]
)
mock_batch_prediction.assert_called_once_with(
files=[tmp_files["monomer_fasta"]],
folding_model=FoldingModel.AF2,
params=default_params,
custom_files=custom_files,
num_seed=5,
project_code="FOLDING_DEV",
)
def test_predict_with_directory_pass(
mock_batch_prediction: mock.Mock,
tmp_files,
default_params,
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files["valid_dir"]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
custom_files = PredictRequestCustomFiles(templates=[], msas=[])
mock_batch_prediction.assert_called_once_with(
files=mock.ANY,
params=default_params,
folding_model=FoldingModel.AF2,
custom_files=custom_files,
num_seed=None,
project_code="FOLDING_DEV",
)
_, kwargs = mock_batch_prediction.call_args
assert sorted(kwargs["files"]) == sorted(
list(tmp_files["valid_dir"].iterdir()),
)
@pytest.mark.parametrize("file", ["valid_batch_file_json", "valid_batch_file_csv"])
def test_predict_with_json_batch_file_pass(
file, mock_batch_prediction_from_file: mock.Mock, tmp_files
):
result = runner.invoke(
app,
[
"predict",
"af2",
str(tmp_files[file]),
"--metadata-file",
str(tmp_files["metadata_file"]),
"--project-code",
"FOLDING_DEV",
],
)
assert result.exit_code == 0, result.output
mock_batch_prediction_from_file.assert_called_once_with(
file=tmp_files[file],
project_code="FOLDING_DEV",
)