import os from unittest import mock import pytest from folding_studio.cli import app from folding_studio.utils.data_model import ( PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio.utils.headers import get_auth_headers from folding_studio_data_models import ( BatchMessageStatus, BatchPublication, FeatureMode, FoldingModel, Message, MessageStatus, OpenFoldParameters, Publication, ) from typer.testing import CliRunner runner = CliRunner() @pytest.fixture(autouse=True) def mock_get_auth_headers(request): if "apikeytest" in request.keywords: os.environ["FOLDING_API_KEY"] = "MY_KEY" return_value = get_auth_headers() else: return_value = {"Authorization": "Bearer identity_token"} with mock.patch( "folding_studio.api_call.predict.simple_predict.get_auth_headers", return_value=return_value, ) as m: yield m @pytest.fixture(autouse=True) def mock_batch_prediction_from_file(): batch_pub = BatchPublication( publications=[ Publication( folding_model=FoldingModel.OPENFOLD, message=Message( pipeline_name="alphafold_inference_pipeline", user_id="default-user", project_code="default-project", parameters=OpenFoldParameters(), experiment_id="dummy-experiment", model_preset="monomer", fasta_file_name=f"monomer_{idx}.fasta", ignore_cache=False, ), status=MessageStatus.PUBLISHED, ) for idx in range(3) ], batch_id="batch_id", cached_publications=[], status=BatchMessageStatus.PUBLISHED, cached=False, ) with mock.patch( "folding_studio.commands.predict.openfold_predict.batch_prediction_from_file", return_value=batch_pub.model_dump(mode="json"), ) as m: yield m @pytest.fixture() def mock_batch_prediction(): batch_pub = BatchPublication( publications=[ Publication( folding_model=FoldingModel.OPENFOLD, message=Message( pipeline_name="alphafold_inference_pipeline", user_id="default-user", project_code="default-project", parameters=OpenFoldParameters(), experiment_id="dummy-experiment", model_preset="monomer", fasta_file_name=f"monomer_{idx}.fasta", ignore_cache=False, ), status=MessageStatus.PUBLISHED, ) for idx in range(3) ], batch_id="batch_id", cached_publications=[], status=BatchMessageStatus.PUBLISHED, cached=False, ) with mock.patch( "folding_studio.commands.predict.openfold_predict.batch_prediction", return_value=batch_pub.model_dump(mode="json"), ) as m: yield m @pytest.fixture() def mock_simple_prediction(): pub = Publication( folding_model=FoldingModel.OPENFOLD, message=Message( pipeline_name="alphafold_inference_pipeline", user_id="default-user", project_code="default-project", parameters=OpenFoldParameters(), experiment_id="dummy-experiment", model_preset="monomer", fasta_file_name="monomer.fasta", ignore_cache=False, ), status=MessageStatus.PUBLISHED, ) with mock.patch( "folding_studio.commands.predict.openfold_predict.simple_prediction", return_value=pub.model_dump(mode="json"), ) as m: yield m @pytest.fixture() def default_params(): yield PredictRequestParams( ignore_cache=False, template_mode=FeatureMode.SEARCH, custom_template_ids=[], msa_mode=FeatureMode.SEARCH, max_msa_clusters=-1, max_extra_msa=-1, gap_trick=False, num_recycle=3, random_seed=0, model_subset=[], project_code="FOLDING_DEV", ) def test_predict_with_unsupported_file_fails(tmp_files): result = runner.invoke( app, ["predict", "openfold", str(tmp_files["invalid_source"])] ) assert result.exit_code == 2 assert "Invalid value for 'SOURCE'" in result.stdout def test_predict_with_directory_containing_unsupported_file_fails(tmp_files): result = runner.invoke(app, ["predict", "openfold", str(tmp_files["invalid_dir"])]) assert result.exit_code == 2 assert "Invalid value for 'SOURCE'" in result.stdout def test_predict_with_empty_directory_fails(tmp_files): result = runner.invoke(app, ["predict", "openfold", str(tmp_files["empty_dir"])]) assert result.exit_code == 2 assert "Invalid value for 'SOURCE'" in result.stdout def test_predict_with_unsupported_custom_template_file_fails(tmp_files): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--custom-template", tmp_files["invalid_template"], ], ) assert result.exit_code == 2 assert "Invalid value for '--custom-template'" in result.stdout def test_predict_with_unsupported_custom_msa_file_fails(tmp_files): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--custom-msa", tmp_files["invalid_msa"], ], ) assert result.exit_code == 2 assert "Invalid value for '--custom-msa'" in result.stdout @pytest.mark.parametrize( "mock_simple_prediction", ( MessageStatus.PUBLISHED, MessageStatus.NOT_PUBLISHED_DONE, MessageStatus.NOT_PUBLISHED_PENDING, ), indirect=True, ) def test_predict_with_fasta_file_pass( mock_simple_prediction: mock.Mock, tmp_files, default_params, ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 0, result.output custom_files = PredictRequestCustomFiles(templates=[], msas=[]) mock_simple_prediction.assert_called_once_with( file=tmp_files["monomer_fasta"], folding_model=FoldingModel.OPENFOLD, params=default_params, custom_files=custom_files, project_code="FOLDING_DEV", ) @pytest.mark.parametrize( "mock_simple_prediction", ( MessageStatus.PUBLISHED, MessageStatus.NOT_PUBLISHED_DONE, MessageStatus.NOT_PUBLISHED_PENDING, ), indirect=True, ) def test_predict_with_fasta_file_pass_with_project_code_from_env_var( mock_simple_prediction: mock.Mock, tmp_files, default_params ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--metadata-file", str(tmp_files["metadata_file"]), # Here we deliberately leave out the # "--project-code" option to check that # env variable is correctly used ], ) assert result.exit_code == 0, result.output custom_files = PredictRequestCustomFiles(templates=[], msas=[]) mock_simple_prediction.assert_called_once_with( file=tmp_files["monomer_fasta"], params=default_params, custom_files=custom_files, folding_model=FoldingModel.OPENFOLD, project_code="FOLDING_DEV", ) def test_predict_unset_project_code_fails(tmp_files, remove_project_code_from_env_var): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), ], ) assert result.exit_code == 2, result.output @pytest.mark.parametrize( "mock_simple_prediction", ( MessageStatus.PUBLISHED, MessageStatus.NOT_PUBLISHED_DONE, MessageStatus.NOT_PUBLISHED_PENDING, ), indirect=True, ) def test_predict_with_fasta_file_with_templates_masks_pass( mock_simple_prediction: mock.Mock, tmp_files, default_params, valid_templates_masks ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--templates-masks-file", str(valid_templates_masks), "--custom-template", str(tmp_files["valid_template_3"]), "--custom-template", str(tmp_files["valid_template_4"]), "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 0, result.output custom_files = PredictRequestCustomFiles( templates=[tmp_files["valid_template_3"], tmp_files["valid_template_4"]], msas=[], templates_masks_files=[valid_templates_masks], ) mock_simple_prediction.assert_called_once_with( file=tmp_files["monomer_fasta"], folding_model=FoldingModel.OPENFOLD, params=default_params, custom_files=custom_files, project_code="FOLDING_DEV", ) def test_predict_with_fasta_file_with_templates_masks_fail( tmp_files, valid_templates_masks ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--templates-masks-file", str(valid_templates_masks), "--custom-template", str(tmp_files["valid_template_3"]), "--custom-template", str(tmp_files["valid_template_2"]), "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 1, result.output assert "Check your input command." in str(result.exception) def test_predict_with_fasta_file_with_multi_seed_pass( mock_batch_prediction: mock.Mock, tmp_files, default_params, ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["monomer_fasta"]), "--num-seed", 5, "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 0, result.output custom_files = PredictRequestCustomFiles(templates=[], msas=[]) mock_batch_prediction.assert_called_once_with( files=[tmp_files["monomer_fasta"]], folding_model=FoldingModel.OPENFOLD, params=default_params, custom_files=custom_files, num_seed=5, project_code="FOLDING_DEV", ) def test_predict_with_directory_pass( mock_batch_prediction: mock.Mock, tmp_files, default_params, ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files["valid_dir"]), "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 0, result.output custom_files = PredictRequestCustomFiles(templates=[], msas=[]) mock_batch_prediction.assert_called_once_with( files=mock.ANY, params=default_params, folding_model=FoldingModel.OPENFOLD, custom_files=custom_files, num_seed=None, project_code="FOLDING_DEV", ) _, kwargs = mock_batch_prediction.call_args assert sorted(kwargs["files"]) == sorted( list(tmp_files["valid_dir"].iterdir()), ) @pytest.mark.parametrize("file", ["valid_batch_file_json", "valid_batch_file_csv"]) def test_predict_with_json_batch_file_pass( file, mock_batch_prediction_from_file: mock.Mock, tmp_files ): result = runner.invoke( app, [ "predict", "openfold", str(tmp_files[file]), "--metadata-file", str(tmp_files["metadata_file"]), "--project-code", "FOLDING_DEV", ], ) assert result.exit_code == 0, result.output mock_batch_prediction_from_file.assert_called_once_with( file=tmp_files[file], project_code="FOLDING_DEV", )