ProteinDesignDemo / chroma /tests /models /test_graph_classifier.py
Hukuna's picture
Upload 275 files
e9e75df verified
raw
history blame
1.21 kB
from unittest import TestCase
import torch
from chroma.models.graph_classifier import GraphClassifier
class TestGraphClassifier(TestCase):
def test_graph_classifier(self):
class_config = {
"dummy_1": {
"tokens": ["a", "b", "c", "d"],
"loss": "bce",
"level": "chain",
},
"dummy_2": {
"tokens": ["w", "x", "y", "z"],
"loss": "ce",
"level": "first_order",
},
}
for k, v in class_config.items():
v["tokenizer"] = {k: i for i, k in enumerate(v["tokens"])}
model = GraphClassifier(
dim_nodes=16,
dim_edges=16,
edge_mlp_dim=8,
node_mlp_dim=8,
class_config=class_config,
)
bs = 1
sl = 8
X = torch.randn(bs, sl, 4, 3)
C = torch.ones(bs, sl)
with torch.no_grad():
node_h, edge_h = model(X, C)
self.assertTrue(node_h.size() == torch.Size([bs, sl, 16]))
grad = model.gradient(X, C, t=0.5, label="dummy_2", value="w")
self.assertTrue(grad.size() == torch.Size([bs, sl, 4, 3]))