Spaces:
Sleeping
Sleeping
File size: 2,076 Bytes
466a8f2 f8402f9 eb9ae1f f8402f9 87c0dbc f8402f9 87c0dbc b7ab123 87c0dbc cfba77f 466a8f2 cfba77f 466a8f2 cfba77f 466a8f2 cfba77f 58c7b8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from Bio.PDB.Structure import Structure
from hexviz.attention import (ModelType, get_attention, get_sequences,
get_structure, unidirectional_sum_filtered)
def test_get_structure():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
assert structure is not None
assert isinstance(structure, Structure)
def test_get_sequences():
pdb_id = "1AKE"
structure = get_structure(pdb_id)
sequences = get_sequences(structure)
assert sequences is not None
assert len(sequences) == 2
A, B = sequences
assert A[:3] == ["M", "R", "I"]
def test_get_attention_zymctrl():
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
assert result is not None
assert result.shape == torch.Size([36,16,3,3])
def test_get_attention_zymctrl_long_chain():
structure = get_structure(pdb_code="6A5J") # 13 residues long
sequences = get_sequences(structure)
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
assert result is not None
assert result.shape == torch.Size([36,16,13,13])
def test_get_attention_tape():
structure = get_structure(pdb_code="6A5J") # 13 residues long
sequences = get_sequences(structure)
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
assert result is not None
assert result.shape == torch.Size([12,12,13,13])
def test_get_unidirection_sum_filtered():
# 1 head, 1 layer, 4 residues long attention tensor
attention= torch.tensor([[[[1, 2, 3, 4],
[2, 5, 6, 7],
[3, 6, 8, 9],
[4, 7, 9, 11]]]], dtype=torch.float32)
result = unidirectional_sum_filtered(attention, 0, 0, 0)
assert result is not None
assert len(result) == 10
attention= torch.tensor([[[[1, 2, 3],
[2, 5, 6],
[4, 7, 91]]]], dtype=torch.float32)
result = unidirectional_sum_filtered(attention, 0, 0, 0)
assert len(result) == 6
|