Spaces:
Sleeping
Sleeping
File size: 1,210 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from unittest import TestCase
import torch
from chroma.models.graph_classifier import GraphClassifier
class TestGraphClassifier(TestCase):
def test_graph_classifier(self):
class_config = {
"dummy_1": {
"tokens": ["a", "b", "c", "d"],
"loss": "bce",
"level": "chain",
},
"dummy_2": {
"tokens": ["w", "x", "y", "z"],
"loss": "ce",
"level": "first_order",
},
}
for k, v in class_config.items():
v["tokenizer"] = {k: i for i, k in enumerate(v["tokens"])}
model = GraphClassifier(
dim_nodes=16,
dim_edges=16,
edge_mlp_dim=8,
node_mlp_dim=8,
class_config=class_config,
)
bs = 1
sl = 8
X = torch.randn(bs, sl, 4, 3)
C = torch.ones(bs, sl)
with torch.no_grad():
node_h, edge_h = model(X, C)
self.assertTrue(node_h.size() == torch.Size([bs, sl, 16]))
grad = model.gradient(X, C, t=0.5, label="dummy_2", value="w")
self.assertTrue(grad.size() == torch.Size([bs, sl, 4, 3]))
|