hexviz / tests /test_attention.py
aksell's picture
Add ModelType enum and Model class to hold layers and head count
ebbe380
raw
history blame
1.08 kB
import torch
from Bio.PDB.Structure import Structure
from transformers import T5EncoderModel, T5Tokenizer
from protention.attention import (ModelType, get_attention, get_protT5,
get_sequences, get_structure)
def test_get_structure():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
assert structure is not None
assert isinstance(structure, Structure)
def test_get_sequences():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
sequences = get_sequences(structure)
assert sequences is not None
assert len(sequences) == 2
A, B = sequences
assert A[:3] == ["M", "R", "I"]
def test_get_protT5():
result = get_protT5()
assert result is not None
assert isinstance(result, tuple)
tokenizer, model = result
assert isinstance(tokenizer, T5Tokenizer)
assert isinstance(model, T5EncoderModel)
def test_get_attention_tape():
result = get_attention("1AKE", model=ModelType.tape_bert)
assert result is not None
assert result.shape == torch.Size([12,12,456,456])