File size: 6,807 Bytes
89c0b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import biotite.structure as struc
import numpy as np
from biotite.structure import AtomArray
from protenix.data.constants import ELEMS, STD_RESIDUES
class Token(object):
"""
Used to store information related to Tokens.
Example:
>>> token = Token(1)
>>> token.value
1
>>> token.atom_indices = [1, 2, 3]
"""
def __init__(self, value, **kwargs):
self.value = value
self._annot = {}
for name, annotation in kwargs.items():
self._annot[name] = annotation
def __getattr__(self, attr):
if attr in super().__getattribute__("_annot"):
return self._annot[attr]
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)
def __repr__(self):
annot_lst = []
for k, v in self._annot.items():
annot_lst.append(f"{k}={v}")
return f'Token({self.value}, {",".join(annot_lst)})'
def __setattr__(self, attr, value):
if attr == "_annot":
super().__setattr__(attr, value)
elif attr == "value":
super().__setattr__(attr, value)
else:
self._annot[attr] = value
class TokenArray(object):
"""
A group of Token objects used for batch operations.
"""
def __init__(self, tokens: list[Token]):
self.tokens = tokens
def __repr__(self):
repr_str = "TokenArray(\n"
for token in self.tokens:
repr_str += f"\t{token}\n"
repr_str += ")"
return repr_str
def __len__(self):
return len(self.tokens)
def __iter__(self):
for token in self.tokens:
yield token
def __getitem__(self, index):
if isinstance(index, int):
return self.tokens[index]
else:
return TokenArray([self.tokens[i] for i in index])
def get_annotation(self, category):
return [token._annot[category] for token in self.tokens]
def set_annotation(self, category, values):
assert len(values) == len(
self.tokens
), "Length of values must match the number of tokens"
for token, value in zip(self.tokens, values):
token._annot[category] = value
def get_values(self):
return [token.value for token in self.tokens]
class AtomArrayTokenizer(object):
"""
Tokenize an AtomArray object into a list of Token object.
"""
def __init__(self, atom_array: AtomArray):
self.atom_array = atom_array
def tokenize(self) -> list[Token]:
"""
Ref: AlphaFold3 SI Chapter 2.6
Tokenize an AtomArray object into a list of Token object.
Returns:
list : a list of Token object.
"""
tokens = []
total_atom_num = 0
for res in struc.residue_iter(self.atom_array):
atom_num = len(res)
first_atom = res[0]
res_name = first_atom.res_name
mol_type = first_atom.mol_type
res_token = STD_RESIDUES.get(res_name, None)
if res_token is not None and mol_type != "ligand":
# for std residues
token = Token(res_token)
atom_indices = [
i for i in range(total_atom_num, total_atom_num + atom_num)
]
atom_names = [self.atom_array[i].atom_name for i in atom_indices]
token.atom_indices = atom_indices
token.atom_names = atom_names
tokens.append(token)
total_atom_num += atom_num
else:
# for ligand and non-std residues
for atom in res:
atom_elem = atom.element
atom_token = ELEMS.get(atom_elem, None)
if atom_token is None:
raise ValueError(f"Unknown atom element: {atom_elem}")
token = Token(atom_token)
token.atom_indices = [total_atom_num]
token.atom_names = [
self.atom_array[token.atom_indices[0]].atom_name
]
tokens.append(token)
total_atom_num += 1
assert total_atom_num == len(self.atom_array)
return tokens
def _set_token_annotations(self, token_array: TokenArray) -> TokenArray:
"""
Set annotations for the token_array.
The annotations include:
- centre_atom_index: the atom indices of the token in the atom array
Args:
token_array (TokenArray): TokenArray object created by tokenize bioassembly AtomArray.
Returns:
TokenArray: TokenArray object with annotations.
"""
centre_atom_indices = np.where(self.atom_array.centre_atom_mask == 1)[0]
token_array.set_annotation("centre_atom_index", centre_atom_indices)
assert len(token_array) == len(centre_atom_indices)
return token_array
def get_token_array(self) -> TokenArray:
"""
Get TokenArray object with annotations (atom_indices, centre_atom_index).
Returns:
TokenArray: The TokenArray object with annotations.
TokenArray(
Token(1, atom_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],centre_atom_index=2,
atom_names=['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'])
Token(15, atom_indices=[11, 12, 13, 14, 15, 16],centre_atom_index=13,
atom_names=['N', 'CA', 'C', 'O', 'CB', 'OG'])
Token(15, atom_indices=[17, 18, 19, 20, 21, 22],centre_atom_index=19,
atom_names=['N', 'CA', 'C', 'O', 'CB', 'OG'])
)
it satisfy the following format
Token($token_index, atom_indices=[global_atom_indexs],
centre_atom_index=global_atom_indexs,atom_names=[names])
"""
tokens = self.tokenize()
token_array = TokenArray(tokens=tokens)
token_array = self._set_token_annotations(token_array=token_array)
return token_array
|