Spaces:
Sleeping
Sleeping
from math import isclose | |
from pathlib import Path | |
import pytest | |
import torch | |
import chroma | |
from chroma.data.protein import Protein | |
from chroma.layers.structure import conditioners | |
from chroma.models.chroma import Chroma | |
BB_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt" #'named:nature_v3' | |
GD_MODEL_PATH = "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt" #'named:nature_v3' | |
BASE_PATH = str(Path(chroma.__file__).parent.parent) | |
PROTEIN_SAMPLE = BASE_PATH + "/tests/resources/steps200_seed42_len100.cif" | |
def chroma(): | |
return Chroma(BB_MODEL_PATH, GD_MODEL_PATH, device="cpu") | |
def test_chroma(chroma): | |
# Fixed Protein Value | |
protein = Protein.from_CIF(PROTEIN_SAMPLE) | |
# Fixed value test score | |
torch.manual_seed(42) | |
scores = chroma.score(protein, num_samples=5) | |
assert isclose(scores["elbo"].score, 5.890165328979492, abs_tol=1e-3) | |
# Test Sampling & Design | |
# torch.manual_seed(42) | |
# sample = chroma.sample(steps=200) | |
# Xs, _, Ss = sample.to_XCS() | |
# X , _, S = protein.to_XCS() | |
# assert torch.allclose(X,Xs) | |
# assert (S == Ss).all() | |
# test postprocessing | |
from chroma.layers.structure import conditioners | |
X, C, S = protein.to_XCS() | |
c_symmetry = conditioners.SymmetryConditioner(G="C_8", num_chain_neighbors=1) | |
X_s, C_s, S_s = ( | |
torch.cat([X, X], dim=1), | |
torch.cat([C, C], dim=1), | |
torch.cat([S, S], dim=1), | |
) | |
protein_sym = Protein(X_s, C_s, S_s) | |
chroma._postprocess(c_symmetry, protein_sym, output_dictionary=None) | |
def test_sample(chroma, conditioner): | |
chroma.sample(steps=3, conditioner=conditioner, design_method=None) | |
def test_sample_backbone(chroma, conditioner): | |
chroma._sample(steps=3, conditioner=conditioner) | |
def test_design(chroma, design_method, potts_proposal): | |
protein = Protein.from_CIF(PROTEIN_SAMPLE) | |
chroma.design( | |
protein, | |
design_method=design_method, | |
potts_proposal=potts_proposal, | |
potts_mcmc_depth=20, | |
) | |