|
from pathlib import Path |
|
from unittest import mock |
|
|
|
import pytest |
|
from folding_studio.query import ChaiQuery |
|
from folding_studio.query.chai import ChaiParameters |
|
|
|
RESTRAINTS_CSV_CONTENT = """ |
|
chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id |
|
A,C387,B,Y101,contact,1.0,0.0,5.5,protein-heavy,restraint_1 |
|
C,I32,A,S483,contact,1.0,0.0,5.5,protein-light,restraint_2 |
|
""" |
|
|
|
A3M_CONTENT = """\ |
|
>101 |
|
RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFHHHHHP |
|
>UniRef100_UPI00024DB110 356 0.741 7.627E-108 0 227 229 0 226 228 |
|
RVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRFLRHGKLRPFERDISNVPFSPDGKPCT-PPAFNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFHHHHH- |
|
""" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_get_auth_headers(): |
|
with mock.patch( |
|
"folding_studio.utils.headers.get_auth_headers", return_value="headers" |
|
) as m: |
|
yield m |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def mock_upload_custom_files(): |
|
def side_effect(paths: list[Path], **kwargs): |
|
return {str(p): "new_url" for p in paths} |
|
|
|
upload_custom_files_mock = mock.Mock(side_effect=side_effect) |
|
with mock.patch( |
|
"folding_studio.commands.utils.upload_custom_files", |
|
upload_custom_files_mock, |
|
): |
|
yield upload_custom_files_mock |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def tmp_files(tmp_directory: Path, tmp_files: dict): |
|
"""Generate temporary files.""" |
|
|
|
|
|
valid_restraints = tmp_directory / "example_restraints.csv" |
|
valid_restraints.write_text(RESTRAINTS_CSV_CONTENT) |
|
|
|
tmp_files["valid_restraints"] = valid_restraints |
|
|
|
|
|
valid_a3m_dir = tmp_directory / "valid_a3m_dir" |
|
valid_a3m_dir.mkdir() |
|
alignment_1 = valid_a3m_dir / "alignment1.a3m" |
|
alignment_1.write_text(A3M_CONTENT) |
|
alignment_2 = valid_a3m_dir / "alignment1.a3m" |
|
alignment_2.write_text(A3M_CONTENT) |
|
tmp_files["valid_a3m_dir"] = valid_a3m_dir |
|
|
|
|
|
custom_msa = tmp_directory / "custom_msa.aligned.pqt" |
|
custom_msa.write_text("") |
|
tmp_files["aligned_pqt"] = custom_msa |
|
|
|
yield tmp_files |
|
|
|
|
|
def test_from_fasta_file(tmp_files): |
|
"""Test _from_fasta_file with a valid FASTA file.""" |
|
fasta_path = tmp_files["monomer_fasta"] |
|
query = ChaiQuery.from_file( |
|
fasta_path, |
|
use_msa_server=True, |
|
use_templates_server=False, |
|
num_trunk_recycles=3, |
|
seed=42, |
|
num_diffn_timesteps=100, |
|
custom_msa_paths=tmp_files["aligned_pqt"], |
|
) |
|
payload = query.payload |
|
assert fasta_path.stem in payload["fasta_files"] |
|
assert payload["use_msa_server"] is True |
|
assert payload["use_templates_server"] is False |
|
assert payload["num_trunk_recycles"] == 3 |
|
assert payload["seed"] == 42 |
|
assert payload["num_diffn_timesteps"] == 100 |
|
assert payload["recycle_msa_subsample"] == 0 |
|
assert payload["num_trunk_samples"] == 1 |
|
|
|
|
|
def test_from_fasta_file_invalid_extension(tmp_files): |
|
"""Test _from_fasta_file with an invalid file extension.""" |
|
with pytest.raises( |
|
ValueError, |
|
match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'", |
|
): |
|
ChaiQuery.from_file( |
|
tmp_files["invalid_source"], |
|
use_msa_server=True, |
|
use_templates_server=False, |
|
num_trunk_recycles=3, |
|
seed=42, |
|
num_diffn_timesteps=100, |
|
custom_msa_paths=tmp_files["aligned_pqt"], |
|
) |
|
|
|
|
|
def test_from_fasta_directory(tmp_files): |
|
"""Test _from_fasta_directory with valid FASTA files.""" |
|
query = ChaiQuery.from_directory( |
|
path=str(tmp_files["valid_dir"]), |
|
use_msa_server=False, |
|
use_templates_server=True, |
|
num_trunk_recycles=2, |
|
seed=123, |
|
num_diffn_timesteps=50, |
|
custom_msa_paths=str(tmp_files["valid_a3m_dir"]), |
|
) |
|
payload = query.payload |
|
fasta_files = payload["fasta_files"] |
|
|
|
|
|
assert "monomer_1" in fasta_files, "Expected 'monomer_1' in FASTA files." |
|
assert "monomer_2" in fasta_files, "Expected 'monomer_2' in FASTA files." |
|
|
|
|
|
assert payload["use_msa_server"] is False |
|
assert payload["use_templates_server"] is True |
|
assert payload["num_trunk_recycles"] == 2 |
|
assert payload["seed"] == 123 |
|
assert payload["num_diffn_timesteps"] == 50 |
|
assert payload["recycle_msa_subsample"] == 0 |
|
assert payload["num_trunk_samples"] == 1 |
|
|
|
|
|
|
|
|
|
def test_from_empty_fasta_directory(tmp_files): |
|
"""Test _from_fasta_directory with an empty directory.""" |
|
with pytest.raises(ValueError, match="No FASTA files found in directory"): |
|
ChaiQuery.from_directory( |
|
tmp_files["empty_dir"], |
|
use_msa_server=False, |
|
use_templates_server=False, |
|
num_trunk_recycles=2, |
|
seed=123, |
|
num_diffn_timesteps=50, |
|
custom_msa_paths=tmp_files["valid_a3m_dir"], |
|
) |
|
|
|
|
|
def test_from_fasta_directory_with_invalid_files(tmp_files): |
|
"""Test _from_fasta_directory ignores invalid file extensions.""" |
|
with pytest.raises(ValueError, match="No FASTA files found in directory"): |
|
ChaiQuery.from_directory( |
|
tmp_files["invalid_dir"], |
|
use_msa_server=False, |
|
num_trunk_recycles=2, |
|
seed=123, |
|
num_diffn_timesteps=50, |
|
custom_msa_paths=tmp_files["valid_a3m_dir"], |
|
) |
|
|
|
|
|
def test_from_file_with_restraints(tmp_files): |
|
"""Test from_file with a FASTA file and restraints.""" |
|
query = ChaiQuery.from_file( |
|
tmp_files["monomer_fasta"], |
|
use_msa_server=True, |
|
use_templates_server=True, |
|
num_trunk_recycles=4, |
|
seed=10, |
|
num_diffn_timesteps=200, |
|
restraints=tmp_files["valid_restraints"], |
|
) |
|
payload = query.payload |
|
assert tmp_files["monomer_fasta"].stem in payload["fasta_files"] |
|
assert payload["restraints"] == RESTRAINTS_CSV_CONTENT.strip() |
|
assert payload["use_msa_server"] is True |
|
assert payload["use_templates_server"] is True |
|
assert payload["num_trunk_recycles"] == 4 |
|
assert payload["seed"] == 10 |
|
assert payload["num_diffn_timesteps"] == 200 |
|
assert payload["recycle_msa_subsample"] == 0 |
|
assert payload["num_trunk_samples"] == 1 |
|
|
|
|
|
def test_from_fasta_directory_with_invalid_sources(tmp_files): |
|
"""Test _from_fasta_directory ignores invalid file extensions.""" |
|
with pytest.raises(ValueError, match="No FASTA files found in directory"): |
|
ChaiQuery.from_directory( |
|
tmp_files["invalid_dir"], |
|
use_msa_server=False, |
|
num_trunk_recycles=2, |
|
seed=123, |
|
num_diffn_timesteps=50, |
|
) |
|
|
|
|
|
def test_ChaiParameters_read_restraints(tmp_files): |
|
"""Test _read_restraints with a valid CSV file.""" |
|
parameters = ChaiParameters(restraints=tmp_files["valid_restraints"]) |
|
assert parameters.restraints == RESTRAINTS_CSV_CONTENT.strip() |
|
|
|
|
|
def test_read_restraints_invalid_extension(tmp_files): |
|
"""Test _read_restraints with a non-CSV file.""" |
|
with pytest.raises( |
|
ValueError, |
|
match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'", |
|
): |
|
ChaiParameters(restraints=tmp_files["invalid_source"]) |
|
|
|
|
|
def test_from_nonexistent_file(): |
|
"""Test from_file with a nonexistent file.""" |
|
with pytest.raises(FileNotFoundError): |
|
ChaiQuery.from_file( |
|
"nonexistent.fasta", |
|
use_msa_server=True, |
|
num_trunk_recycles=3, |
|
seed=42, |
|
num_diffn_timesteps=100, |
|
) |
|
|