import pytest from folding_studio.query import BoltzQuery @pytest.fixture def inference_parameters(): """ Provides the inference parameters dictionary to tests. """ return { "recycling_steps": 3, "sampling_steps": 200, "diffusion_samples": 1, "step_scale": 1.638, "output_format": "mmcif", "num_workers": 2, "msa_pairing_strategy": "greedy", "write_full_pae": False, "write_full_pde": False, "seed": 0, } def test_process_fasta_file(tmp_files): """Test processing a valid FASTA file.""" file_path = tmp_files["monomer_fasta"] fasta_dict, yaml_dict = BoltzQuery._process_file(file_path) expected_fasta = ">tag1|tag2\nABCDEGF" expected_fasta_dict = {file_path.stem: expected_fasta} assert fasta_dict == expected_fasta_dict assert yaml_dict == {} def test_process_yaml_file(tmp_files): """Test processing a valid YAML file.""" file_path = tmp_files["yaml_file_path"] fasta_dict, yaml_dict = BoltzQuery._process_file(file_path) expected_yaml = { "version": 1, "sequences": [{"protein": {"id": "A", "sequence": "QLEDSEVEAVAKGLEE"}}], } expected_yaml_dict = {file_path.stem: expected_yaml} assert fasta_dict == {} assert yaml_dict == expected_yaml_dict def test_process_file_invalid_extension(tmp_files): """Test processing a file with an invalid extension.""" invalid_source = tmp_files["invalid_source"] with pytest.raises( ValueError, match=f"Unsupported format: {invalid_source.suffix}" ): BoltzQuery._process_file(invalid_source) def test_from_file_fasta(tmp_files, inference_parameters): """Test creating BoltzQuery from a FASTA file.""" file_path = tmp_files["monomer_fasta"] query = BoltzQuery.from_file(file_path, **inference_parameters) assert file_path.stem in query.fasta_dict assert not query.yaml_dict def test_from_file_yaml(tmp_files, inference_parameters): """Test creating BoltzQuery from a YAML file.""" file_path = tmp_files["yaml_file_path"] query = BoltzQuery.from_file(file_path, **inference_parameters) assert not query.fasta_dict assert file_path.stem in query.yaml_dict def test_from_directory(tmp_files, inference_parameters): """Test creating BoltzQuery from a directory containing both FASTA and YAML files.""" query = BoltzQuery.from_directory( tmp_files["mixed_fasta_yaml_dir"], **inference_parameters ) assert "protein_A" in query.fasta_dict assert "protein_B" in query.yaml_dict def test_from_source_file(tmp_files, inference_parameters): """Test creating BoltzQuery from a single file.""" file_path = tmp_files["monomer_fasta"] query = BoltzQuery.from_file(str(file_path), **inference_parameters) assert "monomer" in query.fasta_dict def test_from_source_directory(tmp_files, inference_parameters): """Test creating BoltzQuery from a directory."""