Spaces:
Sleeping
Sleeping
File size: 1,236 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from unittest import TestCase
import numpy as np
import torch
from chroma.models import graph_energy
class TestGraphHarmonicFeatures(TestCase):
def test_sample(self):
num_batch = 1
num_nodes = 10
num_neighbors = 8
dim_nodes = 128
dim_edges = 64
layer = graph_energy.GraphHarmonicFeatures(
dim_nodes=dim_nodes,
dim_edges=dim_edges,
node_mlp_layers=2,
node_mlp_dim=dim_nodes,
edge_mlp_layers=2,
edge_mlp_dim=dim_edges,
)
node_h = torch.ones(num_batch, num_nodes, dim_nodes)
node_features = torch.ones(num_batch, num_nodes, dim_nodes)
edge_h = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges)
edge_features = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges)
mask_i = torch.ones(num_batch, num_nodes)
mask_ij = torch.ones(num_batch, num_nodes, num_neighbors)
node_out, edge_out = layer(
node_h, node_features, edge_h, edge_features, mask_i, mask_ij
)
self.assertTrue(node_out.shape == (1, num_nodes, dim_nodes))
self.assertTrue(edge_out.shape == (1, num_nodes, num_neighbors, dim_edges))
|