from pathlib import Path from unittest import mock import pytest from folding_studio.query import ChaiQuery from folding_studio.query.chai import ChaiParameters RESTRAINTS_CSV_CONTENT = """ chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id A,C387,B,Y101,contact,1.0,0.0,5.5,protein-heavy,restraint_1 C,I32,A,S483,contact,1.0,0.0,5.5,protein-light,restraint_2 """ A3M_CONTENT = """\ >101 RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFHHHHHP >UniRef100_UPI00024DB110 356 0.741 7.627E-108 0 227 229 0 226 228 RVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRFLRHGKLRPFERDISNVPFSPDGKPCT-PPAFNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFHHHHH- """ @pytest.fixture(autouse=True) def mock_get_auth_headers(): with mock.patch( "folding_studio.utils.headers.get_auth_headers", return_value="headers" ) as m: yield m @pytest.fixture(autouse=True) def mock_upload_custom_files(): def side_effect(paths: list[Path], **kwargs): return {str(p): "new_url" for p in paths} upload_custom_files_mock = mock.Mock(side_effect=side_effect) with mock.patch( "folding_studio.commands.utils.upload_custom_files", upload_custom_files_mock, ): yield upload_custom_files_mock @pytest.fixture(scope="module") def tmp_files(tmp_directory: Path, tmp_files: dict): """Generate temporary files.""" # A valid restraints CSV file valid_restraints = tmp_directory / "example_restraints.csv" valid_restraints.write_text(RESTRAINTS_CSV_CONTENT) tmp_files["valid_restraints"] = valid_restraints # Directory for A3M files valid_a3m_dir = tmp_directory / "valid_a3m_dir" valid_a3m_dir.mkdir() alignment_1 = valid_a3m_dir / "alignment1.a3m" alignment_1.write_text(A3M_CONTENT) alignment_2 = valid_a3m_dir / "alignment1.a3m" alignment_2.write_text(A3M_CONTENT) tmp_files["valid_a3m_dir"] = valid_a3m_dir # An aligned.pqt file custom_msa = tmp_directory / "custom_msa.aligned.pqt" custom_msa.write_text("") tmp_files["aligned_pqt"] = custom_msa yield tmp_files def test_from_fasta_file(tmp_files): """Test _from_fasta_file with a valid FASTA file.""" fasta_path = tmp_files["monomer_fasta"] query = ChaiQuery.from_file( fasta_path, use_msa_server=True, use_templates_server=False, num_trunk_recycles=3, seed=42, num_diffn_timesteps=100, custom_msa_paths=tmp_files["aligned_pqt"], ) payload = query.payload assert fasta_path.stem in payload["fasta_files"] assert payload["use_msa_server"] is True assert payload["use_templates_server"] is False assert payload["num_trunk_recycles"] == 3 assert payload["seed"] == 42 assert payload["num_diffn_timesteps"] == 100 assert payload["recycle_msa_subsample"] == 0 assert payload["num_trunk_samples"] == 1 def test_from_fasta_file_invalid_extension(tmp_files): """Test _from_fasta_file with an invalid file extension.""" with pytest.raises( ValueError, match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'", ): ChaiQuery.from_file( tmp_files["invalid_source"], use_msa_server=True, use_templates_server=False, num_trunk_recycles=3, seed=42, num_diffn_timesteps=100, custom_msa_paths=tmp_files["aligned_pqt"], ) def test_from_fasta_directory(tmp_files): """Test _from_fasta_directory with valid FASTA files.""" query = ChaiQuery.from_directory( path=str(tmp_files["valid_dir"]), use_msa_server=False, use_templates_server=True, num_trunk_recycles=2, seed=123, num_diffn_timesteps=50, custom_msa_paths=str(tmp_files["valid_a3m_dir"]), ) payload = query.payload fasta_files = payload["fasta_files"] # Verify that both expected FASTA files are included in the payload. assert "monomer_1" in fasta_files, "Expected 'monomer_1' in FASTA files." assert "monomer_2" in fasta_files, "Expected 'monomer_2' in FASTA files." # Verify other parameters are correctly set. assert payload["use_msa_server"] is False assert payload["use_templates_server"] is True assert payload["num_trunk_recycles"] == 2 assert payload["seed"] == 123 assert payload["num_diffn_timesteps"] == 50 assert payload["recycle_msa_subsample"] == 0 assert payload["num_trunk_samples"] == 1 def test_from_empty_fasta_directory(tmp_files): """Test _from_fasta_directory with an empty directory.""" with pytest.raises(ValueError, match="No FASTA files found in directory"): ChaiQuery.from_directory( tmp_files["empty_dir"], use_msa_server=False, use_templates_server=False, num_trunk_recycles=2, seed=123, num_diffn_timesteps=50, custom_msa_paths=tmp_files["valid_a3m_dir"], ) def test_from_fasta_directory_with_invalid_files(tmp_files): """Test _from_fasta_directory ignores invalid file extensions.""" with pytest.raises(ValueError, match="No FASTA files found in directory"): ChaiQuery.from_directory( tmp_files["invalid_dir"], use_msa_server=False, num_trunk_recycles=2, seed=123, num_diffn_timesteps=50, custom_msa_paths=tmp_files["valid_a3m_dir"], ) def test_from_file_with_restraints(tmp_files): """Test from_file with a FASTA file and restraints.""" query = ChaiQuery.from_file( tmp_files["monomer_fasta"], use_msa_server=True, use_templates_server=True, num_trunk_recycles=4, seed=10, num_diffn_timesteps=200, restraints=tmp_files["valid_restraints"], ) payload = query.payload assert tmp_files["monomer_fasta"].stem in payload["fasta_files"] assert payload["restraints"] == RESTRAINTS_CSV_CONTENT.strip() assert payload["use_msa_server"] is True assert payload["use_templates_server"] is True assert payload["num_trunk_recycles"] == 4 assert payload["seed"] == 10 assert payload["num_diffn_timesteps"] == 200 assert payload["recycle_msa_subsample"] == 0 assert payload["num_trunk_samples"] == 1 def test_from_fasta_directory_with_invalid_sources(tmp_files): """Test _from_fasta_directory ignores invalid file extensions.""" with pytest.raises(ValueError, match="No FASTA files found in directory"): ChaiQuery.from_directory( tmp_files["invalid_dir"], use_msa_server=False, num_trunk_recycles=2, seed=123, num_diffn_timesteps=50, ) def test_ChaiParameters_read_restraints(tmp_files): """Test _read_restraints with a valid CSV file.""" parameters = ChaiParameters(restraints=tmp_files["valid_restraints"]) assert parameters.restraints == RESTRAINTS_CSV_CONTENT.strip() def test_read_restraints_invalid_extension(tmp_files): """Test _read_restraints with a non-CSV file.""" with pytest.raises( ValueError, match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'", ): ChaiParameters(restraints=tmp_files["invalid_source"]) def test_from_nonexistent_file(): """Test from_file with a nonexistent file.""" with pytest.raises(FileNotFoundError): ChaiQuery.from_file( "nonexistent.fasta", use_msa_server=True, num_trunk_recycles=3, seed=42, num_diffn_timesteps=100, )