from math import isclose from pathlib import Path import pytest import torch import chroma from chroma.data.protein import Protein from chroma.layers.structure import conditioners from chroma.models.chroma import Chroma BB_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt" #'named:nature_v3' GD_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt" #'named:nature_v3' BASE_PATH = str(Path(chroma.__file__).parent.parent) PROTEIN_SAMPLE = BASE_PATH + "/tests/resources/steps200_seed42_len100.cif" @pytest.fixture(scope="session") def chroma(): return Chroma(BB_MODEL_PATH, GD_MODEL_PATH, device="cpu") def test_chroma(chroma): # Fixed Protein Value protein = Protein.from_CIF(PROTEIN_SAMPLE) # Fixed value test score torch.manual_seed(42) scores = chroma.score(protein, num_samples=5) assert isclose(scores["elbo"].score, 5.890165328979492, abs_tol=1e-3) # Test Sampling & Design # torch.manual_seed(42) # sample = chroma.sample(steps=200) # Xs, _, Ss = sample.to_XCS() # X , _, S = protein.to_XCS() # assert torch.allclose(X,Xs) # assert (S == Ss).all() # test postprocessing from chroma.layers.structure import conditioners X, C, S = protein.to_XCS() c_symmetry = conditioners.SymmetryConditioner(G="C_8", num_chain_neighbors=1) X_s, C_s, S_s = ( torch.cat([X, X], dim=1), torch.cat([C, C], dim=1), torch.cat([S, S], dim=1), ) protein_sym = Protein(X_s, C_s, S_s) chroma._postprocess(c_symmetry, protein_sym, output_dictionary=None) @pytest.mark.parametrize( "conditioner", [ conditioners.Identity(), conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), ], ) def test_sample(chroma, conditioner): chroma.sample(steps=3, conditioner=conditioner, design_method=None) @pytest.mark.parametrize( "conditioner", [ conditioners.Identity(), conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), ], ) def test_sample_backbone(chroma, conditioner): chroma._sample(steps=3, conditioner=conditioner) @pytest.mark.parametrize("design_method", ["autoregressive", "potts",]) @pytest.mark.parametrize("potts_proposal", ["dlmc", "chromatic"]) def test_design(chroma, design_method, potts_proposal): protein = Protein.from_CIF(PROTEIN_SAMPLE) chroma.design( protein, design_method=design_method, potts_proposal=potts_proposal, potts_mcmc_depth=20, )