Spaces:
Sleeping
Sleeping
import math | |
from unittest import TestCase | |
import pytest | |
import torch | |
import torch.nn as nn | |
from chroma.layers.basic import ( | |
MaybeOnehotEmbedding, | |
MeanEmbedding, | |
NodeProduct, | |
NoOp, | |
OneHot, | |
PeriodicPositionalEncoding, | |
PositionalEncoding, | |
PositionWiseFeedForward, | |
Transpose, | |
TriangleMultiplication, | |
Unsqueeze, | |
) | |
class TestBasicLayers(TestCase): | |
def setUp(self): | |
self.noop = NoOp() | |
self.onehot = OneHot(n_tokens=4) | |
self.transpose = Transpose(1, 2) | |
self.unsqueeze = Unsqueeze(1) | |
self.mean_embedding = MeanEmbedding(nn.Embedding(4, 64), use_softmax=False) | |
self.periodic = PeriodicPositionalEncoding(64) | |
self.pwff = PositionWiseFeedForward(64, 64) | |
def test_noop(self): | |
x = torch.randn(4, 2, 2) | |
self.assertTrue((x == self.noop(x)).all().item()) | |
def test_onehot(self): | |
input = torch.tensor([[0, 1, 2], [3, 0, 1]]) | |
onehot = self.onehot(input).transpose(1, 2) | |
target = torch.tensor( | |
[ | |
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]], | |
[[0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 0, 0]], | |
], | |
dtype=onehot.dtype, | |
) | |
self.assertTrue((onehot == target).all().item()) | |
def test_mean_embedding(self): | |
input = torch.tensor([[0, 1, 2], [3, 0, 1]]) | |
onehot = self.onehot(input) | |
self.assertTrue( | |
(self.mean_embedding(input) == self.mean_embedding(onehot.float())) | |
.all() | |
.item() | |
) | |
def test_triangle_multiplication(self): | |
bs = 4 | |
nres = 25 | |
d_model = 12 | |
m = TriangleMultiplication(d_model=d_model) | |
X = torch.randn(bs, nres, nres, d_model) | |
mask = torch.ones(bs, nres, nres, 1) | |
self.assertTrue( | |
m(X, mask.bool()).size() == torch.Size([bs, nres, nres, d_model]) | |
) | |
def test_node_product(self): | |
bs = 4 | |
nres = 25 | |
d_model = 12 | |
m = NodeProduct(d_in=d_model, d_out=d_model) | |
node_h = torch.randn(bs, nres, d_model) | |
node_mask = torch.ones(bs, nres).bool() | |
edge_mask = torch.ones(bs, nres, nres).bool() | |
self.assertTrue( | |
m(node_h, node_mask, edge_mask).size() | |
== torch.Size([bs, nres, nres, d_model]) | |
) | |
def test_transpose(self): | |
x = torch.randn(4, 5, 2) | |
self.assertTrue((x == self.transpose(x).transpose(1, 2)).all().item()) | |
def test_periodic(self): | |
position = torch.arange(0.0, 4000).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0.0, 64, 2) * -(math.log(10000.0) / 64)) | |
self.assertTrue( | |
(self.periodic.pe.squeeze()[:, 0::2] == torch.sin(position * div_term)) | |
.all() | |
.item() | |
) | |
self.periodic(torch.randn(6, 30, 64)) | |
def test_pwff(self): | |
x = torch.randn(4, 5, 64) | |
self.assertTrue(self.pwff(x).size() == x.size()) | |
def test_positional_encoding(d_model, d_input): | |
encoding = PositionalEncoding(d_model, d_input) | |
for batch_shape in [(), (4,), (3, 2)]: | |
inputs = torch.randn(batch_shape + (d_input,), requires_grad=True) | |
outputs = encoding(inputs) | |
assert outputs.shape == batch_shape + (d_model,) | |
assert torch.isfinite(outputs).all() | |
outputs.sum().backward() # smoke test | |
def test_maybe_onehot_embedding(): | |
x = torch.empty(10, dtype=torch.long).random_(4) | |
x_onehot = nn.functional.one_hot(x, 4).float() | |
embedding = MaybeOnehotEmbedding(4, 8) | |
expected = embedding(x) | |
actual = embedding(x_onehot) | |
assert torch.allclose(expected, actual) | |