Spaces:
Sleeping
Sleeping
from unittest import TestCase | |
import numpy as np | |
import torch | |
from chroma.models import graph_energy | |
class TestGraphHarmonicFeatures(TestCase): | |
def test_sample(self): | |
num_batch = 1 | |
num_nodes = 10 | |
num_neighbors = 8 | |
dim_nodes = 128 | |
dim_edges = 64 | |
layer = graph_energy.GraphHarmonicFeatures( | |
dim_nodes=dim_nodes, | |
dim_edges=dim_edges, | |
node_mlp_layers=2, | |
node_mlp_dim=dim_nodes, | |
edge_mlp_layers=2, | |
edge_mlp_dim=dim_edges, | |
) | |
node_h = torch.ones(num_batch, num_nodes, dim_nodes) | |
node_features = torch.ones(num_batch, num_nodes, dim_nodes) | |
edge_h = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges) | |
edge_features = torch.ones(num_batch, num_nodes, num_neighbors, dim_edges) | |
mask_i = torch.ones(num_batch, num_nodes) | |
mask_ij = torch.ones(num_batch, num_nodes, num_neighbors) | |
node_out, edge_out = layer( | |
node_h, node_features, edge_h, edge_features, mask_i, mask_ij | |
) | |
self.assertTrue(node_out.shape == (1, num_nodes, dim_nodes)) | |
self.assertTrue(edge_out.shape == (1, num_nodes, num_neighbors, dim_edges)) | |