Spaces:
Sleeping
Sleeping
File size: 874 Bytes
e9e75df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import tempfile
import torch
from chroma.models.procap import ProteinCaption, load_model, save_model
def test_procap():
model = ProteinCaption(
lm_id="EleutherAI/gpt-neo-125m",
gnn_dim_edges=16,
context_size=8,
context_per_chain=1,
gnn_num_neighbors=4,
gnn_num_layers=1,
)
assert sum(p.numel() for p in model.parameters()) == 128839584
X = torch.randn(1, 8, 4, 3)
C = torch.ones(X.shape[:2])
caption = ["test caption"]
chain_id = torch.tensor([1])
with torch.no_grad():
logits = model(X, C, caption, chain_id).logits
assert logits.shape == torch.Size([1, 11, 50260])
temp = tempfile.NamedTemporaryFile()
save_model(model, temp.name)
del model
model = load_model(temp.name)
assert sum(p.numel() for p in model.parameters()) == 128839584
temp.close()
|