Spaces:
Sleeping
Sleeping
import tempfile | |
import torch | |
from chroma.models.procap import ProteinCaption, load_model, save_model | |
def test_procap(): | |
model = ProteinCaption( | |
lm_id="EleutherAI/gpt-neo-125m", | |
gnn_dim_edges=16, | |
context_size=8, | |
context_per_chain=1, | |
gnn_num_neighbors=4, | |
gnn_num_layers=1, | |
) | |
assert sum(p.numel() for p in model.parameters()) == 128839584 | |
X = torch.randn(1, 8, 4, 3) | |
C = torch.ones(X.shape[:2]) | |
caption = ["test caption"] | |
chain_id = torch.tensor([1]) | |
with torch.no_grad(): | |
logits = model(X, C, caption, chain_id).logits | |
assert logits.shape == torch.Size([1, 11, 50260]) | |
temp = tempfile.NamedTemporaryFile() | |
save_model(model, temp.name) | |
del model | |
model = load_model(temp.name) | |
assert sum(p.numel() for p in model.parameters()) == 128839584 | |
temp.close() | |