Spaces:
Sleeping
Sleeping
from unittest import TestCase | |
import torch | |
from chroma.models.graph_classifier import GraphClassifier | |
class TestGraphClassifier(TestCase): | |
def test_graph_classifier(self): | |
class_config = { | |
"dummy_1": { | |
"tokens": ["a", "b", "c", "d"], | |
"loss": "bce", | |
"level": "chain", | |
}, | |
"dummy_2": { | |
"tokens": ["w", "x", "y", "z"], | |
"loss": "ce", | |
"level": "first_order", | |
}, | |
} | |
for k, v in class_config.items(): | |
v["tokenizer"] = {k: i for i, k in enumerate(v["tokens"])} | |
model = GraphClassifier( | |
dim_nodes=16, | |
dim_edges=16, | |
edge_mlp_dim=8, | |
node_mlp_dim=8, | |
class_config=class_config, | |
) | |
bs = 1 | |
sl = 8 | |
X = torch.randn(bs, sl, 4, 3) | |
C = torch.ones(bs, sl) | |
with torch.no_grad(): | |
node_h, edge_h = model(X, C) | |
self.assertTrue(node_h.size() == torch.Size([bs, sl, 16])) | |
grad = model.gradient(X, C, t=0.5, label="dummy_2", value="w") | |
self.assertTrue(grad.size() == torch.Size([bs, sl, 4, 3])) | |