jfaustin's picture
add dockerfile and folding studio cli
44459bb
raw
history blame
7.96 kB
from pathlib import Path
from unittest import mock
import pytest
from folding_studio.query import ChaiQuery
from folding_studio.query.chai import ChaiParameters
RESTRAINTS_CSV_CONTENT = """
chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id
A,C387,B,Y101,contact,1.0,0.0,5.5,protein-heavy,restraint_1
C,I32,A,S483,contact,1.0,0.0,5.5,protein-light,restraint_2
"""
A3M_CONTENT = """\
>101
RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFHHHHHP
>UniRef100_UPI00024DB110 356 0.741 7.627E-108 0 227 229 0 226 228
RVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRFLRHGKLRPFERDISNVPFSPDGKPCT-PPAFNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFHHHHH-
"""
@pytest.fixture(autouse=True)
def mock_get_auth_headers():
with mock.patch(
"folding_studio.utils.headers.get_auth_headers", return_value="headers"
) as m:
yield m
@pytest.fixture(autouse=True)
def mock_upload_custom_files():
def side_effect(paths: list[Path], **kwargs):
return {str(p): "new_url" for p in paths}
upload_custom_files_mock = mock.Mock(side_effect=side_effect)
with mock.patch(
"folding_studio.commands.utils.upload_custom_files",
upload_custom_files_mock,
):
yield upload_custom_files_mock
@pytest.fixture(scope="module")
def tmp_files(tmp_directory: Path, tmp_files: dict):
"""Generate temporary files."""
# A valid restraints CSV file
valid_restraints = tmp_directory / "example_restraints.csv"
valid_restraints.write_text(RESTRAINTS_CSV_CONTENT)
tmp_files["valid_restraints"] = valid_restraints
# Directory for A3M files
valid_a3m_dir = tmp_directory / "valid_a3m_dir"
valid_a3m_dir.mkdir()
alignment_1 = valid_a3m_dir / "alignment1.a3m"
alignment_1.write_text(A3M_CONTENT)
alignment_2 = valid_a3m_dir / "alignment1.a3m"
alignment_2.write_text(A3M_CONTENT)
tmp_files["valid_a3m_dir"] = valid_a3m_dir
# An aligned.pqt file
custom_msa = tmp_directory / "custom_msa.aligned.pqt"
custom_msa.write_text("")
tmp_files["aligned_pqt"] = custom_msa
yield tmp_files
def test_from_fasta_file(tmp_files):
"""Test _from_fasta_file with a valid FASTA file."""
fasta_path = tmp_files["monomer_fasta"]
query = ChaiQuery.from_file(
fasta_path,
use_msa_server=True,
use_templates_server=False,
num_trunk_recycles=3,
seed=42,
num_diffn_timesteps=100,
custom_msa_paths=tmp_files["aligned_pqt"],
)
payload = query.payload
assert fasta_path.stem in payload["fasta_files"]
assert payload["use_msa_server"] is True
assert payload["use_templates_server"] is False
assert payload["num_trunk_recycles"] == 3
assert payload["seed"] == 42
assert payload["num_diffn_timesteps"] == 100
assert payload["recycle_msa_subsample"] == 0
assert payload["num_trunk_samples"] == 1
def test_from_fasta_file_invalid_extension(tmp_files):
"""Test _from_fasta_file with an invalid file extension."""
with pytest.raises(
ValueError,
match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'",
):
ChaiQuery.from_file(
tmp_files["invalid_source"],
use_msa_server=True,
use_templates_server=False,
num_trunk_recycles=3,
seed=42,
num_diffn_timesteps=100,
custom_msa_paths=tmp_files["aligned_pqt"],
)
def test_from_fasta_directory(tmp_files):
"""Test _from_fasta_directory with valid FASTA files."""
query = ChaiQuery.from_directory(
path=str(tmp_files["valid_dir"]),
use_msa_server=False,
use_templates_server=True,
num_trunk_recycles=2,
seed=123,
num_diffn_timesteps=50,
custom_msa_paths=str(tmp_files["valid_a3m_dir"]),
)
payload = query.payload
fasta_files = payload["fasta_files"]
# Verify that both expected FASTA files are included in the payload.
assert "monomer_1" in fasta_files, "Expected 'monomer_1' in FASTA files."
assert "monomer_2" in fasta_files, "Expected 'monomer_2' in FASTA files."
# Verify other parameters are correctly set.
assert payload["use_msa_server"] is False
assert payload["use_templates_server"] is True
assert payload["num_trunk_recycles"] == 2
assert payload["seed"] == 123
assert payload["num_diffn_timesteps"] == 50
assert payload["recycle_msa_subsample"] == 0
assert payload["num_trunk_samples"] == 1
def test_from_empty_fasta_directory(tmp_files):
"""Test _from_fasta_directory with an empty directory."""
with pytest.raises(ValueError, match="No FASTA files found in directory"):
ChaiQuery.from_directory(
tmp_files["empty_dir"],
use_msa_server=False,
use_templates_server=False,
num_trunk_recycles=2,
seed=123,
num_diffn_timesteps=50,
custom_msa_paths=tmp_files["valid_a3m_dir"],
)
def test_from_fasta_directory_with_invalid_files(tmp_files):
"""Test _from_fasta_directory ignores invalid file extensions."""
with pytest.raises(ValueError, match="No FASTA files found in directory"):
ChaiQuery.from_directory(
tmp_files["invalid_dir"],
use_msa_server=False,
num_trunk_recycles=2,
seed=123,
num_diffn_timesteps=50,
custom_msa_paths=tmp_files["valid_a3m_dir"],
)
def test_from_file_with_restraints(tmp_files):
"""Test from_file with a FASTA file and restraints."""
query = ChaiQuery.from_file(
tmp_files["monomer_fasta"],
use_msa_server=True,
use_templates_server=True,
num_trunk_recycles=4,
seed=10,
num_diffn_timesteps=200,
restraints=tmp_files["valid_restraints"],
)
payload = query.payload
assert tmp_files["monomer_fasta"].stem in payload["fasta_files"]
assert payload["restraints"] == RESTRAINTS_CSV_CONTENT.strip()
assert payload["use_msa_server"] is True
assert payload["use_templates_server"] is True
assert payload["num_trunk_recycles"] == 4
assert payload["seed"] == 10
assert payload["num_diffn_timesteps"] == 200
assert payload["recycle_msa_subsample"] == 0
assert payload["num_trunk_samples"] == 1
def test_from_fasta_directory_with_invalid_sources(tmp_files):
"""Test _from_fasta_directory ignores invalid file extensions."""
with pytest.raises(ValueError, match="No FASTA files found in directory"):
ChaiQuery.from_directory(
tmp_files["invalid_dir"],
use_msa_server=False,
num_trunk_recycles=2,
seed=123,
num_diffn_timesteps=50,
)
def test_ChaiParameters_read_restraints(tmp_files):
"""Test _read_restraints with a valid CSV file."""
parameters = ChaiParameters(restraints=tmp_files["valid_restraints"])
assert parameters.restraints == RESTRAINTS_CSV_CONTENT.strip()
def test_read_restraints_invalid_extension(tmp_files):
"""Test _read_restraints with a non-CSV file."""
with pytest.raises(
ValueError,
match=f"Unsupported suffix '{tmp_files['invalid_source'].suffix}'",
):
ChaiParameters(restraints=tmp_files["invalid_source"])
def test_from_nonexistent_file():
"""Test from_file with a nonexistent file."""
with pytest.raises(FileNotFoundError):
ChaiQuery.from_file(
"nonexistent.fasta",
use_msa_server=True,
num_trunk_recycles=3,
seed=42,
num_diffn_timesteps=100,
)