Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import copy | |
import logging | |
import re | |
import warnings | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import torch | |
import chroma.utility.polyseq as polyseq | |
import chroma.utility.starparser as sp | |
from chroma import constants | |
class SystemAssemblyInfo: | |
"""A class for representing the assembly information for System objects. | |
assemblies (dict): a dictionary of assemblies with keys being assembly IDs | |
and values being dictionaries with of the following structure: | |
{ | |
"details": "complete icosahedral assembly", | |
"instructions": [ | |
{ | |
"oper_expression": "(1-60)", | |
"chains": [0, 1, 2], | |
# Each assembly instruction has information for generating | |
# one or more images, with image `i` generated by applying | |
# the sequence of operations with IDs in `operations[i]` to the | |
# list of chains in `chains`. The corresponding operations | |
# are described under `assembly_info["operations"][ID]`. | |
"operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], | |
}, | |
... | |
], | |
} | |
operations (dict): a dictionary with symmetry operations. Keys are operation IDs | |
and values being dictionaries with the following structure: | |
{ | |
"type": "identity operation", | |
"name": "1_555", | |
"matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), | |
"vector": np.array([0., 0., 0.]), | |
}, | |
... | |
""" | |
assemblies: dict | |
operations: dict | |
def __init__(self, assemblies: dict = dict(), operations: dict = dict()): | |
self.assemblies = assemblies | |
self.operations = operations | |
def make_operation(type: str, name: str, matrix: list, vector: list): | |
op = { | |
"type": type, | |
"name": name, | |
"matrix": np.zeros([3, 3]), | |
"vector": np.zeros([3, 1]), | |
} | |
assert len(matrix) == 9, "expected 9 elements in rotation matrix" | |
assert len(vector) == 3, "expected 3 elements in translation vector" | |
for i in range(3): | |
op["vector"][i] = float(vector[i]) | |
for j in range(3): | |
op["matrix"][i][j] = float(matrix[i * 3 + j]) | |
return op | |
def delete_chain(self, cid: str): | |
"""Deletes the mention of the chain from assembly information. | |
Args: | |
cid (str): Chain ID to delete. | |
""" | |
for ass_id, assembly in self.assemblies.items(): | |
for ins in assembly["instructions"]: | |
ins["chains"] = [_id for _id in ins["chains"] if _id != cid] | |
def rename_chain(self, old_cid: str, new_cid: str): | |
"""Renames all mentions of a chain to its new chain ID. | |
Args: | |
old_cid (str): Chain ID to rename. | |
new_cid (str): Newly assigned Chain ID. | |
""" | |
for ass_id, assembly in self.assemblies.items(): | |
for ins in assembly["instructions"]: | |
ins["chains"] = [ | |
new_cid if cid == old_cid else cid for cid in ins["chains"] | |
] | |
class StringList: | |
"""A class for representing and accessing a list of strings in a highly memory-efficient | |
manner. Access is constant time, but modification is linear time in length of list. | |
""" | |
def __init__(self, init_list: List[str] = []): | |
self.string = "" | |
self.rng = ArrayList(2, dtype=int) | |
for i in range(len(init_list)): | |
self.append(init_list[i]) | |
def __getitem__(self, i: int): | |
beg, length = self.rng[i] | |
return self.string[beg : beg + length] | |
def __setitem__(self, i: int, new_string: str): | |
beg, length = self.rng[i] | |
self.string = self.string[:beg] + new_string + self.string[beg + length :] | |
if len(new_string) != length: | |
self.rng[i, 1] = len(new_string) | |
self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length | |
def __str__(self): | |
return self.string | |
def __len__(self): | |
return len(self.rng) | |
def copy(self): | |
new_list = StringList() | |
new_list.string = self.string | |
new_list.rng = self.rng.copy() | |
return new_list | |
def append(self, new_string: str): | |
self.rng.append([len(self.string), len(new_string)]) | |
self.string = self.string + new_string | |
def insert(self, i: int, new_string: str): | |
if i < len(self): | |
ix, _ = self.rng[i] | |
elif i == len(self): | |
if len(self) == 0: | |
ix = 0 | |
else: | |
ix = self.rng[i - 1].sum() | |
else: | |
raise Exception( | |
f"cannot insert in position {i} for stringList of length {len(self)}" | |
) | |
self.string = self.string[0:ix] + new_string + self.string[ix:] | |
self.rng.insert(i, [ix, len(new_string)]) | |
if len(new_string) > 0: | |
self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) | |
def pop(self, i: int): | |
beg, length = self.rng[i] | |
val = self.string[beg : beg + length] | |
self.string = self.string[0:beg] + self.string[beg + length :] | |
self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) | |
self.rng.pop(i) | |
return val | |
def delete_range(self, rng: range): | |
rng = sorted(rng) | |
[i, j] = [rng[0], rng[-1]] | |
beg, _ = self.rng[i] | |
end = self.rng[j].sum() | |
self.string = self.string[0:beg] + self.string[end:] | |
self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) | |
self.rng.delete_range(rng) | |
class NameList: | |
"""A class for representing and accessing a list of "names"--i.e., strings that tend to | |
have generic values, such that many repeat values are expected in a given list.""" | |
def __init__(self, init_list: List[str] = []): | |
self._reindex(init_list) | |
def _reindex(self, init_list: List[str]): | |
self.unique_names = [] | |
self.name_indicies = dict() | |
self.index_use = dict() | |
self.indices = ArrayList(1, dtype=int) | |
for name in init_list: | |
self.append(name) | |
def copy(self): | |
new_list = NameList() | |
new_list.unique_names = self.unique_names.copy() | |
new_list.name_indicies = self.name_indicies.copy() | |
new_list.index_use = self.index_use.copy() | |
new_list.indices = self.indices.copy() | |
return new_list | |
def _check_index(self): | |
L = len(self.unique_names) | |
I = len(self.index_use) | |
if (L > 2 * I) and (L - I > 10): | |
self._reindex([self[i] for i in range(len(self))]) | |
def __getitem__(self, i: int): | |
try: | |
idx = self.indices[i].item() | |
except IndexError as e: | |
raise IndexError(f"index {i} out of range for nameList\n" + str(e)) | |
return self.unique_names[idx] | |
def __setitem__(self, i: int, new_name: str): | |
try: | |
idx = self.indices[i] | |
except IndexError as e: | |
raise IndexError(f"index {i} out of range for nameList\n" + str(e)) | |
self.index_use[idx] = self.index_use[idx] - 1 | |
if self.index_use[idx] == 0: | |
del self.index_use[idx] | |
if new_name not in self.name_indicies: | |
idx = len(self.name_indicies) | |
self.name_indicies[new_name] = idx | |
self.unique_names.append(new_name) | |
else: | |
idx = self.name_indicies[new_name] | |
self.indices[i] = idx | |
self._update_use(idx, 1) | |
self._check_index() | |
def __str__(self): | |
return str([self[i] for i in range(len(self))]) | |
def __len__(self): | |
return len(self.indices) | |
def _update_use(self, idx, delta): | |
self.index_use[idx] = self.index_use.get(idx, 0) + delta | |
if self.index_use[idx] <= 0: | |
del self.index_use[idx] | |
def _get_name_index(self, name: str): | |
if name not in self.name_indicies: | |
idx = len(self.name_indicies) | |
self.name_indicies[name] = idx | |
self.unique_names.append(name) | |
else: | |
idx = self.name_indicies[name] | |
return idx | |
def append(self, name: str): | |
idx = self._get_name_index(name) | |
self.indices.append(idx) | |
self.index_use[idx] = self.index_use.get(idx, 0) + 1 | |
def insert(self, i: int, new_string: str): | |
idx = self._get_name_index(new_string) | |
self.indices.insert(i, idx) | |
self.index_use[idx] = self.index_use.get(idx, 0) + 1 | |
def pop(self, i: int): | |
idx = self.indices.pop(i).item() | |
val = self.unique_names[idx] | |
self._update_use(idx, -1) | |
self._check_index() | |
return val | |
def delete_range(self, rng: range): | |
for i in reversed(sorted(rng)): | |
self.pop(i) | |
class ArrayList: | |
def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): | |
if ndims == 1: | |
self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) | |
else: | |
self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) | |
self.ndims = ndims | |
self._array[:] = val | |
self.length = length | |
# view of just the data without the extra allocated stuff | |
self.array = self._array[: self.length] | |
def convert_negative_slice(self, slice_obj): | |
start = slice_obj.start if slice_obj.start is not None else 0 | |
stop = slice_obj.stop if slice_obj.stop is not None else self.length | |
if start < 0: | |
start = self.length + start | |
if stop < 0: | |
stop = self.length + stop | |
return slice(start, stop, slice_obj.step) | |
def copy(self): | |
new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) | |
new_list[:] = self[:] | |
return new_list | |
def __len__(self): | |
return self.length | |
def capacity(self): | |
return self._array.shape[0] | |
def __getitem__(self, i: int): | |
return self.array[i] | |
def __setitem__(self, i: int, row: list): | |
self.array[i] = row | |
def resize(self, delta): | |
# for speed, hard-code instead of calling len() and capacity() | |
new_length = self.length + delta | |
cap = self._array.shape[0] | |
if (new_length > cap) or (new_length < cap / 3): | |
new_capacity = 2 * new_length | |
self._resize(new_capacity) | |
self.length = new_length | |
self.array = self._array[: self.length] | |
def _resize(self, new_size): | |
if self.ndims == 1: | |
self._array.resize((new_size), refcheck=False) | |
else: | |
self._array.resize((new_size, self.ndims), refcheck=False) | |
def items(self): | |
for i in range(self.length): | |
yield self.array[i, :] | |
def append(self, row: list): | |
self.resize(1) | |
self.array[-1] = row | |
def insert(self, i: int, row: list): | |
"""Insert the row such that it ends up being at index ``i`` in the new arrayList""" | |
# resize by +1 | |
self.resize(1) | |
# everything in range [i:end-1) moves over by +1 | |
self.array[i + 1 :] = self.array[i:-1] | |
# set the value at index i | |
self.array[i] = row | |
def pop(self, i: int): | |
"""Remove and return element at index i""" | |
# get the element at index i | |
row = self.array[i].copy() | |
# everything from [i+1; end) moves over by -1 | |
self.array[i:-1] = self.array[i + 1 :] | |
# resize by -1 | |
self.resize(-1) | |
return row | |
def delete_range(self, rng: range): | |
i, j = min(rng), max(rng) | |
# move over to the left to account for the removed part | |
cut_length = j - i + 1 | |
new_length = len(self) - cut_length | |
self.array[i:new_length] = self.array[j + 1 :] | |
# resize by -1 | |
self.resize(-cut_length) | |
def __str__(self): | |
return str([self[i] for i in range(len(self))]) | |
class HierarchicList: | |
"""A utility class that represents a hierarchy of lists. Each level represents | |
a list of elements, each element having a set of properties (each property being | |
stored as an array-like object over elements). Further, each element has a number | |
of children corresponding to a range of elements in a lower-hierarhy list.""" | |
_properties: dict | |
_parent_list: HierarchicList | |
_child_list: HierarchicList | |
_num_children: ArrayList # (1, n) | |
_child_offset: ArrayList # (1, n) | |
def __init__( | |
self, | |
properties: dict, | |
parent_list: HierarchicList = None, | |
num_children: ArrayList = ArrayList(1, dtype=int), | |
): | |
self._properties = dict() | |
for key in properties: | |
self._properties[key] = properties[key].copy() | |
self._parent_list = parent_list | |
if self._parent_list is not None: | |
self._parent_list._child_list = self | |
self._child_list = None | |
self._num_children = num_children.copy() if num_children is not None else None | |
# start off with lazy offsets, self.reindex() creates them | |
self._child_offset = None | |
def copy(self): | |
new_list = HierarchicList( | |
self._properties, self._parent_list, self._num_children | |
) | |
new_list._child_list = self._child_list | |
if self._child_offset is None: | |
new_list._child_offset = None | |
else: | |
new_list._child_offset = self._child_offset.copy() | |
return new_list | |
def set_parent(self, parent_list: HierarchicList): | |
self._parent_list = parent_list | |
def child_index(self, i: int, at: int): | |
if self._child_offset is not None: | |
return self._child_offset[i] + at | |
return self._num_children[0:i].sum() + at | |
def reindex(self): | |
if self._num_children is not None: | |
self._child_offset = ArrayList( | |
1, dtype=int, length=len(self._num_children), val=0 | |
) | |
for i in range(1, len(self)): | |
self._child_offset[i] = ( | |
self._child_offset[i - 1] + self._num_children[i - 1] | |
) | |
def append_child(self, properties): | |
self._num_children[len(self._num_children) - 1] += 1 | |
self._child_list.append(properties) | |
def insert_child(self, i: int, at: int, properties): | |
idx = self.child_index(i, at) | |
self._num_children[i] += 1 | |
self._child_offset = None | |
self._child_list.insert(idx, properties) | |
return idx | |
def delete_child(self, i: int, at: int): | |
idx = self.child_index(i, at) | |
self._num_children[i] -= 1 | |
self._child_offset = None | |
self._child_list.delete(idx) | |
def append(self, properties): | |
if set(properties.keys()) != set(self._properties.keys()): | |
raise Exception(f"unexpected set of attributes '{list(properties.keys())}") | |
for key, value in properties.items(): | |
self._properties[key].append(value) | |
if self._child_offset is not None: | |
self._child_offset.append( | |
self._child_offset[-1:].sum() + self._num_children[-1:].sum() | |
) | |
if self._num_children is not None: | |
self._num_children.append(0) | |
def insert(self, i: int, properties): | |
if set(properties.keys()) != set(self._properties.keys()): | |
raise Exception(f"unexpected set of attributes '{list(properties.keys())}") | |
for key, value in properties.items(): | |
self._properties[key].insert(i, value) | |
if self._child_offset is not None: | |
if i >= len(self._child_offset): | |
off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() | |
else: | |
off = self._child_offset[i] | |
self._child_offset.insert(i, off) | |
if self._num_children is not None: | |
self._num_children.insert(i, 0) | |
def delete(self, i: int): | |
for key in self._properties: | |
self._properties[key].pop(i) | |
if self._num_children is not None and self._num_children[i] != 0: | |
for at in range(self._num_children[i] - 1, -1, -1): | |
self.delete_child(i, at) | |
self._num_children.pop(i) | |
self._child_offset = None | |
def delete_range(self, rng: range): | |
for key in self._properties: | |
self._properties[key].delete_range(rng) | |
# iterating in descending order so that child offsets remain valid for subsequent elements | |
for i in reversed(sorted(rng)): | |
if self._num_children is not None and self._num_children[i] != 0: | |
idx = self.child_index(i, 0) | |
self._child_list.delete_range( | |
self, range(idx, idx + self._num_children[i]) | |
) | |
self._num_children[i] = 0 | |
self._child_offset = None | |
def __len__(self): | |
for key in self._properties: | |
return len(self._properties[key]) | |
return None | |
def __getitem__(self, i: str): | |
return self._properties[i] | |
# def __setitem__(self, i: tuple, val): | |
# self._properties[i[0]][i[1]] = val | |
def num_children(self, i: int): | |
return self._num_children[i] | |
def has_children(self, i: int): | |
return self._num_children is not None and self._num_children[i] | |
def __str__(self): | |
string = "Properties:\n" | |
for key in self._properties: | |
string += f"{key}: {str(self._properties[key])}\n" | |
string += f"num_children: {str(self._num_children)}\n" | |
string += f"child_offset: {str(self._child_offset)}\n" | |
string += "----\n" | |
string += str(self._child_list) | |
return string | |
class System: | |
"""A class for storing, accessing, managing, and manipulating a molecular | |
system's structure, sequence, and topological information. The class is | |
organized as a hierarchy of objects: | |
System: top-level class containing all information about a molecular system | |
-> Chain: a sub-portion of the System; for polymers this is generally a | |
chemically connected molecular graph belong to a System (e.g., for | |
protein complexes, this would be one of the proteins). | |
-> Residue: a generally chemically-connected molecular unit (for polymers, | |
the repeating unit), belonging to a Chain. | |
-> Atom: an atom belonging to a Residue with zero, one, or more locations. | |
-> AtomLocation: the location of an Atom (3D coordinates and other information). | |
Attributes: | |
name (str): given name for System | |
_chains (list): a list of Chain objects | |
_entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs | |
_chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into | |
`entities`) corresponding to the entity for chain `ci` | |
_extra_models (list): a list of hierarchicList object, representing locations | |
for alternative models | |
_labels (dict): a dictionary of residue labels. A label is a string value, | |
under some category (also a string), associated with a residue. E.g., | |
the category could be "SSE" and the value could be "H" or "S". If entry | |
`labels[category][gti]` exists and is equal to `value`, this means that | |
residue with global template index `gti` has the label `category:value`. | |
_selections (dict): a dictionary of selections. Keys are selection names and | |
values are lists of corresponding gti indices. | |
_assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can | |
be constructed from components of the molecular system. See ``SystemAssemblyInfo``. | |
""" | |
name: str | |
_chains: HierarchicList | |
_residues: HierarchicList | |
_atoms: HierarchicList | |
_locations: HierarchicList | |
_entities: Dict[int, SystemEntity] | |
_chain_entities: List[int] | |
_extra_models: List[HierarchicList] | |
_labels: Dict[str, Dict[int, str]] | |
_selections: Dict[str, List[int]] | |
_assembly_info: SystemAssemblyInfo | |
def __init__(self, name: str = "system"): | |
self.name = name | |
self._chains = HierarchicList( | |
properties={ | |
"cid": StringList(), | |
"segid": StringList(), | |
"authid": StringList(), | |
} | |
) | |
self._residues = HierarchicList( | |
properties={ | |
"name": NameList(), | |
"resnum": ArrayList(1, dtype=int), | |
"authresid": StringList(), | |
"icode": ArrayList(1, dtype="U1"), | |
}, | |
parent_list=self._chains, | |
) | |
self._atoms = HierarchicList( | |
properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, | |
parent_list=self._residues, | |
) | |
self._locations = HierarchicList( | |
properties={ | |
"coor": ArrayList(5, dtype=float), | |
"alt": ArrayList(1, dtype="U1"), | |
}, | |
parent_list=self._atoms, | |
num_children=None, | |
) | |
self._entities = dict() | |
self._chain_entities = [] | |
self._extra_models = [] | |
self._labels = dict() | |
self._selections = dict() | |
self._assembly_info = SystemAssemblyInfo() | |
def _reindex(self): | |
self._chains.reindex() | |
self._residues.reindex() | |
self._atoms.reindex() | |
self._locations.reindex() | |
def _print_indexing(self): | |
for chain in self.chains(): | |
off = self._chains.child_index(chain._ix, 0) | |
num = self._chains.num_children(chain._ix) | |
print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") | |
for residue in chain.residues(): | |
off = self._residues.child_index(residue._ix, 0) | |
num = self._residues.num_children(residue._ix) | |
print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") | |
for atom in residue.atoms(): | |
off = self._atoms.child_index(atom._ix, 0) | |
num = self._atoms.num_children(atom._ix) | |
print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") | |
for loc in atom.locations(): | |
has_children = self._locations.has_children(loc._ix) | |
print( | |
f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" | |
) | |
def from_XCS( | |
cls, | |
X: torch.Tensor, | |
C: torch.Tensor, | |
S: torch.Tensor, | |
alternate_alphabet: str = None, | |
) -> System: | |
"""Convert an XCS set of pytorch tensors to a new System object. | |
B is batch size (Function only handles batch size of one now) | |
N is the number of residues | |
Args: | |
X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. | |
`num_atoms` will be 14 if `all_atom=True` or 4 otherwise. | |
C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes | |
positions as 0 when masked, positive integers for chain indices, | |
and negative integers to represent missing residues of the | |
corresponding positive integers. | |
S (torch.LongTensor): Sequence with shape `(1, num_residues)`. | |
alternate_alphabet (str, optional): Optional alternative alphabet for | |
sequence encoding. Otherwise the default alphabet is set in | |
`constants.AA20`.Amino acid alphabet for embedding. | |
Returns: | |
System: A System object with the new XCS data. | |
""" | |
alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet | |
all_atom = X.shape[2] == 14 | |
assert X.shape[0] == 1 | |
assert C.shape[0] == 1 | |
assert S.shape[0] == 1 | |
assert X.shape[1] == S.shape[1] | |
assert C.shape[1] == C.shape[1] | |
X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] | |
chain_ids = np.abs(C) | |
atom_count = 0 | |
new_system = cls("system") | |
for i, chain_id in enumerate(np.unique(chain_ids)): | |
if chain_id == 0: | |
continue | |
chain_bool = chain_ids == chain_id | |
X_chain = X[chain_bool, :, :].tolist() | |
C_chain = C[chain_bool].tolist() | |
S_chain = S[chain_bool].tolist() | |
# Build chain | |
chain = new_system.add_chain("A") | |
for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): | |
resname = polyseq.to_triple(alphabet[int(S_i)]) | |
# Build residue | |
residue = chain.add_residue( | |
resname, chain_ix + 1, str(chain_ix + 1), " " | |
) | |
if C_i > 0: | |
atom_names = constants.ATOMS_BB | |
if all_atom and resname in constants.AA_GEOMETRY: | |
atom_names = ( | |
atom_names + constants.AA_GEOMETRY[resname]["atoms"] | |
) | |
for atom_ix, atom_name in enumerate(atom_names): | |
x, y, z = X_i[atom_ix] | |
atom_count += 1 | |
residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") | |
# add an entity for each chain (copy from chain information) | |
for ci, chain in enumerate(new_system.chains()): | |
seq = [None] * chain.num_residues() | |
het = [None] * chain.num_residues() | |
for ri, res in enumerate(chain.residues()): | |
seq[ri] = res.name | |
het[ri] = all(a.het for a in res.atoms()) | |
entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) | |
entity = SystemEntity( | |
entity_type, f"chain {chain.cid}", polymer_type, seq, het | |
) | |
new_system.add_new_entity(entity, [ci]) | |
return new_system | |
def to_XCS( | |
self, | |
all_atom: bool = False, | |
batch_dimension: bool = True, | |
mask_unknown: bool = True, | |
unknown_token: int = 0, | |
reorder_chain: bool = True, | |
alternate_alphabet=None, | |
alternate_atoms=None, | |
get_indices=False, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Convert System object to XCS format. | |
`C` tensor has shape [num_residues], where it codes positions as 0 | |
when masked, positive integers for chain indices, and negative integers | |
to represent missing residues of the corresponding positive integers. | |
`S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. | |
If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if | |
also want to mask `unk residue` in `chain_map` | |
This function takes into account missing residues and updates chain_map | |
accordingly. | |
Args: | |
system (type): generate System object to convert. | |
all_atom (bool): Include side chain atoms. Default is `False`. | |
batch_dimension (bool): Include a batch dimension. Default is `True`. | |
mask_unknown (bool): Mask residues not found in the alphabet. Default is | |
`True`. | |
unknown_token (int): Default token index if a residue is not found in | |
the alphabet. Default is `0`. | |
reorder_chain (bool): If set to true will start indexing chain at 1, | |
else will use the alphabet index (Default: True) | |
altenate_alphabet (str): Alternative alphabet if not `None`. | |
alternate_atoms (list): Alternate atom name subset for `X` if not `None`. | |
get_indices (bool): Also return the location indices corresponding to the | |
returned `X` tensor. | |
Returns: | |
X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. | |
`num_atoms` will be 14 if `all_atom=True` or 4 otherwise. | |
C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes | |
positions as 0 when masked, positive integers for chain indices, | |
and negative integers to represent missing residues of the | |
corresponding positive integers. | |
S (torch.LongTensor): Sequence with shape `(1, num_residues)`. | |
location_indices (np.ndaray, optional): location indices corresponding to | |
the coordinates in `X`. | |
""" | |
alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet | |
# Get chain map grabbing each chain in system and look at length | |
C = [] | |
for ch_id, chain in enumerate(self.chains()): | |
ch_str = chain.cid | |
if ch_str in list(constants.CHAIN_ALPHABET): | |
map_ch_id = list(constants.CHAIN_ALPHABET).index(ch_str) | |
else: | |
# fmt: off | |
map_ch_id = np.setdiff1d(np.arange(1, len(constants.CHAIN_ALPHABET)), np.unique(C))[0] | |
# fmt: on | |
if reorder_chain: | |
map_ch_id = ch_id + 1 | |
C += [map_ch_id] * chain.num_residues() | |
# Grab full sequence | |
oneLetterSeq = self.sequence(format="one-letter-string") | |
if len(oneLetterSeq) != len(C): | |
logging.warning("Warning, System and chain_map length don't agree") | |
# Initialize recipient arrays | |
atom_names = None | |
if all_atom: | |
num_atoms = 14 if all_atom else 4 | |
else: | |
if alternate_atoms is not None: | |
atom_names = alternate_atoms | |
else: | |
atom_names = constants.ATOMS_BB | |
num_atoms = len(atom_names) | |
atom_names = {a: i for (i, a) in enumerate(atom_names)} | |
num_residues = self.num_residues() | |
X = np.zeros([num_residues, num_atoms, 3]) | |
location_indices = ( | |
np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None | |
) | |
S = [] # Array will contain sequence indices | |
for i in range(num_residues): | |
# If residue should be mask or not | |
is_mask = False | |
# Add sequence | |
if oneLetterSeq[i] in list(alphabet): | |
S.append(alphabet.index(oneLetterSeq[i])) | |
else: | |
S.append(unknown_token) | |
if mask_unknown: | |
is_mask = True | |
# Get residue from system | |
res = self.get_residue(i) | |
if res is None or not res.has_structure(): | |
is_mask = True | |
# If residue is mask because no structure or not found in alphabet | |
if is_mask: | |
# Set chain map to -x | |
C[i] = -abs(C[i]) | |
else: | |
# Loop through atoms | |
if all_atom: | |
code3 = constants.AA20_1_TO_3[oneLetterSeq[i]] | |
atom_names = ( | |
constants.ATOMS_BB + constants.AA_GEOMETRY[code3]["atoms"] | |
) | |
atom_names = {a: i for (i, a) in enumerate(atom_names)} | |
X[ | |
i, : | |
] = np.nan # so we can tell whether some atom was previously found | |
num_rem = len(atom_names) | |
for atom in res.atoms(): | |
name = System.protein_backbone_atom_type(atom.name, False, True) | |
if name is None: | |
name = atom.name | |
ix = atom_names.get(name, None) | |
if ix is None or not np.isnan(X[i, ix, 0]): | |
continue | |
for loc in atom.locations(): | |
X[i, ix] = loc.coors | |
if location_indices is not None: | |
location_indices[i * num_atoms + ix] = loc.get_index() | |
num_rem -= 1 | |
break | |
if num_rem == 0: | |
break | |
if num_rem != 0: | |
C[i] = -abs(C[i]) | |
X[i, :] = 0 | |
np.nan_to_num(X[i, :], copy=False, nan=0) | |
# Tensor everything | |
X = torch.tensor(X).float() | |
C = torch.tensor(C).type(torch.long) | |
S = torch.tensor(S).type(torch.long) | |
# Unsqueeze all the thing | |
if batch_dimension: | |
X = X.unsqueeze(0) | |
C = C.unsqueeze(0) | |
S = S.unsqueeze(0) | |
if location_indices is not None: | |
return X, C, S, location_indices | |
return X, C, S | |
def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): | |
"""Update the System with XCS coordinates. NOTE: if the System has | |
more than one model, and if the shape of the System changes (i.e., | |
atoms are added or deleted), the additional models will be wiped. | |
Args: | |
X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. | |
`num_atoms` will be 14 if `all_atom=True` or 4 otherwise. | |
C (LongTensor): Chain map with shape `(1, num_residues)`. It codes | |
positions as 0 when masked, positive integers for chain indices, | |
and negative integers to represent missing residues of the | |
corresponding positive integers. Defaults to the current System's | |
chain map. | |
S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to | |
the current System's sequence. | |
""" | |
if C is None or S is None: | |
_, _C, _S = self.to_XCS() | |
if C is None: | |
C = _C | |
if S is None: | |
S = _S | |
# check to make sure sizes agree | |
if not ( | |
(X.shape[1] == self.num_residues()) | |
and (X.shape[1] == C.shape[1]) | |
and (X.shape[1] == S.shape[1]) | |
): | |
raise Exception( | |
f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" | |
) | |
def _process_inputs(T): | |
if T is not None: | |
if len(T.shape) == 2 or len(T.shape) == 4: | |
T = T.squeeze(0) | |
T = T.to("cpu").detach().numpy() | |
return T | |
X, C, S = map(_process_inputs, [X, C, S]) | |
shape_changed = False | |
alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet | |
for i, res in enumerate(self.residues()): | |
# atoms to update must have structure and are present in the chain map | |
if not res.has_structure() or C[i] <= 0: | |
continue | |
# First, determine if the sequence has changed | |
resname = "UNK" | |
if S is not None and S[i] < len(alphabet): | |
resname = polyseq.to_triple(alphabet[S[i]]) | |
# If the identity changes, rename and delete side chain atoms | |
if res.name != resname: | |
res.rename(resname) | |
# Second, delete all atoms that are missing in XCS or have duplicate names | |
atoms_sys = [atom.name for atom in res.atoms()] | |
atoms_XCS = constants.ATOMS_BB | |
if resname in constants.AA_GEOMETRY: | |
atoms_XCS = atoms_XCS + constants.AA_GEOMETRY[resname]["atoms"] | |
atoms_XCS = atoms_XCS[: X.shape[1]] | |
to_delete = [] | |
for ix_a, atom in enumerate(res.atoms()): | |
name = atom.name | |
if name not in atoms_XCS or name in atoms_sys[:ix_a]: | |
to_delete.append(atom) | |
if len(to_delete) > 0: | |
shape_changed = True | |
res.delete_atoms(to_delete) | |
# Finally, update all atom coordinates and manufacture any missing atoms | |
for x_id, atom_name in enumerate(atoms_XCS): | |
atom = res.find_atom(atom_name) | |
x, y, z = [X[i][x_id][k].item() for k in range(3)] | |
if atom is not None and atom.num_locations() > 0: | |
atom.x = x | |
atom.y = y | |
atom.z = z | |
else: | |
shape_changed = True | |
if atom is not None: | |
atom.add_location(x, y, z) | |
else: | |
res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) | |
# wipe extra models if the shape of the System changed | |
if shape_changed: | |
self._extra_models = [] | |
def __str__(self): | |
return "system " + self.name | |
def chains(self): | |
"""Chain iterator (generator function).""" | |
for ci in range(len(self._chains)): | |
yield ChainView(ci, self) | |
def get_chain(self, ci: int): | |
"""Returns the chain by index. | |
Args: | |
ci (int): Chain index (from 0) | |
Returns: | |
ChainView object corresponding to the chain in question. | |
""" | |
return ChainView(ci, self) | |
def get_chain_by_id(self, cid: str, segid=False): | |
"""Returns the chain by its string ID. | |
Args: | |
cid (str): Chain ID. | |
segid (bool, optional): If set to True (default is False) will | |
return the chain with the matching segment ID and not chain ID. | |
Returns: | |
ChainView object corresponding to the chain in question. | |
""" | |
for ci, chain in enumerate(self.chains()): | |
if (not segid and cid == chain.cid) or (segid and cid == chain.segid): | |
return ChainView(ci, self) | |
return None | |
def get_chains(self): | |
"""Returns the list of all chains.""" | |
return [ChainView(ci, self) for ci in range(len(self._chains))] | |
def get_chains_of_entity(self, entity_id: int, by=None): | |
"""Returns the list of chains that correspond to the given entity ID. | |
Args: | |
entity_id (int): Entity ID. | |
by (str, optional): If specified as "index", will return a | |
list of chain indices instead of ChainView objects. | |
Returns: | |
List of ChainView objects or chain indices. | |
""" | |
cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] | |
if by == "index": | |
return cixs | |
return [ChainView(ci, self) for ci in cixs] | |
def residues(self): | |
"""Residue iterator (generator function).""" | |
for chain in self.chains(): | |
for residue in chain.residues(): | |
yield residue | |
def get_residue(self, gti: int): | |
"""Returns the residue at the given global index. | |
Args: | |
gti (int): Global residue index. | |
Returns: | |
ResidueView object corresponding to the index. | |
""" | |
if gti < 0: | |
raise Exception(f"negative residue index: {gti}") | |
off = 0 | |
for chain in self.chains(): | |
nr = chain.num_residues() | |
if gti < off + nr: | |
return chain.get_residue(gti - off) | |
off = off + nr | |
raise Exception( | |
f"residue index {gti} out of range for System, which has {self.num_residues()} residues" | |
) | |
def atoms(self): | |
"""Iterator of atoms in this System (generator function).""" | |
for chain in self.chains(): | |
for residue in chain.residues(): | |
for atom in residue.atoms(): | |
yield atom | |
def get_atom(self, aidx: int): | |
"""Returns the atom at the given global atom index. | |
Args: | |
gti (int): Global atom index. | |
Returns: | |
AtomView object corresponding to the index. | |
""" | |
if aidx < 0: | |
raise Exception(f"negative atom index: {aidx}") | |
off = 0 | |
for chain in self.chains(): | |
na = chain.num_atoms() | |
if aidx < off + na: | |
return chain.get_atom(aidx - off) | |
off = off + na | |
raise Exception( | |
f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" | |
) | |
def locations(self): | |
"""Iterator of atoms in this System (generator function).""" | |
for chain in self.chains(): | |
for residue in chain.residues(): | |
for atom in residue.atoms(): | |
for loc in atom.locations(): | |
yield loc | |
def _new_locations(self): | |
new_locs = self._locations.copy() | |
for li in range(len(new_locs)): | |
new_locs["coor"][li] = [np.nan] * 5 | |
return new_locs | |
def select(self, expression: str, left_associativity: bool = True): | |
"""Evalates the given selection expression and returns all atoms | |
involved in the result as a list of AtomView's. | |
Args: | |
expression (str): selection expression. | |
left_associativity (bool, optional): determines whether operators | |
in the expression are left-associative. | |
Returns: | |
List of AtomView's. | |
""" | |
val, selex_info = self._select( | |
expression, left_associativity=left_associativity | |
) | |
# make a list of AtomView | |
result = [selex_info["all_atoms"][i].atom for i in sorted(val)] | |
return result | |
def select_residues( | |
self, | |
expression: str, | |
gti: bool = False, | |
allow_unstructured=False, | |
left_associativity: bool = True, | |
): | |
"""Evalates the given selection expression and returns all residues with any | |
atoms involved in the result as a list of ResidueView's or list of gti's. | |
Args: | |
expression (str): selection expression. | |
gti (bool): if True (default is False), will return a list of gti | |
instead of a list of ResidueView's. | |
allow_unstructured (bool): If True (default is False), will allow | |
unstructured residues to be selected. | |
left_associativity (bool, optional): determines whether operators | |
in the expression are left-associative. | |
Returns: | |
List of ResidueView's or gti's (ints). | |
""" | |
val, selex_info = self._select( | |
expression, | |
unstructured=allow_unstructured, | |
left_associativity=left_associativity, | |
) | |
# make a list of ResidueView or gti's | |
if gti: | |
result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) | |
else: | |
residues = dict() | |
for i in val: | |
a = selex_info["all_atoms"][i] | |
residues[a.rix] = a.atom.residue | |
result = [residues[rix] for rix in sorted(residues.keys())] | |
return result | |
def select_chains( | |
self, expression: str, allow_unstructured=False, left_associativity: bool = True | |
): | |
"""Evalates the given selection expression and returns all chains with any | |
atoms involved in the result as a list of ChainView's. | |
Args: | |
expression (str): selection expression. | |
allow_unstructured (bool): If True (default is False), will allow | |
unstructured chains to be selected. | |
left_associativity (bool, optional): determines whether operators | |
in the expression are left-associative. | |
Returns: | |
List of ResidueView's or gti's (ints). | |
""" | |
val, selex_info = self._select( | |
expression, | |
unstructured=allow_unstructured, | |
left_associativity=left_associativity, | |
) | |
# make a list of ResidueView or gti's | |
chains = dict() | |
for i in val: | |
a = selex_info["all_atoms"][i] | |
chains[a.cix] = a.atom.chain | |
result = [chains[rix] for rix in sorted(chains.keys())] | |
return result | |
def _select( | |
self, | |
expression: str, | |
unstructured: bool = False, | |
left_associativity: bool = True, | |
): | |
# Build some helpful data structures to support _selex_eval | |
class MappableAtom: | |
atom: AtomView | |
aix: int | |
rix: int | |
cix: int | |
def __hash__(self) -> int: | |
return self.aix | |
# first collect all real atoms | |
all_atoms = [None] * self.num_atoms() | |
cix, rix, aix = 0, 0, 0 | |
for chain in self.chains(): | |
for residue in chain.residues(): | |
for atom in residue.atoms(): | |
all_atoms[aix] = MappableAtom(atom, aix, rix, cix) | |
aix = aix + 1 | |
# for residues that do not have atoms, add a dummy atom with no location | |
# or name; that way, we can still select the residue even though selection | |
# algebra fundamentally works on atoms | |
if residue.num_atoms() == 0: | |
view = DummyAtomView(residue) | |
view.dummy = True | |
# make more room at the end of the list since as this is an "extra" atom | |
all_atoms.append(None) | |
all_atoms[aix] = MappableAtom(view, aix, rix, cix) | |
aix = aix + 1 | |
rix = rix + 1 | |
cix = cix + 1 | |
_selex_info = {"all_atoms": all_atoms} | |
_selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) | |
# fmt: off | |
# make an expression tree object | |
tree = ExpressionTreeEvaluator( | |
["hyd", "all", "none"], | |
["not", "byres", "bychain", "first", "last", | |
"chain", "authchain", "segid", "namesel", "gti", "resix", "resid", | |
"authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], | |
["and", "or", "around", "saround"], | |
eval_function=partial(self._selex_eval, _selex_info), | |
left_associativity=left_associativity, | |
debug=False, | |
) | |
# fmt: on | |
# evaluate the expression | |
val = tree.evaluate(expression) | |
# if we are not looking to select unstructured residues, remove any dummy | |
# atoms. NOTE: making dummy atoms can still impact what structured atoms | |
# are selected (e.g., consider `saround` relative to an unstructured residue) | |
if not unstructured: | |
val = { | |
i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") | |
} | |
return val, _selex_info | |
def save_selection( | |
self, | |
expression: Optional[str] = None, | |
gti: Optional[List[int]] = None, | |
selname: str = "_default", | |
allow_unstructured=False, | |
left_associativity: bool = True, | |
): | |
"""Performs a selection on the System according to the given | |
selection string and saves the indices of residues involved in | |
the result (global template indices) under the given name. | |
Args: | |
expression (str): (optional) selection expression. | |
gti (list of int): (optional) list of gti to define selection expression | |
selname (str): selection name. | |
allow_unstructured (bool): If True (default is False), will allow | |
unstructured residues to be selected. | |
left_associativity (bool, optional): determines whether operators | |
in the expression are left-associative. | |
""" | |
if gti is not None: | |
if expression is not None: | |
warnings.warn( | |
f"Expression and gti are both not null, expression will be ignored" | |
f" and gti will be used!" | |
) | |
result = sorted(gti) | |
else: | |
result = self.select_residues( | |
expression, | |
allow_unstructured=allow_unstructured, | |
left_associativity=left_associativity, | |
gti=True, | |
) | |
# save the list of gti's | |
self._selections[selname] = result | |
def get_selected(self, selname: str = "_default"): | |
"""Returns the list of gti saved under the specified name. | |
Args: | |
selname (str): selection name. | |
Returns: | |
List of global template indices. | |
""" | |
if selname not in self._selections: | |
raise Exception( | |
f"selection by name '{selname}' does not exist in the System" | |
) | |
return self._selections[selname] | |
def has_selection(self, selname: str = "_default"): | |
"""Returns whether the given named selection exists. | |
Args: | |
selname (str): selection name. | |
Returns: | |
Whether the selection exists in the System. | |
""" | |
return selname in self._selections | |
def get_selection_names(self): | |
"""Returns the list of all currently stored named selections.""" | |
return list(self._selections.keys()) | |
def remove_selection(self, selname: str = "_default"): | |
"""Deletes the selection under the specified name. | |
Args: | |
selname (str): selection name. | |
""" | |
if selname not in self._selections: | |
raise Exception( | |
f"selection by name '{selname}' does not exist in the System" | |
) | |
del self._selections[selname] | |
def _selex_eval(self, _selex_info, op: str, left, right): | |
def _is_numeric(string: str) -> bool: | |
try: | |
float(string) | |
return True | |
except ValueError: | |
return False | |
def _is_int(string: str) -> bool: | |
try: | |
int(string) | |
return True | |
except ValueError: | |
return False | |
def _unpack_operands(operands, dests): | |
assert len(operands) == len(dests) | |
unpacked = [None] * len(operands) | |
succ = True | |
for i, (operand, dest) in enumerate(zip(operands, dests)): | |
if dest is None: | |
if operand is not None: | |
succ = False | |
break | |
elif dest == "result": | |
if not (isinstance(operand, dict) and "result" in operand): | |
succ = False | |
break | |
unpacked[i] = operand["result"] | |
elif dest == "string": | |
if not (len(operand) == 1 and isinstance(operand[0], str)): | |
succ = False | |
break | |
unpacked[i] = operand[0] | |
elif dest == "strings": | |
if not ( | |
isinstance(operand, list) | |
and all([isinstance(val, str) for val in operands]) | |
): | |
succ = False | |
break | |
unpacked[i] = operands | |
elif dest == "float": | |
if not (len(operand) == 1 and _is_numeric(operand[0])): | |
succ = False | |
break | |
unpacked[i] = float(operand[0]) | |
elif dest == "floats": | |
if not ( | |
isinstance(operand, list) | |
and all([_is_numeric(val) for val in operands]) | |
): | |
succ = False | |
break | |
unpacked[i] = [float(val) for val in operands] | |
elif dest == "range": | |
test = _parse_range(operand) | |
if test is None: | |
succ = False | |
break | |
unpacked[i] = test | |
elif dest == "int": | |
if not (len(operand) == 1 and _is_int(operand[0])): | |
succ = False | |
break | |
unpacked[i] = int(operand[0]) | |
elif dest == "ints": | |
if not ( | |
isinstance(operand, list) | |
and all([_is_int(val) for val in operands]) | |
): | |
succ = False | |
break | |
unpacked[i] = [int(val) for val in operands] | |
elif dest == "int_range": | |
test = _parse_int_range(operand) | |
if test is None: | |
succ = False | |
break | |
unpacked[i] = test | |
elif dest == "int_range_string": | |
test = _parse_int_range(operand, string=True) | |
if test is None: | |
succ = False | |
break | |
unpacked[i] = test | |
return unpacked, succ | |
def _parse_range(operands: list): | |
"""Parses range information given a list of operands that were originally separated | |
by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with | |
optional spaces allowed between operands.""" | |
if not ( | |
isinstance(operands, list) | |
and all([isinstance(opr, str) for opr in operands]) | |
): | |
return None | |
operand = "".join(operands) | |
if operand.startswith(">") or operand.startswith("<"): | |
if not _is_numeric(operand[1:]): | |
return None | |
num = float(operand[1:]) | |
if operand.startswith(">"): | |
test = lambda x, cut=num: x > cut | |
else: | |
test = lambda x, cut=num: x < cut | |
elif ":" in operand: | |
parts = operand.split(":") | |
if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): | |
return None | |
parts = [float(p) for p in parts] | |
test = lambda x, lims=parts: lims[0] < x < lims[1] | |
elif _is_numeric(operand): | |
target = float(operand) | |
test = lambda x, t=target: x == t | |
else: | |
return None | |
return test | |
def _parse_int_range(operands: list, string: bool = False): | |
"""Parses range of integers information given a list of operands that were | |
originally separated by spaces. Allowed range expressiosn are of the form: | |
`n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations | |
also allowed (e.g., "n+m+s+r-p+a").""" | |
if not ( | |
isinstance(operands, list) | |
and all([isinstance(opr, str) for opr in operands]) | |
): | |
return None | |
operand = "".join(operands) | |
operands = operand.split("+") | |
ranges = [] | |
for operand in operands: | |
m = re.fullmatch("(.*\d)-(.+)", operand) | |
if m: | |
if not all([_is_int(g) for g in m.groups()]): | |
return None | |
r = range(int(m.group(1)), int(m.group(2)) + 1) | |
ranges.append(r) | |
else: | |
if not _is_int(operand): | |
return None | |
if string: | |
ranges.append(set([operand])) | |
else: | |
ranges.append(set([int(operand)])) | |
if string: | |
ranges = [[str(x) for x in r] for r in ranges] | |
test = lambda x, ranges=ranges: any([x in r for r in ranges]) | |
return test | |
# evaluate expression and store result in list `result` | |
result = set() | |
if op in ("and", "or"): | |
(Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) | |
if not succ: | |
return None | |
if op == "and": | |
result = set(Si).intersection(set(Sj)) | |
else: | |
result = set(Si).union(set(Sj)) | |
elif op == "not": | |
(_, S), succ = _unpack_operands([left, right], [None, "result"]) | |
if not succ: | |
return None | |
result = _selex_info["all_indices_set"].difference(S) | |
elif op == "all": | |
(_, _), succ = _unpack_operands([left, right], [None, None]) | |
if not succ: | |
return None | |
result = _selex_info["all_indices_set"] | |
elif op == "none": | |
(_, _), succ = _unpack_operands([left, right], [None, None]) | |
if not succ: | |
return None | |
elif op == "around": | |
(S, rad), succ = _unpack_operands([left, right], ["result", "float"]) | |
if not succ: | |
return None | |
# Convert to numpy for distance calculation | |
atom_indices = np.asarray( | |
[ | |
ai.aix | |
for ai in _selex_info["all_atoms"] | |
for xi in ai.atom.locations() | |
] | |
) | |
X_i = np.asarray( | |
[ | |
[xi.x, xi.y, xi.z] | |
for ai in _selex_info["all_atoms"] | |
for xi in ai.atom.locations() | |
] | |
) | |
X_j = np.asarray( | |
[ | |
[xi.x, xi.y, xi.z] | |
for j in S | |
for xi in _selex_info["all_atoms"][j].atom.locations() | |
] | |
) | |
D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) | |
ix_match = (D <= rad).sum(1) > 0 | |
match_hits = atom_indices[ix_match] | |
result = set(match_hits.tolist()) | |
elif op == "saround": | |
(S, srad), succ = _unpack_operands([left, right], ["result", "int"]) | |
if not succ: | |
return None | |
for j in S: | |
aj = _selex_info["all_atoms"][j] | |
rj = aj.rix | |
for ai in _selex_info["all_atoms"]: | |
if aj.atom.residue.chain != ai.atom.residue.chain: | |
continue | |
ri = ai.rix | |
if abs(ri - rj) <= srad: | |
result.add(ai.aix) | |
elif op == "byres": | |
(_, S), succ = _unpack_operands([left, right], [None, "result"]) | |
if not succ: | |
return None | |
gtis = set() | |
for j in S: | |
gtis.add(_selex_info["all_atoms"][j].rix) | |
for a in _selex_info["all_atoms"]: | |
if a.rix in gtis: | |
result.add(a.aix) | |
elif op == "bychain": | |
(_, S), succ = _unpack_operands([left, right], [None, "result"]) | |
if not succ: | |
return None | |
cixs = set() | |
for j in S: | |
cixs.add(_selex_info["all_atoms"][j].cix) | |
for a in _selex_info["all_atoms"]: | |
if a.cix in cixs: | |
result.add(a.aix) | |
elif op in ("first", "last"): | |
(_, S), succ = _unpack_operands([left, right], [None, "result"]) | |
if not succ: | |
return None | |
if op == "first": | |
mi = min([_selex_info["all_atoms"][i].aix for i in S]) | |
else: | |
mi = max([_selex_info["all_atoms"][i].aix for i in S]) | |
result.add(mi) | |
elif op == "name": | |
(_, name), succ = _unpack_operands([left, right], [None, "string"]) | |
if not succ: | |
return None | |
for a in _selex_info["all_atoms"]: | |
if a.atom.name == name: | |
result.add(a.aix) | |
elif op in ("re", "hyd"): | |
if op == "re": | |
(_, regex), succ = _unpack_operands([left, right], [None, "string"]) | |
else: | |
(_, _), succ = _unpack_operands([left, right], [None, None]) | |
regex = "[0123456789]?H.*" | |
if not succ: | |
return None | |
ex = re.compile(regex) | |
for a in _selex_info["all_atoms"]: | |
if a.atom.name is not None and ex.fullmatch(a.atom.name): | |
result.add(a.aix) | |
elif op in ("chain", "authchain", "segid"): | |
(_, match_id), succ = _unpack_operands([left, right], [None, "string"]) | |
if not succ: | |
return None | |
if op == "chain": | |
prop = "cid" | |
elif op == "authchain": | |
prop = "authid" | |
elif op == "segid": | |
prop = "segid" | |
for a in _selex_info["all_atoms"]: | |
if getattr(a.atom.residue.chain, prop) == match_id: | |
result.add(a.aix) | |
elif op == "resid": | |
(_, test), succ = _unpack_operands([left, right], [None, "int_range"]) | |
if not succ: | |
return None | |
for a in _selex_info["all_atoms"]: | |
if test(a.atom.residue.num): | |
result.add(a.aix) | |
elif op in ("resname", "icode"): | |
(_, match_id), succ = _unpack_operands([left, right], [None, "string"]) | |
if not succ: | |
return None | |
if op == "resname": | |
prop = "name" | |
elif op == "icode": | |
prop = "icode" | |
for a in _selex_info["all_atoms"]: | |
if getattr(a.atom.residue, prop) == match_id: | |
result.add(a.aix) | |
elif op == "authresid": | |
(_, test), succ = _unpack_operands( | |
[left, right], [None, "int_range_string"] | |
) | |
if not succ: | |
return None | |
for a in _selex_info["all_atoms"]: | |
if test(a.atom.residue.authid): | |
result.add(a.aix) | |
elif op == "gti": | |
(_, test), succ = _unpack_operands([left, right], [None, "int_range"]) | |
if not succ: | |
return None | |
for a in _selex_info["all_atoms"]: | |
if test(a.rix): | |
result.add(a.aix) | |
elif op in ("x", "y", "z", "b", "occ"): | |
(_, test), succ = _unpack_operands([left, right], [None, "range"]) | |
if not succ: | |
return None | |
prop = op | |
if op == "b": | |
prop = "B" | |
for a in _selex_info["all_atoms"]: | |
for loc in a.atom.locations(): | |
if test(getattr(loc, prop)): | |
result.add(a.aix) | |
break | |
elif op == "namesel": | |
(_, selname), succ = _unpack_operands([left, right], [None, "string"]) | |
if not succ: | |
return None | |
if selname not in self._selections: | |
return None | |
gtis = set(self._selections[selname]) | |
for a in _selex_info["all_atoms"]: | |
if a.rix in gtis: | |
result.add(a.aix) | |
else: | |
return None | |
return {"result": result} | |
def __getitem__(self, chain_idx: int): | |
"""Returns the chain at the given index.""" | |
return self.get_chain(chain_idx) | |
def add_chain( | |
self, | |
cid: str, | |
segid: str = None, | |
authid: str = None, | |
entity_id: int = None, | |
auto_rename: bool = True, | |
at: int = None, | |
): | |
"""Adds a new chain to the System and returns a reference to it. | |
Args: | |
cid (str): Chain ID. | |
segid (str): Segment ID. | |
authid (str): Author chain ID. | |
entity_id (int, optional): Entity ID of the entity corresponding to this chain. | |
auto_rename (bool, optional): If True, will pick a unique chain ID if the specified | |
one clashes with an already existing chain. | |
Returns: | |
AtomView object corresponding to the index. | |
""" | |
if auto_rename: | |
cid = self._pick_unique_chain_name(cid) | |
if segid is None: | |
segid = cid | |
if authid is None: | |
authid = cid | |
if at is None: | |
at = self.num_chains() | |
self._chains.append({"cid": cid, "segid": segid, "authid": authid}) | |
self._chain_entities.append(entity_id) | |
else: | |
self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) | |
self._chain_entities.insert(at, entity_id) | |
return ChainView(at, self) | |
def _append_residue(self, name: str, num: int, authid: str, icode: str): | |
"""Add a new residue to the end this System. Internal method, do not use. | |
Args: | |
name (str): Residue name. | |
num (int): Residue number (i.e., residue ID). | |
authid (str): Author residue ID. | |
icode (str): Insertion code. | |
Returns: | |
Global index to the newly added residue. | |
""" | |
self._chains.append_child( | |
{"name": name, "resnum": num, "authresid": authid, "icode": icode} | |
) | |
return len(self._residues) - 1 | |
def _append_atom( | |
self, | |
name: str, | |
het: bool, | |
x: float = None, | |
y: float = None, | |
z: float = None, | |
occ: float = None, | |
B: float = None, | |
alt: str = None, | |
): | |
"""Adds a new atom to the end of this System. Internal method, do not use. | |
Args: | |
name (str): Atom name. | |
het (bool): Whether it is a hetero-atom. | |
x, y, z (float): Atom location coordinates. | |
occ (float): Occupancy. | |
B (float): B-factor. | |
alt (str): Alternative position character. | |
Returns: | |
Global index to the newly added atom. | |
""" | |
self._residues.append_child({"name": name, "het": het}) | |
return len(self._atoms) - 1 | |
def _append_location(self, x, y, z, occ, B, alt): | |
"""Adds a location to the end of this System. Internal method, do not use. | |
Args: | |
x, y, z (float): coordinates of the location. | |
occ (float): occupancy for the location. | |
B (float): B-factor for the location. | |
alt (str): alternative location character. | |
Returns: | |
Global index to the newly added location. | |
""" | |
self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) | |
return len(self._locations) - 1 | |
def add_new_entity(self, entity: SystemEntity, chain_indices: list): | |
"""Adds a new entity to the list contained within the System and | |
assigns chains with provided indices to this entity. | |
Args: | |
entity (SystemEntity): The new entity to add to the System. | |
chain_indices (list): a list of Chain indices for chains to | |
assign to this entity. | |
Returns: | |
The entity ID of the newly added entity. | |
""" | |
new_entity_id = len(self._entities) | |
while new_entity_id in self._entities: | |
new_entity_id = new_entity_id + 1 | |
self._entities[new_entity_id] = entity | |
for ci in chain_indices: | |
self._chain_entities[ci] = new_entity_id | |
return new_entity_id | |
def delete_entity(self, entity_id: int): | |
"""Deletes the entity with the specified ID. Takes care to unlink | |
any chains belonging to this entity from it. | |
Args: | |
entity_id (int): Entity ID. | |
""" | |
chain_indices = self.get_chains_of_entity(entity_id) | |
for ci in chain_indices: | |
self._chain_entities[ci] = None | |
del self._entities[entity_id] | |
def _pick_unique_chain_name(self, hint: str, verbose=False): | |
goodNames = list( | |
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" | |
) | |
taken = set([chain.cid for chain in self.chains()]) | |
# first try to pick a conforming chain name (single alpha-numeric character) | |
for cid in [hint] + goodNames: | |
if cid not in taken: | |
return cid | |
if verbose: | |
warnings.warn( | |
"ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" | |
) | |
# if that does not work, get a longer chain name | |
for i in range(-1, len(goodNames)): | |
# first try to expand the original chain ID | |
base = hint if i < 0 else goodNames[i : i + 1] | |
if base == "": | |
continue | |
for k in range(1000): | |
longName = f"{base}{k}" | |
if longName not in taken: | |
return longName | |
raise Exception( | |
"ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" | |
) | |
def _ensure_unique_entity(self, ci: int): | |
"""Any time we need to update some piece of information about a Chain that | |
relates to its entity (e.g., sequence info or hetero info), we cannot just | |
update it directly because other Chains may be pointing to the same entity. | |
This function checks for any other chains pointing to the same entity as the | |
specified chain, and (if so) assigns the given chain to a new (duplicate) | |
entity and returns its new ID. This clears the way updates of this Chain's entity. | |
Args: | |
ci (int): Index of the Chain for which we are trying to update | |
entity information. | |
Returns: | |
entity ID for either a newly created entity mapped to the Chain or its | |
initial entity ID if no other chains point to the same entity. | |
""" | |
chain = self.get_chain(ci) | |
entity_id = chain.get_entity_id() | |
if entity_id is None: | |
return entity_id | |
# see if any other chains point to the same entity | |
unique = True | |
for other in self.chains(): | |
if (other != chain) and (entity_id == other.get_entity_id()): | |
unique = False | |
break | |
if unique: | |
return entity_id | |
# if so, we need to make a new entity and point the chain to it | |
new_entity = copy.deepcopy(self._entities[entity_id]) | |
new_entity_id = self.add_new_entity(new_entity, [ci]) | |
return new_entity_id | |
def num_chains(self): | |
"""Returns the number of chains in the System.""" | |
return len(self._chains) | |
def num_chains_of_entity(self, entity_id: int): | |
"""Returns the number of chains of a given entity. | |
Args: | |
entity_id (int): Entity ID. | |
Returns: | |
number of chains mapping to the entity. | |
""" | |
return sum([entity_id == eid for eid in self._chain_entities]) | |
def num_molecules_of_entity(self, entity_id: int): | |
if self._entities[entity_id].is_polymer(): | |
return self.num_chains_of_entity(entity_id) | |
cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] | |
return sum([self[ci].num_residues() for ci in cixs]) | |
def num_entities(self): | |
"""Returns the number of entities in the System.""" | |
return len(self._entities) | |
def num_residues(self): | |
"""Returns the number of residues in the System.""" | |
return len(self._residues) | |
def num_structured_residues(self): | |
"""Returns the number of residues with any structure information.""" | |
return sum([chain.num_structured_residues() for chain in self.chains()]) | |
def num_atoms(self): | |
"""Returns the number of atoms in the System.""" | |
return len(self._atoms) | |
def num_structured_atoms(self): | |
"""Returns the number of atoms with any location information.""" | |
num = 0 | |
for chain in self.chains(): | |
for residue in chain.residues(): | |
for atom in residue.atoms(): | |
num = num + (atom.num_locations() > 0) | |
return num | |
def num_atom_locations(self): | |
"""Returns the number of atom locations. Note that an atom can have | |
multiple (alternative) locations and this functions counts all. | |
""" | |
return len(self._locations) | |
def num_models(self): | |
"""Returns the number of models in the System. A model is effectively | |
a conformation of the molecular system and each System object can have | |
an arbitrary number of different conformations. | |
""" | |
return len(self._extra_models) + 1 | |
def swap_model(self, i: int): | |
"""Swaps the model at index `i` with the current model (i.e., the | |
model at index 0). | |
Args: | |
i (int): Model index | |
""" | |
if i == 0: | |
return | |
if i < 0 or i >= self.num_models(): | |
raise Exception(f"model index {i} out of range") | |
tmp = self._locations | |
self._locations = self._extra_models[i - 1] | |
self._extra_models[i - 1] = tmp | |
def add_model(self, other: System): | |
"""Adds a new model to the System by taking the current model from the | |
specified System `other`. Note that `other` and the present System | |
must have the same number of atom locations of matching atom and | |
residue names. | |
Args: | |
other (System): The System to take the model from. | |
""" | |
if len(self._locations) != len(other._locations): | |
raise Exception( | |
f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" | |
) | |
self._extra_models.append(other._locations.copy()) | |
self._extra_models[-1].set_parent(self._atoms) | |
def add_model_from_X(self, X: torch.Tensor): | |
"""Adds a new model to the System with given coordinates. | |
Args: | |
X (torch.Tensor): Coordinate tensor of shape | |
`(residues, atoms (4 or 14), coordinates (3))` | |
""" | |
if len(self._locations) != X.numel() / 3: | |
raise Exception( | |
f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" | |
) | |
X = X.detach().cpu() | |
self._extra_models.append(self._locations.copy()) | |
self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) | |
return None | |
def num_assemblies(self): | |
"""Returns the number of biological assemblies defined in this System.""" | |
return len(self._assembly_info.assemblies) | |
def from_CIF_string(cif_string: str): | |
"""Initializes and returns a System object from a CIF string.""" | |
import io | |
f = io.StringIO(cif_string) | |
return System._read_cif(f)[0] | |
def from_CIF(input_file: str): | |
"""Initializes and returns a System object from a CIF file.""" | |
f = open(input_file, "r") | |
return System._read_cif(f)[0] | |
def _read_cif(f, strict=False): | |
def _warn_or_error(strict: bool, msg: str): | |
if strict: | |
raise Exception(msg) | |
else: | |
warnings.warn(msg) | |
is_read = { | |
part: False for part in ["coors", "entities", "sequence", "entity_poly"] | |
} | |
category = "" | |
(in_loop, success) = (False, True) | |
peeked = sp.PeekedLine("", 0) | |
# number of molecules per entity prescribed in the CIF file | |
num_of_mols = dict() | |
system = System("system") | |
while sp.peek_line(f, peeked): | |
if peeked.line.startswith("#"): | |
# nothing to do, skip comments | |
sp.advance(f, peeked) | |
elif peeked.line.startswith("data_"): | |
# nothing to do, this is the beginning of the file | |
sp.advance(f, peeked) | |
elif peeked.line.startswith("loop_"): | |
in_loop = True | |
category = "" | |
sp.advance(f, peeked) | |
else: | |
(cat, name, val) = ("", "", "") | |
if peeked.line.startswith("_"): | |
(cat, name, val) = sp.star_item_parse(peeked.line) | |
if cat != category: | |
if category != "": | |
in_loop = False | |
category = cat | |
if (cat == "_entry") and (name == "id"): | |
if val != "": | |
system.name = val | |
sp.advance(f, peeked) | |
elif cat == "_entity_poly": | |
if is_read["entity_poly"]: | |
raise Exception("entity_poly block encountered multiple times") | |
tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) | |
for row in tab: | |
ent_id = int(row[0]) - 1 | |
if ent_id not in system._entities: | |
system._entities[ent_id] = SystemEntity( | |
None, None, row[1], None, None | |
) | |
else: | |
system._entities[ent_id]._polymer_type = row[1] | |
is_read["entity_poly"] = True | |
elif cat == "_entity": | |
if is_read["entities"]: | |
raise Exception( | |
f"entities block encountered multiple times: {peeked.line}" | |
) | |
tab = sp.star_read_data( | |
f, | |
["id", "type", "pdbx_description", "pdbx_number_of_molecules"], | |
in_loop, | |
) | |
for row in tab: | |
ent_id = int(row[0]) - 1 | |
if ent_id not in system._entities: | |
system._entities[ent_id] = SystemEntity( | |
row[1], row[2], None, None, None | |
) | |
else: | |
system._entities[ent_id]._type = row[1] | |
system._entities[ent_id]._desc = row[2] | |
if row[3].isnumeric(): | |
num_of_mols[ent_id] = int(row[3]) | |
is_read["entities"] = True | |
elif cat == "_entity_poly_seq": | |
if is_read["sequence"]: | |
raise Exception(f"sequence block encountered multiple times") | |
tab = sp.star_read_data( | |
f, ["entity_id", "num", "mon_id", "hetero"], in_loop | |
) | |
(seq, het) = ([], []) | |
for i in range(len(tab)): | |
# accumulate sequence information until we reach the end or a new entity ID | |
seq.append(tab[i][2]) | |
het.append(tab[i][3].startswith("y")) | |
if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): | |
ent_id = int(tab[i][0]) - 1 | |
system._entities[ent_id]._seq = seq | |
system._entities[ent_id]._het = het | |
(seq, het) = ([], []) | |
is_read["sequence"] = True | |
elif cat == "_pdbx_struct_assembly": | |
tab = sp.star_read_data(f, ["id", "details"], in_loop) | |
for row in tab: | |
system._assembly_info.assemblies[row[0]] = {"details": row[1]} | |
elif cat == "_pdbx_struct_assembly_gen": | |
tab = sp.star_read_data( | |
f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop | |
) | |
for row in tab: | |
assembly = system._assembly_info.assemblies[row[0]] | |
if "instructions" not in assembly: | |
assembly["instructions"] = [] | |
chain_ids = [cid.strip() for cid in row[2].strip().split(",")] | |
assembly["instructions"].append( | |
{"oper_expression": row[1], "chains": chain_ids} | |
) | |
elif cat == "_pdbx_struct_oper_list": | |
tab = sp.star_read_data( | |
f, | |
[ | |
"id", | |
"type", | |
"name", | |
"matrix[1][1]", | |
"matrix[1][2]", | |
"matrix[1][3]", | |
"matrix[2][1]", | |
"matrix[2][2]", | |
"matrix[2][3]", | |
"matrix[3][1]", | |
"matrix[3][2]", | |
"matrix[3][3]", | |
"vector[1]", | |
"vector[2]", | |
"vector[3]", | |
], | |
in_loop, | |
) | |
for row in tab: | |
system._assembly_info.operations[ | |
row[0] | |
] = SystemAssemblyInfo.make_operation( | |
row[1], row[2], row[3:12], row[12:15] | |
) | |
elif cat == "_generate_selections": | |
tab = sp.star_read_data(f, ["name", "indices"], in_loop) | |
for row in tab: | |
system._selections[row[0]] = [ | |
int(gti.strip()) for gti in row[1].strip().split() | |
] | |
elif cat == "_generate_labels": | |
tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) | |
for row in tab: | |
if row[0] not in system._labels: | |
system._labels[row[0]] = dict() | |
idx = int(row[1]) | |
system._labels[row[0]][int(row[1])] = row[2] | |
elif cat == "_atom_site": | |
if is_read["coors"]: | |
raise Exception(f"ATOM_SITE block encountered multiple times") | |
# this section is special as it cannot have quoted blocks (because some atom names have the single quote character in them) | |
tab = sp.star_read_data( | |
f, | |
[ | |
"group_PDB", | |
"id", | |
"label_atom_id", | |
"label_alt_id", | |
"label_comp_id", | |
"label_asym_id", | |
"label_entity_id", | |
"label_seq_id", | |
"pdbx_PDB_ins_code", | |
"Cartn_x", | |
"Cartn_y", | |
"Cartn_z", | |
"occupancy", | |
"B_iso_or_equiv", | |
"pdbx_PDB_model_num", | |
"auth_seq_id", | |
"auth_asym_id", | |
], | |
in_loop, | |
cols=False, | |
has_blocks=False, | |
) | |
groupCol = 0 | |
idxCol = 1 | |
atomNameCol = 2 | |
altIdCol = 3 | |
resNameCol = 4 | |
chainNameCol = 5 | |
entityIdCol = 6 | |
seqIdCol = 7 | |
insCodeCol = 8 | |
xCol = 9 | |
yCol = 10 | |
zCol = 11 | |
occCol = 12 | |
bCol = 13 | |
modelCol = 14 | |
authSeqIdCol = 15 | |
authChainNameCol = 16 | |
( | |
atom, | |
residue, | |
chain, | |
prev_chain, | |
prev_residue, | |
prev_atom, | |
prev_entity_id, | |
prev_seq_id, | |
prev_auth_seq_id, | |
) = (None, None, None, None, None, None, None, None, None) | |
loc = None # first model location | |
aIdx = 0 | |
for i in range(len(tab)): | |
if i == 0: | |
first_model = tab[i][modelCol] | |
prev_model = first_model | |
elif (tab[i][modelCol] != prev_model) or ( | |
tab[i][modelCol] != first_model | |
): | |
if tab[i][modelCol] != prev_model: | |
aIdx = 0 | |
num_loc = system.num_atom_locations() | |
# setting the default value to None allows us to tell when the | |
# same coordinate in a subsequent model was not specified (e.g., | |
# when an alternative coordinate is not specified) | |
system._extra_models.append(system._new_locations()) | |
prev_model = tab[i][modelCol] | |
locations_generator = (l for l in system.locations()) | |
loc = next(locations_generator, None) | |
if aIdx >= num_loc: | |
_warn_or_error( | |
strict, | |
f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", | |
) | |
success = False | |
system._extra_models.clear() | |
break | |
# check that the atoms correspond | |
same = ( | |
(loc is not None) | |
and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) | |
and (tab[i][resNameCol] == loc.atom.residue.name) | |
and ( | |
int( | |
sp.star_value( | |
tab[i][seqIdCol], loc.atom.residue.num | |
) | |
) | |
== loc.atom.residue.num | |
) | |
and (tab[i][atomNameCol] == loc.atom.name) | |
) | |
if not same: | |
_warn_or_error( | |
strict, | |
f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", | |
) | |
success = False | |
system._extra_models.clear() | |
break | |
coor = [ | |
float(tab[i][c]) | |
for c in [xCol, yCol, zCol, occCol, bCol] | |
] | |
system._extra_models[-1]["coor"][aIdx] = coor | |
system._extra_models[-1]["alt"][aIdx] = sp.star_value( | |
tab[i][altIdCol], " " | |
)[0] | |
aIdx = aIdx + 1 | |
continue | |
# new chain? | |
if ( | |
(chain is None) | |
or (prev_entity_id != tab[i][entityIdCol]) | |
or (tab[i][chainNameCol] != chain.cid) | |
): | |
authid = ( | |
tab[i][authChainNameCol] | |
if (tab[i][authChainNameCol] != "") | |
else tab[i][chainNameCol] | |
) | |
chain = system.add_chain( | |
tab[i][chainNameCol], | |
tab[i][chainNameCol], | |
authid, | |
int(tab[i][entityIdCol]) - 1, | |
) | |
# new residue | |
if ( | |
(residue is None) | |
or (chain != prev_chain) | |
or (prev_seq_id != tab[i][seqIdCol]) | |
or (prev_auth_seq_id != tab[i][authSeqIdCol]) | |
): | |
resnum = ( | |
int(tab[i][seqIdCol]) | |
if sp.star_value_defined(tab[i][seqIdCol]) | |
else chain.num_residues() + 1 | |
) | |
ri = system._append_residue( | |
tab[i][resNameCol], | |
resnum, | |
tab[i][authSeqIdCol], | |
sp.star_value(tab[i][insCodeCol], " ")[0], | |
) | |
residue = ResidueView(ri, chain) | |
# usually will be a new atom, but may be an alternative coordinate | |
# TODO: this only covers cases where alternative atom coordinates are listed next to each other, | |
# but sometimes they are not -- need to actively use the altIdCol information | |
x, y, z, occ, B = [ | |
float(tab[i][col]) | |
for col in [xCol, yCol, zCol, occCol, bCol] | |
] | |
alt = sp.star_value(tab[i][altIdCol], " ")[0] | |
if ( | |
(atom is None) | |
or (residue != prev_residue) | |
or (tab[i][atomNameCol] != atom.name) | |
): | |
ai = system._append_atom( | |
tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") | |
) | |
atom = AtomView(ai, residue) | |
system._append_location(x, y, z, occ, B, alt) | |
prev_chain = chain | |
prev_residue = residue | |
prev_entity_id = tab[i][entityIdCol] | |
prev_seq_id = tab[i][seqIdCol] | |
prev_auth_seq_id = tab[i][authSeqIdCol] | |
is_read["coors"] = True | |
else: | |
sp.advance(f, peeked) | |
# fill in any "missing" polymer chains (e.g., chains with no visible density | |
# or known structure, but which are nevertheless present) | |
for entity_id in num_of_mols: | |
if system._entities[entity_id].is_polymer(): | |
rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) | |
for _ in range(rem): | |
# the chain will get renamed to avoid clashes | |
system.add_chain("A", None, None, entity_id, auto_rename=True) | |
# fill in missing residues (i.e., those that exist in the entity but not | |
# the atomistic section) | |
for chain in system.chains(): | |
entity = chain.get_entity() | |
if not entity.is_polymer() or entity._seq is None: | |
continue | |
k = 0 | |
for ri in range(len(entity._seq)): | |
cur_res = chain.get_residue(k) if k < chain.num_residues() else None | |
if (cur_res is None) or (cur_res.num > ri + 1): | |
# insert new residue to correspond to entity monomer with index ri | |
chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) | |
elif cur_res.num < ri + 1: | |
_warn_or_error( | |
strict, f"inconsistent numbering in chain {chain.cid}" | |
) | |
break | |
k = k + 1 | |
# do an entity-to-structure sequence check for all chains | |
for chain in system.chains(): | |
if not chain.check_sequence(): | |
_warn_or_error( | |
strict, | |
f"chain {chain.cid} did not pass sequence check against corresponding entity", | |
) | |
system._reindex() | |
return system, success | |
def from_PDB_string(cif_string: str, options=""): | |
"""Initializes and returns a System object from a PDB string.""" | |
import io | |
f = io.StringIO(cif_string) | |
sys = System._read_pdb(f, options) | |
sys.name = "from_string" | |
return sys | |
def from_PDB(input_file: str, options=""): | |
"""Initializes and returns a System object from a PDB file.""" | |
f = open(input_file, "r") | |
sys = System._read_pdb(f, options) | |
sys.name = input_file | |
return sys | |
def _read_pdb(f, strict=False, options=""): | |
def _to_float(strval, default): | |
v = default | |
try: | |
v = float(strval) | |
except: | |
pass | |
return v | |
last_resnum = None | |
last_resname = None | |
last_icode = None | |
last_chain_id = None | |
last_alt = None | |
chain = None | |
residue = None | |
# flag to indicate that chain terminus was reached. Initialize to true so as to create a new chain upon reading the first atom. | |
ter = True | |
# various parsing options (the wonders of dealing with the good-old PDB format) | |
# and any user-specified overrides | |
options = options.upper() | |
# use segment IDs to name chains instead of chain IDs? (useful when the latter | |
# are absent OR when too many chains, so need multi-letter names) | |
usese_gid = True if ("USESEGID" in options) else False | |
# the PDB file was written by CHARMM (slightly different format) | |
charmm_format = True if ("CHARMM" in options) else False | |
# upon reading, convert from all-hydrogen topology (param22 and higher) to | |
# the CHARMM19 united-atom topology (matters for HIS protonation states) | |
charmm19_format = True if ("CHARMM19" in options) else False | |
# make sure chain IDs are unique, even if they are not unique in the read file | |
uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True | |
# rename CD in ILE to CD1 (as is standard in PDB, but not some MM packages) | |
fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True | |
# consequtive residues that differ only in their insertion code will be treated | |
# as separate residues | |
icodes_as_sep_res = True | |
# if true, will not pay attention to TER lines in deciding when chains end/begin | |
ignore_ter = True if ("IGNORE-TER" in options) else False | |
# report various warnings when weird things are found and fixed? | |
verbose = False if ("QUIET" in options) else True | |
chains_to_rename = [] | |
# read line by line and build the System | |
system = System("system") | |
all_system = system | |
model_index = 0 | |
for line in f: | |
line = line.strip() | |
if line.startswith("ENDMDL"): | |
# merge the last read model with the overall System | |
if model_index: | |
try: | |
all_system.add_model(system) | |
except Exception as e: | |
warnings.warn( | |
f"error when adding model {model_index + 1}: {str(e)}, skipping model..." | |
) | |
system = System("system") | |
model_index = model_index + 1 | |
last_resnum = None | |
last_resname = None | |
last_icode = None | |
last_chain_id = None | |
last_alt = None | |
chain = None | |
residue = None | |
continue | |
if line.startswith("END"): | |
break | |
if line.startswith("MODEL"): | |
# new model | |
continue | |
if line.startswith("TER") and not ignore_ter: | |
ter = True | |
continue | |
if not (line.startswith("ATOM") or line.startswith("HETATM")): | |
continue | |
""" Now read atom record. Sometimes PDB lines are too short (if they do not contain some | |
of the last optional columns). We don't want to read past the end of the string!""" | |
line += " " * 100 | |
atominx = int(line[6:11]) | |
atomname = line[12:16].strip() | |
alt = line[16:17] | |
resname = line[17:21].strip() | |
chain_id = line[21:22].strip() | |
resnum = int(line[23:27]) if charmm_format else int(line[22:26]) | |
icode = " " if charmm_format else line[26:27] | |
x = float(line[30:38]) | |
y = float(line[38:46]) | |
z = float(line[46:54]) | |
seg_id = line[72:76].strip() | |
B = _to_float(line[60:66], 0.0) | |
occ = _to_float(line[54:60], 0.0) | |
het = line.startswith("HETATM") | |
# use segment ID's instead of chain ID's? | |
if usese_gid: | |
chain_id = seg_id | |
elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): | |
# use first character of segment name if no chain name is specified, a segment ID | |
# is specified, and the latter starts with an alphanumeric character | |
chain_id = seg_id[0:1] | |
# create a new chain object, if necessary | |
if (chain_id != last_chain_id) or ter: | |
cid_used = system.get_chain_by_id(chain_id) is not None | |
chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) | |
# non-unique chains will be automatically renamed (unless the user specified not to rename chains), BUT we need to | |
# remember the name that was actually read, since this name is what will be used to determine when the next chain comes | |
if uniq_chain_ids and cid_used: | |
chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" | |
if model_index == 0: | |
chains_to_rename.append(chain) | |
if verbose: | |
warnings.warn( | |
"chain name '" | |
+ chain_id | |
+ "' was repeated while reading, will rename at the end..." | |
) | |
# start to count residue numbers in this chain | |
last_resnum = None | |
last_resname = None | |
ter = False | |
if charmm19_format: | |
if resname == "HSE": | |
resname = "HSD" # neutral HIS, proton on ND1 | |
if resname == "HSD": | |
resname = "HIS" # neutral HIS, proton on NE2 | |
if resname == "HSC": | |
resname = "HSP" # doubley-protonated +1 HIS | |
# many PDB files in the Protein Data Bank call the delta carbon of isoleucine CD1, but | |
# the convention in basically all MM packages is to call it CD, since there is only one | |
if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): | |
atomname = "CD1" | |
# if necessary, make a new residue | |
really_new_atom = True # is this a truely new atom, as opposed to an alternative position? | |
if ( | |
(resnum != last_resnum) | |
or (resname != last_resname) | |
or (icodes_as_sep_res and (icode != last_icode)) | |
): | |
# this corresponds to a case, where the alternative location flag is being used to | |
# designate two (or more) different possible amino acids at a particular position | |
# (e.g., where the density is not clear to assign one). In this case, we shall keep | |
# only the first option, because we don't know any better. But we need to separate | |
# this from the case, where we end up here because we are trying to separate residues | |
# by insertion code. | |
if ( | |
(resnum == last_resnum) | |
and (resname != last_resname) | |
and (alt != last_alt) | |
and (not icodes_as_sep_res or (icode == last_icode)) | |
): | |
continue | |
residue = chain.add_residue( | |
resname, chain.num_residues() + 1, str(resnum), icode[0] | |
) | |
elif alt != " ": | |
# if this is not a new residue AND the alternative location flag is specified, | |
# figure out if another location for this atom has already been given. If not, | |
# then treat this as the "primary" location, and whatever other locations | |
# are specified will be treated as alternatives. | |
a = residue.find_atom(atomname) | |
if a is not None: | |
really_new_atom = False | |
a.add_location(x, y, z, occ, B, alt[0]) | |
# if necessary, make a new atom | |
if really_new_atom: | |
a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) | |
# remember previous values for determining whether something interesting happens next | |
last_resnum = resnum | |
last_icode = icode | |
last_resname = resname | |
last_chain_id = chain_id | |
last_alt = alt | |
# take care of renaming any chains that had duplicate IDs | |
for chain in chains_to_rename: | |
parts = chain.cid.split("|") | |
assert ( | |
len(parts) > 1 | |
), "something went wrong when renaming a chain at the end of reading" | |
name = all_system._pick_unique_chain_name(parts[0], verbose) | |
chain.cid = name | |
if len(name): | |
chain.segid = name | |
# add an entity for each chain (copy from chain information) | |
for ci, chain in enumerate(all_system.chains()): | |
seq = [None] * chain.num_residues() | |
het = [None] * chain.num_residues() | |
for ri, res in enumerate(chain.residues()): | |
seq[ri] = res.name | |
het[ri] = all(a.het for a in res.atoms()) | |
entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) | |
entity = SystemEntity( | |
entity_type, f"chain {chain.cid}", polymer_type, seq, het | |
) | |
all_system.add_new_entity(entity, [ci]) | |
return all_system | |
def to_CIF(self, output_file: str): | |
"""Writes the System to a CIF file.""" | |
f = open(output_file, "w") | |
self._write_cif(f) | |
def to_CIF_string(self): | |
"""Returns a CIF string representing the System.""" | |
import io | |
f = io.StringIO("") | |
self._write_cif(f) | |
cif_str = f.getvalue() | |
f.close() | |
return cif_str | |
def _write_cif(self, f): | |
# fmt: off | |
_specials_atom_names = [ | |
"MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", | |
"BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", | |
"AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" | |
] | |
# fmt: on | |
_ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] | |
def _guess_type(atom_name, res_name): | |
if len(atom_name) > 0 and atom_name[0] == '"': | |
atom_name = atom_name.replace('"', "") | |
if atom_name[:2] in _specials_atom_names: | |
return atom_name[:2] | |
else: | |
if atom_name in _ambiguous_atom_names and res_name == atom_name: | |
return atom_name | |
elif atom_name == "UNK": | |
return "X" | |
return atom_name[:1] | |
entry_id = self.name.strip() | |
if entry_id == "": | |
entry_id = "system" | |
f.write( | |
"data_GNR8\n#\n" | |
+ "_entry.id " | |
+ sp.star_string_escape(entry_id) | |
+ "\n#\n" | |
) | |
# write entities table | |
sp.star_loop_header_write( | |
f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] | |
) | |
for id, entity in self._entities.items(): | |
num_mol = self.num_molecules_of_entity(id) | |
f.write( | |
f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" | |
) | |
f.write("#\n") | |
# write entity polymer sequences | |
sp.star_loop_header_write( | |
f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] | |
) | |
for id, entity in self._entities.items(): | |
if entity._seq is not None: | |
for i, (res, het) in enumerate(zip(entity._seq, entity._het)): | |
f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") | |
f.write("#\n") | |
# write entity polymer types | |
sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) | |
for id, entity in self._entities.items(): | |
if entity.is_polymer(): | |
f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") | |
f.write("#\n") | |
if self.num_assemblies(): | |
assemblies = self._assembly_info.assemblies | |
ops = self._assembly_info.operations | |
# assembly info table | |
sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) | |
for assembly_id, assembly in assemblies.items(): | |
f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") | |
f.write("#\n") | |
# assembly generation instructions table | |
sp.star_loop_header_write( | |
f, | |
"_pdbx_struct_assembly_gen", | |
["assembly_id", "oper_expression", "asym_id_list"], | |
) | |
for assembly_id, assembly in assemblies.items(): | |
for instruction in assembly["instructions"]: | |
chain_list = ",".join([str(ci) for ci in instruction["chains"]]) | |
f.write( | |
f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" | |
) | |
f.write("#\n") | |
# symmetry operations table | |
sp.star_loop_header_write( | |
f, | |
"_pdbx_struct_oper_list", | |
[ | |
"id", | |
"type", | |
"name", | |
"matrix[1][1]", | |
"matrix[1][2]", | |
"matrix[1][3]", | |
"matrix[2][1]", | |
"matrix[2][2]", | |
"matrix[2][3]", | |
"matrix[3][1]", | |
"matrix[3][2]", | |
"matrix[3][3]", | |
"vector[1]", | |
"vector[2]", | |
"vector[3]", | |
], | |
) | |
for op_id, op in ops.items(): | |
f.write( | |
f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " | |
) | |
f.write( | |
f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " | |
) | |
f.write( | |
f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " | |
) | |
f.write( | |
f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " | |
) | |
f.write( | |
f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" | |
) | |
f.write("#\n") | |
sp.star_loop_header_write( | |
f, | |
"_atom_site", | |
[ | |
"group_PDB", | |
"id", | |
"label_atom_id", | |
"label_alt_id", | |
"label_comp_id", | |
"label_asym_id", | |
"label_entity_id", | |
"label_seq_id", | |
"pdbx_PDB_ins_code", | |
"Cartn_x", | |
"Cartn_y", | |
"Cartn_z", | |
"occupancy", | |
"B_iso_or_equiv", | |
"pdbx_PDB_model_num", | |
"auth_seq_id", | |
"auth_asym_id", | |
"type_symbol", | |
], | |
) | |
idx = -1 | |
for model_index in range(self.num_models()): | |
self.swap_model(model_index) | |
for chain, entity_id in zip(self.chains(), self._chain_entities): | |
authchainid = ( | |
chain.authid if sp.star_value_defined(chain.authid) else chain.cid | |
) | |
for residue in chain.residues(): | |
authresid = ( | |
residue.authid | |
if sp.star_value_defined(residue.authid) | |
else residue.num | |
) | |
for atom in residue.atoms(): | |
idx = idx + 1 | |
for location in atom.locations(): | |
# this means this coordinate was not specified for this model | |
if not location.defined(): | |
continue | |
coor = location.coor_info | |
f.write("HETATM " if atom.het else "ATOM ") | |
f.write( | |
f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " | |
) | |
entity_id_str = ( | |
f"{entity_id + 1}" if entity_id is not None else "?" | |
) | |
f.write( | |
f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " | |
) | |
f.write( | |
f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " | |
) | |
f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") | |
f.write( | |
f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" | |
) | |
self.swap_model(model_index) | |
f.write("#\n") | |
# write out selections | |
if len(self._selections): | |
sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) | |
for name, indices in self._selections.items(): | |
f.write( | |
f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" | |
) | |
f.write("#\n") | |
# write out labels | |
if len(self._labels): | |
sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) | |
for category, label_dict in self._labels.items(): | |
for gti, label in label_dict.items(): | |
f.write( | |
f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" | |
) | |
f.write("#\n") | |
def to_PDB(self, output_file: str, options: str = ""): | |
"""Writes the System to a PDB file. | |
Args: | |
output_file (str): output PDB file name. | |
options (str, optional): a string specifying various options for | |
the writing process. The presence of certain sub-strings will | |
trigger specific behaviors. Currently recognized sub-strings | |
include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", | |
"NOTER", and "NOALT". This option is case-insensitive. | |
""" | |
f = open(output_file, "w") | |
self._write_pdb(f, options) | |
def to_PDB_string(self, options=""): | |
"""Writes the System to a PDB string. The options string has the same | |
interpretation as with System::toPDB. | |
""" | |
import io | |
f = io.StringIO("") | |
self._write_pdb(f, options) | |
cif_str = f.getvalue() | |
f.close() | |
return cif_str | |
def _write_pdb(self, f, options=""): | |
def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): | |
if rn is None: | |
rn = loc.atom.residue.name | |
if ri is None: | |
ri = loc.atom.residue.num | |
if an is None: | |
an = loc.atom.name | |
icode = loc.atom.residue.icode | |
cid = loc.atom.residue.chain.cid | |
if len(cid) > 1: | |
cid = cid[0] | |
segid = loc.atom.residue.chain.segid | |
if len(segid) > 4: | |
segid = segid[0:4] | |
# atom name placement is different when it is 4 characters long | |
if len(an) < 4: | |
an_str = " %-.3s" % an | |
else: | |
an_str = "%.4s" % an | |
# moduli are used to make sure numbers do not go over prescribe field widths | |
# (this is not enforced by sprintf like with strings) | |
line = ( | |
"%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" | |
% ( | |
"HETATM" if loc.atom.het else "ATOM ", | |
ai % 100000, | |
an_str, | |
loc.alt, | |
rn, | |
cid, | |
ri % 10000, | |
icode, | |
loc.x, | |
loc.y, | |
loc.z, | |
loc.occ, | |
loc.B, | |
segid, | |
) | |
) | |
return line | |
# various formating options (the wonders of dealing with the good-old PDB format) | |
# and user-defined overrides | |
options = options.upper() | |
# the PDB file is intended for use in CHARMM or some other MM package | |
charmmFormat = True if "CHARMM" in options else False | |
# upon writing, convert from all-hydrogen topology (param 22 and higher) | |
# to CHARMM19 united-atom topology (matters for HIS protonation states) | |
charmm19Format = True if "CHARMM19" in options else False | |
# upon writing, convert from CHARMM19 united-atom topology to all-hydrogen | |
# param 22 topology (matters for HIS protonation states). Also works for | |
# converting generic PDB files downloaded from the PDB. | |
charmm22Format = True if "CHARMM22" in options else False | |
# upon writing, renumber residue and atom names to start from 1 and go in order | |
renumber = True if "RENUMBER" in options else False | |
# do not write END at the end of the PDB file (e.g., useful for | |
# concatenating chains from several structures) | |
noend = True if "NOEND" in options else False | |
# do not demark the end of each chain with TER (this is not _really_ | |
# necessary, assuming chain names are unique, and it is sometimes nice | |
# not to have extra lines other than atoms) | |
noter = True if "NOTER" in options else False | |
# write alternative locations by default | |
writeAlt = True if "NOALT" in options else False | |
# upon writing, convert to a generic PDB naming convention (no | |
# protonation state specified for HIS) | |
genericFormat = False | |
if charmm19Format and charmm22Format: | |
raise Exception( | |
"CHARMM 19 and 22 formatting options cannot be specified together" | |
) | |
atomIndex = 1 | |
for ci, chain in enumerate(self.chains()): | |
for ri, residue in enumerate(chain.residues()): | |
for ai, atom in enumerate(residue.atoms()): | |
# dirty details of formating for MM purposes converting | |
atomname = atom.name | |
resname = residue.name | |
if charmmFormat: | |
if (residue.name == "ILE") and (atom.name == "CD1"): | |
atomname = "CD" | |
if (atom.name == "O") and (ri == chain.num_residues() - 1): | |
atomname = "OT1" | |
if (atom.name == "OXT") and (ri == chain.num_residues() - 1): | |
atomname = "OT2" | |
if residue.name == "HOH": | |
resname = "TIP3" | |
if charmm19Format: | |
if residue.name == "HSD": # neutral HIS, proton on ND1 | |
resname = "HIS" | |
if residue.name == "HSE": # neutral HIS, proton on NE2 | |
resname = "HSD" | |
if residue.name == "HSC": # doubley-protonated +1 HIS | |
resname = "HSP" | |
elif charmm22Format: | |
"""This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded | |
* PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen | |
* topology, HIS protonation state must be explicitly specified, so there is no HIS per se. | |
* Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one | |
* does not even really know the protonation state). Whether sometimes people do specify it | |
* nevertheless, and what naming format they use to do so, I am not sure (welcome to the | |
* PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to | |
* HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but | |
* something needs to be assumed. Doing this renaming will make the PDB file work in MM | |
* packages with the all-hydrogen model.""" | |
if residue.name == "HSD": # neutral HIS, proton on NE2 | |
resname = "HSE" | |
if residue.name == "HIS": # neutral HIS, proton on ND1 | |
resname = "HSD" | |
if residue.name == "HSP": # doubley-protonated +1 HIS | |
resname = "HSC" | |
elif genericFormat: | |
if residue.name in ["HSD", "HSP", "HSE", "HSC"]: | |
resname = "HIS" | |
if (residue.name == "ILE") and (atom.name == "CD"): | |
atomname = "CD1" | |
# write the atom line | |
for li in range(atom.num_locations()): | |
if renumber: | |
f.write( | |
_pdb_line( | |
atom.get_location(li), | |
atomIndex, | |
ri=ri + 1, | |
rn=resname, | |
an=atomname, | |
) | |
+ "\n" | |
) | |
else: | |
f.write( | |
_pdb_line( | |
atom.get_location(li), | |
atomIndex, | |
rn=resname, | |
an=atomname, | |
) | |
+ "\n" | |
) | |
atomIndex = atomIndex + 1 | |
if not noter and (ri == chain.num_residues() - 1): | |
f.write("TER\n") | |
if not noend and (ci == self.num_chains() - 1): | |
f.write("END\n") | |
def canonicalize_protein( | |
self, | |
level=2, | |
drop_coors_unknowns=False, | |
drop_coors_missing_backbone=False, | |
filter_by_entity=False, | |
): | |
"""Canonicalize the calling System object (in place) by assuming that it represents | |
a protein molecular system. Different canonicalization rigor and options | |
can be specified but are all optional. | |
Args: | |
level (int): Canonicalization level that determines which nonstandard-to-standard | |
residue mappings are performed. Possible values are 1, 2 or 3, with 2 being | |
the default and higher values meaning more rigorous (and less conservative) | |
canonicalization. With level 1, only truly equivalent mappings are performed | |
(e.g., different His protonation states are mapped to the canonical residue | |
name HIS that does not specify protonation). Level 2 adds to this some less | |
exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- | |
cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds | |
even less equivalent but still reasonable mappings--i.e., phosphorylated SER, | |
THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys | |
to Cys. | |
drop_coors_unknowns (bool, optional): if True, will discard structural information | |
for all residues that are not natural or mappable under the current level. | |
NOTE: any sequence record for these residues (i.e., if they are part of a | |
polymer entity) will be preserved. | |
drop_coors_missing_backbone (bool, optional): if True, will discard structural | |
information for residues that do not have at least the N, CA, C, and O | |
backbone atoms. Same note applies regarding the full sequence record as in | |
the above. | |
filter_by_entity (bool, optional): if True, will remove any chains that do not | |
represent polymer/polypeptide entities. This is convenient for cases where a | |
System object has both protein and non-protein components. However, depending | |
on how the System object was generated, entity metadata may not have been filled, | |
so applying this canonicalization approach will remove the entire structure. | |
For this reason, the option is False by default. | |
""" | |
def _mod_to_standard_aa_mappings( | |
less_standard: bool, almost_standard: bool, standard: bool | |
): | |
# Perfectly corresponding to standard residues | |
standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} | |
# Almost perfectly corresponding to standard residues: | |
# * MSE -- selenomethyonine; SEC -- selenocysteine | |
almost_standard_map = {"MSE": "MET", "SEC": "CYS"} | |
# A little less perfectly corresponding pairings, but can be acceptable (depends): | |
# * HIP -- ND1-phosphohistidine; SEP -- phosphoserine; TPO -- phosphothreonine; | |
# * PTR -- o-phosphotyrosine. | |
less_standard_map = { | |
"HIP": "HIS", | |
"CSX": "CYS", | |
"SEP": "SER", | |
"TPO": "THR", | |
"PTR": "TYR", | |
} | |
ret = dict() | |
if standard: | |
ret.update(standard_map) | |
if almost_standard: | |
ret.update(almost_standard_map) | |
if less_standard: | |
ret.update(less_standard_map) | |
return ret | |
def _to_standard_aa_mappings( | |
less_standard: bool, almost_standard: bool, standard: bool | |
): | |
# get the mapping between modifications and their corresponding standard forms | |
mapping = _mod_to_standard_aa_mappings( | |
less_standard, almost_standard, standard | |
) | |
# add mapping between standard names and themselves | |
import chroma.utility.polyseq as polyseq | |
for aa in polyseq.canonical_amino_acids(): | |
mapping[aa] = aa | |
return mapping | |
less_standard, almost_standard, standard = False, False, False | |
if level == 3: | |
less_standard, almost_standard, standard = True, True, True | |
elif level == 2: | |
less_standard, almost_standard, standard = False, True, True | |
elif level == 1: | |
less_standard, almost_standard, standard = False, False, True | |
else: | |
raise Exception(f"unknown canonicalization level {level}") | |
to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) | |
# NOTE: need to re-implement the canonicalization procedure such that it: | |
# 1. checks to make sure entity sequence and structure sequence agree (error if not) | |
# 2. goes over entities and looks for residues to rename, does the renaming on the entities | |
# and all chains simultaneously (so that no new entities are created) | |
# 3. then goes over the structured part and fixes atoms | |
# For residue renamings, we will first record all edits and will perform them | |
# afterwards in one go, so we can judge whether any new entities have to be | |
# created. The dictionary `esidues_to_rename`` will be as follows: | |
# entity_id: { | |
# chain_index: [list of (residue index, rew name) tuples] | |
# } | |
chains_to_delete = [] | |
residues_to_rename = dict() | |
for ci, chain in enumerate(self.chains()): | |
entity = chain.get_entity() | |
if filter_by_entity: | |
if ( | |
(entity is None) | |
or (entity._type != "polymer") | |
or ("polypeptide" not in entity.polymer_time) | |
): | |
chains_to_delete.append(chain) | |
continue | |
# iterate in reverse order so we can safely delete any residues we find necessary | |
cleared_residues = 0 | |
for residue in reversed(list(chain.residues())): | |
aa = residue.name | |
delete_atoms = False | |
# canonicalize amino acid (delete structure if unknown, provided this was asked for) | |
if aa in to_standard: | |
aa_new = to_standard[aa] | |
if aa != aa_new: | |
# edit any atoms to reflect the mutation | |
if ( | |
(aa == "HSD") | |
or (aa == "HSE") | |
or (aa == "HSC") | |
or (aa == "HSP") | |
) and (aa_new == "HIS"): | |
pass | |
elif ((aa == "MSE") and (aa_new == "MET")) or ( | |
(aa == "SEC") and (aa_new == "CYS") | |
): | |
SE = residue.find_atom("SE") | |
if SE is not None: | |
if aa == "MSE": | |
SE.residue.rename("SD") | |
else: | |
SE.residue.rename("SG") | |
elif ( | |
((aa == "HIP") and (aa_new == "HIS")) | |
or ((aa == "SEP") and (aa_new == "SER")) | |
or ((aa == "TPO") and (aa_new == "THR")) | |
or ((aa == "PTR") and (aa_new == "TYR")) | |
): | |
# delete the phosphate group | |
for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: | |
a = residue.find_atom(atomname) | |
if a is not None: | |
a.delete() | |
elif (aa == "CSX") and (aa_new == "CYS"): | |
a = residue.find_atom("OD") | |
if a is not None: | |
a.delete() | |
# record residue renaming operation to be done later | |
entity_id = chain.get_entity_id() | |
if entity_id is None: | |
residue.rename(aa_new) | |
else: | |
if entity_id not in residues_to_rename: | |
residues_to_rename[entity_id] = dict() | |
if ci not in residues_to_rename[entity_id]: | |
residues_to_rename[entity_id][ci] = list() | |
residues_to_rename[entity_id][ci].append( | |
(residue.get_index_in_chain(), aa_new) | |
) | |
else: | |
if aa == "ARG": | |
A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} | |
for an in A: | |
atom = residue.find_atom(an) | |
if atom is not None and atom.num_locations(): | |
A[an] = atom.get_location(0) | |
if all([a is not None for n, a in A.items()]): | |
dihe1 = System.dihedral( | |
A["CD"], A["NE"], A["CZ"], A["NH1"] | |
) | |
dihe2 = System.dihedral( | |
A["CD"], A["NE"], A["CZ"], A["NH2"] | |
) | |
if abs(dihe1) > abs(dihe2): | |
A["NH1"].name = "NH2" | |
A["NH2"].name = "NH1" | |
elif drop_coors_unknowns: | |
delete_atoms = True | |
if not drop_coors_missing_backbone: | |
if not delete_atoms and not residue.has_full_backbone(): | |
delete_atoms = True | |
if delete_atoms: | |
residue.delete_atoms() | |
cleared_residues += 1 | |
# If we have deleted all residues in this chain, then this is probably not | |
# a protein chain, so get rid of it. Unless we are asked to pay attention to | |
# the entity type (i.e., whether it is peptidic), in which case the decision | |
# of whether to keep the chain would have been made previously. | |
if ( | |
not filter_by_entity | |
and (cleared_residues != 0) | |
and (cleared_residues == chain.num_residues()) | |
): | |
chains_to_delete.append(chain) | |
# rename residues differently depending on whether all chains of a given entity | |
# have the same set of renamings | |
for entity_id, ops in residues_to_rename.items(): | |
chain_indices = set(ops.keys()) | |
entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) | |
unique_renames = set([tuple(v) for v in ops.values()]) | |
fork = True | |
if (chain_indices == entity_chains) and (len(unique_renames) == 1): | |
# we can rename without updating entities, because all entity chains are updated the same way | |
fork = False | |
for ci, renames in ops.items(): | |
chain = self.get_chain(ci) | |
for ri, new_name in renames: | |
chain.get_residue(ri).rename(new_name, fork_entity=fork) | |
# now delete any chains | |
for chain in reversed(chains_to_delete): | |
chain.delete() | |
self._reindex() | |
def sequence(self, format="three-letter-list"): | |
"""Returns the full sequence of this System, concatenated over all | |
chains in their order within the System. | |
Args: | |
format (str): sequence format. Possible options are either | |
"three-letter-list" (default) or "one-letter-string". | |
Returns: | |
List (default) or string. | |
""" | |
if format == "three-letter-list": | |
seq = [] | |
else: | |
seq = "" | |
for chain in self.chains(): | |
seq = seq + chain.sequence(format) | |
return seq | |
def distance(a1: AtomLocationView, a2: AtomLocationView): | |
"""Computes the distance between atom locations `a1` and `a2`.""" | |
v21 = a1.coors - a2.coors | |
return np.linalg.norm(v21) | |
def angle( | |
a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False | |
): | |
"""Computes the angle formed by three 3D points represented by AtomLocationView objects. | |
Args: | |
a1, a2, a3 (AtomLocationView): three 3D points. | |
radian (bool, optional): if True (default False), will return the angle in radians. | |
Otherwise, in degrees. | |
Returns: | |
Angle `a1`-`a2`-`a3`. | |
""" | |
v21 = a1.coors - a2.coors | |
v23 = a3.coors - a2.coors | |
v21 = v21 / np.linalg.norm(v21) | |
v23 = v23 / np.linalg.norm(v23) | |
c = np.dot(v21, v23) | |
return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) | |
def dihedral( | |
a1: AtomLocationView, | |
a2: AtomLocationView, | |
a3: AtomLocationView, | |
a4: AtomLocationView, | |
radians=False, | |
): | |
"""Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. | |
Args: | |
a1, a2, a3, a4 (AtomLocationView): four 3D points. | |
radian (bool, optional): if True (default False), will return the angle in radians. | |
Otherwise, in degrees. | |
Returns: | |
Dihedral angle `a1`-`a2`-`a3`-`a4`. | |
""" | |
AB = a1.coors - a2.coors | |
CB = a3.coors - a2.coors | |
DC = a4.coors - a3.coors | |
if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: | |
raise Exception("some points coincide in dihedral calculation") | |
ABxCB = np.cross(AB, CB) | |
ABxCB = ABxCB / np.linalg.norm(ABxCB) | |
DCxCB = np.cross(DC, CB) | |
DCxCB = DCxCB / np.linalg.norm(DCxCB) | |
# the following is necessary for values very close to 1 but just above | |
dotp = np.dot(ABxCB, DCxCB) | |
if dotp > 1.0: | |
dotp = 1.0 | |
elif dotp < -1.0: | |
dotp = -1.0 | |
angle = np.arccos(dotp) | |
if np.dot(ABxCB, DC) > 0: | |
angle *= -1 | |
if not radians: | |
angle *= 180.0 / np.pi | |
return angle | |
def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): | |
"""Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. | |
Specifically, possible known names in each category are: | |
'N', 'NT' | |
'CA', 'C', 'CY', 'CAY' | |
'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' | |
'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' | |
""" | |
array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] | |
if atom_name in ["N", "NT"]: | |
return array[0] | |
if atom_name == "CA": | |
return array[1] | |
if (atom_name == "C") or (atom_name == "CY"): | |
return array[2] | |
if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): | |
return array[3] | |
if not no_hyd: | |
if atom_name in ["H", "HA", "HN"]: | |
return array[4] | |
if atom_name.startswith("HT") or atom_name.startswith("HY"): | |
return array[4] | |
# Rosetta's N-terinal amine has hydrogens named 1H, 2H, and 3H | |
if ( | |
atom_name.startswith("1H") | |
or atom_name.startswith("2H") | |
or atom_name.startswith("3H") | |
): | |
return array[4] | |
return None | |
class SystemEntity: | |
"""A molecular entity represented in a molecular system.""" | |
_type: str | |
_desc: str | |
_polymer_type: str | |
_seq: list | |
_het: list | |
def is_polymer(self): | |
"""Returns whether the entity represents a polymer.""" | |
return self._type == "polymer" | |
def guess_entity_and_polymer_type(cls, seq: List): | |
is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 | |
polymer_type = None | |
if is_poly: | |
entity_type = "polymer" | |
for ptype in polyseq.polymerType: | |
if ( | |
np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) | |
> 0.8 | |
): | |
polymer_type = polyseq.polymer_type_name(ptype) | |
break | |
else: | |
entity_type = "unknown" | |
return entity_type, polymer_type | |
def type(self): | |
return self._type | |
def description(self): | |
return self._desc | |
def polymer_type(self): | |
return self._polymer_type | |
def sequence(self): | |
return self._seq | |
def hetero(self): | |
return self._het | |
class BaseView: | |
"""An abstract base "view" class for accessing different parts of System.""" | |
_ix: int | |
_parent: object | |
def get_index(self): | |
"""Return the index of this atom location in its System.""" | |
return self._ix | |
def is_valid(self): | |
return self._ix >= 0 and self._parent is not None | |
def _delete(self): | |
at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) | |
self.parent._siblings.delete_child(self.parent._ix, at) | |
def parent(self): | |
return self._parent | |
class ChainView(BaseView): | |
"""A Chain view, allowing hierarchical exploration and editing.""" | |
def __init__(self, ix: int, system: System): | |
self._ix = ix | |
self._parent = system | |
self._siblings = system._chains | |
def __str__(self): | |
return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" | |
def residues(self): | |
for rn in range(self.num_residues()): | |
ri = self._siblings.child_index(self._ix, rn) | |
yield ResidueView(ri, self) | |
def num_residues(self): | |
"""Returns the number of residues in the Chain.""" | |
return self._siblings.num_children(self._ix) | |
def num_structured_residues(self): | |
return sum([res.has_structure() for res in self.residues()]) | |
def num_atoms(self): | |
return sum([res.num_atoms() for res in self.residues()]) | |
def num_atom_locations(self): | |
return sum([res.num_atom_locations() for res in self.residues()]) | |
def sequence(self, format="three-letter-list"): | |
"""Returns the sequence of this chain. See `System::sequence()` for | |
possible formats. | |
""" | |
if format == "three-letter-list": | |
seq = [None] * self.num_residues() | |
for ri, residue in enumerate(self.residues()): | |
seq[ri] = residue.name | |
return seq | |
elif format == "one-letter-string": | |
import chroma.utility.polyseq as polyseq | |
seq = [None] * self.num_residues() | |
for ri, residue in enumerate(self.residues()): | |
seq[ri] = polyseq.to_single(residue.name) | |
return "".join(seq) | |
else: | |
raise Exception(f"unknown sequence format {format}") | |
def get_residue(self, ri: int): | |
"""Get the residue at the specified index within the Chain. | |
Args: | |
ri (int): Residue index within the Chain. | |
Returns: | |
ResidueView object corresponding to the residue in question. | |
""" | |
if ri < 0 or ri >= self.num_residues(): | |
raise Exception( | |
f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" | |
) | |
ri = self._siblings.child_index(self._ix, ri) | |
return ResidueView(ri, self) | |
def get_residue_index(self, residue: ResidueView): | |
"""Get the index of the given residue in this Chain.""" | |
return residue._ix - self._siblings.child_index(self._ix, 0) | |
def get_atom(self, aidx: int): | |
"""Get the atom at index `aidx` within this chain.""" | |
if aidx < 0: | |
raise Exception(f"negative atom index: {aidx}") | |
off = 0 | |
for residue in self.residues(): | |
na = residue.num_atoms() | |
if aidx < off + na: | |
return residue.get_atom(aidx - off) | |
off = off + na | |
raise Exception( | |
f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" | |
) | |
def get_atoms(self): | |
"""Return a list of all atoms in this chain.""" | |
atoms_views = [] | |
for residue in self.residues(): | |
atoms_views.extend(residue.get_atoms()) | |
return atoms_views | |
def __getitem__(self, res_idx: int): | |
return self.get_residue(res_idx) | |
def get_entity_id(self): | |
"""Return the entity ID corresponding to this chain.""" | |
return self.system._chain_entities[self._ix] | |
def get_entity(self): | |
"""Return the entity this chain belongs to.""" | |
entity_id = self.get_entity_id() | |
if entity_id is None: | |
return None | |
return self.system._entities[entity_id] | |
def check_sequence(self): | |
"""Compare the list of residue names of this chain to the corresponding entity sequence record.""" | |
entity = self.get_entity() | |
if entity is not None and entity.is_polymer(): | |
if self.num_residues() != len(entity._seq): | |
return False | |
for res, ent_aan in zip(self.residues(), entity._seq): | |
if res.name != ent_aan: | |
return False | |
return True | |
def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): | |
"""Add a new residue to this chain. | |
Args: | |
name (str): Residue name. | |
num (int): Residue number (i.e., residue ID). | |
authid (str): Author residue ID. | |
icode (str): Insertion code. | |
at (int, optional): Index at which to insert the residue. Default | |
is to append to the end of the chain (i.e., equivalent of ``at` | |
being equal to the present length of the chain). | |
""" | |
if at is None: | |
at = self.num_residues() | |
ri = self._siblings.insert_child( | |
self._ix, | |
at, | |
{"name": name, "resnum": num, "authresid": authid, "icode": icode}, | |
) | |
return ResidueView(ri, self) | |
def delete(self, keep_entity=False): | |
"""Deletes this Chain from its System. | |
Args: | |
keep_entity (bool, optional): If False (default) and if the chain | |
being deleted happens to be the last representative of the | |
entity it belongs to, the entity will be deleted. If True, the | |
entity will always be kept. | |
""" | |
# delete the mention of the chain from assembly information | |
self.system._assembly_info.delete_chain(self.cid) | |
# optionally, delete the corresponding entity if no other chains point to it | |
if not keep_entity: | |
eid = self.get_entity_id() | |
if self.system.num_chains_of_entity(eid) == 0: | |
self.system.delete_entity(eid) | |
self.system._chain_entities.pop(self._ix) | |
self._siblings.delete(self._ix) | |
self._ix = -1 # invalidate the view | |
def system(self): | |
return self._parent | |
def cid(self): | |
return self._siblings["cid"][self._ix] | |
def segid(self): | |
return self._siblings["segid"][self._ix] | |
def authid(self): | |
return self._siblings["authid"][self._ix] | |
def cid(self, val): | |
self._siblings["cid"][self._ix] = val | |
def segid(self, val): | |
self._siblings["segid"][self._ix] = val | |
def authid(self, val): | |
self._siblings["authid"][self._ix] = val | |
class ResidueView(BaseView): | |
"""A Residue view, allowing hierarchical exploration and editing.""" | |
def __init__(self, ix: int, chain: ChainView): | |
self._ix = ix | |
self._parent = chain | |
self._siblings = chain.system._residues | |
def __str__(self): | |
return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" | |
def atoms(self): | |
off = self._siblings.child_index(self._ix, 0) | |
for an in range(self.num_atoms()): | |
yield AtomView(off + an, self) | |
def num_atoms(self): | |
return self._siblings.num_children(self._ix) | |
def num_atom_locations(self): | |
return sum([a.num_locations() for a in self.atoms()]) | |
def has_structure(self): | |
"""Returns whether the atom has any structural information (i.e., one or more locations).""" | |
for a in self.atoms(): | |
if a.num_locations(): | |
return True | |
return False | |
def get_atom(self, ai: int): | |
"""Get the atom at the specified index within the Residue. | |
Args: | |
atom_idx (int): Atom index within the Residue. | |
Returns: | |
AtomView object corresponding to the atom in question. | |
""" | |
if ai < 0 or ai >= self.num_atoms(): | |
raise Exception( | |
f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" | |
) | |
ai = self._siblings.child_index(self._ix, ai) | |
return AtomView(ai, self) | |
def get_atom_index(self, atom: AtomView): | |
"""Get the index of the given atom in this Residue.""" | |
return atom._ix - self._siblings.child_index(self._ix, 0) | |
def find_atom(self, name): | |
"""Find and return the first atom (as AtomView object) with the given name | |
within the Residue or None.""" | |
for atom in self.atoms(): | |
if atom.name == name: | |
return atom | |
return None | |
def __getitem__(self, atom_idx: int): | |
return self.get_atom(atom_idx) | |
def get_index_in_chain(self): | |
"""Return the index of the Residue in its parent Chain.""" | |
return self.chain.get_residue_index(self) | |
def rename(self, new_name: str, fork_entity=True): | |
"""Assigns the residue a new name with all proper updates. | |
Args: | |
new_name (str): New residue name. | |
fork_entity (bool, optional): If True (default) and if parent | |
chain corresponds to an entity that has other chains | |
associated with it and there is a real renaming (i.e., | |
the old name is not the same as the new name), will | |
make a new (duplicate) entity for to this chain and | |
will edit the new one, leaving the old one unchanged. | |
If False, will not perform this regardless. NOTE: | |
setting this to False can create an inconsistent state | |
between chain and entity sequence information. | |
""" | |
entity_id = self.chain.get_entity_id() | |
if entity_id is not None: | |
entity = self.system._entities[entity_id] | |
ri = self.get_index_in_chain() | |
if fork_entity and (entity._seq[ri] != new_name): | |
ci = self.chain.get_index() | |
entity_id = self.system._ensure_unique_entity(ci) | |
entity = self.system._entities[entity_id] | |
entity._seq[ri] = new_name | |
self._siblings["name"][self._ix] = new_name | |
def add_atom( | |
self, | |
name: str, | |
het: bool, | |
x: float = None, | |
y: float = None, | |
z: float = None, | |
occ: float = 1.0, | |
B: float = 0.0, | |
alt: str = " ", | |
at=None, | |
): | |
"""Adds a new atom to the residue (appending it at the end) and | |
returns an AtomView to it. If atom location information is | |
specified, will also add a location to the atom. | |
Args: | |
name (str): Atom name. | |
het (bool): Whether it is a hetero-atom. | |
x, y, z (float): Atom location coordinates. | |
occ (float): Occupancy. | |
B (float): B-factor. | |
alt (str): Alternative position character. | |
at (int, optional): Index at which to insert the atom. Default | |
is to append to the end of the residue (i.e., equivalent of | |
``at` being equal to the number of atoms in the residue). | |
Returns: | |
AtomView object corresponding to the newly added atom. | |
""" | |
if at is None: | |
at = self.num_atoms() | |
ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) | |
atom = AtomView(ai, self) | |
# now add a location to this atom | |
if x is not None: | |
atom.add_location(x, y, z, occ, B, alt) | |
return atom | |
def delete(self, fork_entity=True): | |
"""Deletes this residue from its Chain/System. | |
Args: | |
fork_entity (bool, optional): If True (default) and if parent | |
chain corresponds to an entity that has other chains | |
associated with it, will make a new (duplicate) entity | |
for to this chain and will edit the new one, leaving the | |
old one unchanged. If False, will not perform this. | |
NOTE: setting this to False can create an inconsistent state | |
between chain and entity sequence information. | |
""" | |
# update the entity (duplicating, if necessary) | |
entity_id = self.chain.get_entity_id() | |
if entity_id is not None: | |
entity = self.system._entities[entity_id] | |
ri = self.get_index_in_chain() | |
if fork_entity: | |
ci = self.chain.get_index() | |
entity_id = self.system._ensure_unique_entity(ci) | |
entity = self.system._entities[entity_id] | |
entity._seq.pop(ri) | |
# delete the residue | |
self._delete() | |
self._ix = -1 # invalidate the view | |
def delete_atoms(self, atoms=None): | |
"""Delete either the specified list of atoms or all atoms from the residue. | |
Args: | |
atoms (list, optional): List of AtomView objects corresponding to the | |
atoms to delete. If not specified, will delete all atoms in the residue. | |
""" | |
if atoms is None: | |
atoms = list(self.atoms()) | |
for atom in reversed(atoms): | |
if atom.residue != self: | |
raise Exception(f"Atom {atom} does not belong to Residue {self}") | |
atom.delete() | |
def chain(self): | |
return self._parent | |
def system(self): | |
return self.chain.system | |
def name(self): | |
return self._siblings["name"][self._ix] | |
def num(self): | |
return self._siblings["resnum"][self._ix] | |
def authid(self): | |
return self._siblings["authresid"][self._ix] | |
def icode(self): | |
return self._siblings["icode"][self._ix] | |
def get_backbone(self, no_hyd=True): | |
"""Assuming that this is a protein residue (i.e., an amino acid), returns the | |
list of atoms corresponding to the residue's backbone, in the order: | |
backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), | |
and amide hydrogen (H, optional). | |
Args: | |
no_hyd (bool, optional): If True (default), will exclude the amide hydrogen | |
and only return four atoms. If False, will include the amide hydrogen. | |
Returns: | |
A list with each entry being an AtomView object corresponding to the backbone | |
atom in the order above or None if the atom does not exist in the residue. | |
""" | |
bb = [None] * (4 if no_hyd else 5) | |
left = len(bb) | |
for atom in self.atoms(): | |
i = System.protein_backbone_atom_type(atom.name, no_hyd) | |
if i is None or bb[i] is not None: | |
continue | |
bb[i] = atom | |
left = left - 1 | |
if left == 0: | |
break | |
return bb | |
def has_full_backbone(self, no_hyd=True): | |
"""Assuming that this is a protein residue (i.e., an amino acid), returns | |
whether the residue harbors a structurally defined backbone (i.e., has | |
all backbone atoms each of which has location information). | |
Args: | |
no_hyd (bool, optional): If True (default), will ignore whether the amide | |
hydrogen exists or not (if False will consider it). | |
Returns: | |
Boolean indicating whether there is a full backbone in the residue. | |
""" | |
bb = self.get_backbone(no_hyd) | |
return all([(a is not None) and a.num_locations() for a in bb]) | |
def delete_non_backbone(self, no_hyd=True): | |
"""Assuming that this is a protein residue (i.e., an amino acid), deletes | |
all atoms except backbone atoms. | |
Args: | |
no_hyd (bool, optional): If True (default), will not consider the amide | |
hydrogen as a backbone atom (if False will consider it). | |
""" | |
to_delete = [] | |
for atom in self.atoms(): | |
if System.protein_backbone_atom_type(atom.name, no_hyd) is None: | |
to_delete.append(atom) | |
self.delete_atoms(to_delete) | |
class AtomView(BaseView): | |
"""An Atom view, allowing hierarchical exploration and editing.""" | |
def __init__(self, ix: int, residue: ResidueView): | |
self._ix = ix | |
self._parent = residue | |
self._siblings = residue.system._atoms | |
def __str__(self): | |
string = self.name + (" (HET) " if self.het else " ") | |
if self.num_locations() > 0: | |
string = string + str(self.get_location(0)) | |
string = string + f" ({self.num_locations()})" | |
return string + " -> " + str(self.residue) | |
def locations(self): | |
off = self._siblings.child_index(self._ix, 0) | |
for ln in range(self.num_locations()): | |
yield AtomLocationView(off + ln, self) | |
def num_locations(self): | |
return self._siblings.num_children(self._ix) | |
def __getitem__(self, loc_idx: int): | |
return self.get_location(loc_idx) | |
def get_location(self, li: int = 0): | |
"""Returns the (li+1)-th location of the atom.""" | |
if li < 0 or li >= self.num_locations(): | |
raise Exception( | |
f"location index {li} out of range for Atom with {self.num_locations()} locations" | |
) | |
li = self._siblings.child_index(self._ix, li) | |
return AtomLocationView(li, self) | |
def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): | |
"""Adds a location to this atom, append it to the end. | |
Args: | |
x, y, z (float): coordinates of the location. | |
occ (float): occupancy for the location. | |
B (float): B-factor for the location. | |
alt (str): alternative location character. | |
at (int, optional): Index at which to insert the location. Default | |
is to append at the end (i.e., equivalent of ``at` being equal | |
to the current number of locations). | |
""" | |
if at is None: | |
at = self.num_locations() | |
li = self._siblings.insert_child( | |
self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} | |
) | |
return AtomLocationView(li, self) | |
def delete(self): | |
"""Deletes this atom from its Residue/Chain/System.""" | |
self._delete() | |
self._ix = -1 # invalidate the view | |
def residue(self): | |
return self._parent | |
def chain(self): | |
return self.residue.chain | |
def system(self): | |
return self.chain.system | |
def name(self): | |
return self._siblings["name"][self._ix] | |
def het(self): | |
return self._siblings["het"][self._ix] | |
"""Location information getters and setters operate on the default (first) | |
location for this atom and throw an index error if there are no locations.""" | |
def x(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 0] | |
def y(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 1] | |
def z(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 2] | |
def coors(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 0:3] | |
def occ(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 3] | |
def B(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["coor"][ix, 4] | |
def alt(self): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
return self.system._locations["alt"][ix] | |
def x(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["coor"][ix, 0] = val | |
def y(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["coor"][ix, 1] = val | |
def z(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["coor"][ix, 2] = val | |
def occ(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["coor"][ix, 3] = val | |
def B(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["coor"][ix, 4] = val | |
def alt(self, val): | |
if self._siblings.num_children(self._ix) == 0: | |
raise Exception("atom has no locations") | |
ix = self._siblings.child_index(self._ix, 0) | |
self.system._locations["alt"][ix] = val | |
class DummyAtomView(AtomView): | |
"""An dummy Atom view that can be attached to a residue but that does not | |
have any locations and with no other information.""" | |
def __init__(self, residue: ResidueView): | |
self._ix = -1 | |
self._parent = residue | |
def __str__(self): | |
return "DUMMY -> " + str(self.residue) | |
def locations(self): | |
return | |
yield | |
def num_locations(self): | |
return 0 | |
def __getitem__(self, loc_idx: int): | |
return None | |
def get_location(self, li: int = 0): | |
raise Exception(f"no locations in DUMMY atom") | |
def add_location(self, x, y, z, occ, B, alt, at=None): | |
raise Exception(f"can't add no locations to DUMMY atom") | |
def residue(self): | |
return self._parent | |
def chain(self): | |
return self.residue.chain | |
def system(self): | |
return self.chain.system | |
def name(self): | |
return None | |
def het(self): | |
return None | |
def x(self): | |
raise Exception(f"no coordinates in DUMMY atom") | |
def y(self): | |
raise Exception(f"no coordinates in DUMMY atom") | |
def z(self): | |
raise Exception(f"no coordinates in DUMMY atom") | |
def occ(self): | |
raise Exception(f"no occupancy in DUMMY atom") | |
def B(self): | |
raise Exception(f"no B-factor in DUMMY atom") | |
def alt(self): | |
raise Exception(f"no alt flag in DUMMY atom") | |
def x(self, val): | |
raise Exception(f"can't set coordinate for DUMMY atom") | |
def y(self, val): | |
raise Exception(f"can't set coordinate for DUMMY atom") | |
def z(self, val): | |
raise Exception(f"can't set coordinate for DUMMY atom") | |
def occ(self, val): | |
raise Exception(f"can't set occupancy for DUMMY atom") | |
def B(self, val): | |
raise Exception(f"can't set B-factor for DUMMY atom") | |
def alt(self, val): | |
raise Exception(f"can't set alt flag for DUMMY atom") | |
class AtomLocationView(BaseView): | |
"""An AtomLocation view, allowing hierarchical exploration and editing.""" | |
def __init__(self, ix: int, atom: AtomView): | |
self._ix = ix | |
self._parent = atom | |
self._siblings = atom.system._locations | |
def __str__(self): | |
return f"{self.x} {self.y} {self.z}" | |
def swap(self, other: AtomLocationView): | |
"""Swaps information between itself and the provided atom location. | |
Args: | |
other (AtomLocationView): the other atom location to swap with. | |
""" | |
self.x, other.x = other.x, self.x | |
self.y, other.y = other.y, self.y | |
self.z, other.z = other.z, self.z | |
self.occ, other.occ = other.occ, self.occ | |
self.B, other.B = other.B, self.B | |
self.alt, other.alt = other.alt, self.alt | |
def defined(self): | |
"""Return whether this is a valid location.""" | |
return (self.x is not None) and (self.y is not None) and (self.z is not None) | |
def atom(self): | |
return self._parent | |
def residue(self): | |
return self.atom.residue | |
def chain(self): | |
return self.residue.chain | |
def system(self): | |
return self.chain.system | |
def x(self): | |
return self.system._locations["coor"][self._ix, 0] | |
def y(self): | |
return self.system._locations["coor"][self._ix, 1] | |
def z(self): | |
return self.system._locations["coor"][self._ix, 2] | |
def occ(self): | |
return self.system._locations["coor"][self._ix, 3] | |
def B(self): | |
return self.system._locations["coor"][self._ix, 4] | |
def alt(self): | |
return self.system._locations["alt"][self._ix] | |
def coors(self): | |
return np.array(self.system._locations["coor"][self._ix, 0:3]) | |
def coor_info(self): | |
return np.array(self.system._locations["coor"][self._ix]) | |
def x(self, val): | |
self.system._locations["coor"][self._ix, 0] = val | |
def y(self, val): | |
self.system._locations["coor"][self._ix, 1] = val | |
def z(self, val): | |
self.system._locations["coor"][self._ix, 2] = val | |
def coors(self, val): | |
self.system._locations["coor"][self._ix, 0:3] = val | |
def coor_info(self, val): | |
self.system._locations["coor"][self._ix] = val | |
def occ(self, val): | |
self.system._locations["coor"][self._ix, 3] = val | |
def B(self, val): | |
self.system._locations["coor"][self._ix, 4] = val | |
def alt(self, val): | |
self.system._locations["alt"][self._ix] = val | |
class ExpressionTreeEvaluator: | |
"""A class for evaluating custom logical parenthetical expressions. The | |
implementation is very generic, supports nullary, unary, and binary | |
operators, and does not know anything about what the expressions actually | |
mean. Instead the class interprets the expression as a tree of sub- | |
expressions, governed by parentheses and operators, and traverses the | |
calling upon a user-specified evaluation function to evaluate leaf | |
nodes as the tree is gradually collapsed into a single node. This | |
can be used for evaluating set expressions, algebraic expressions, and | |
others. | |
Args: | |
operators_nullary (list): A list of strings designating nullary operators | |
(i.e., operators that do not have any operands). E.g., if the language | |
describes selection algebra, these could be "hyd", "all", or "none"]. | |
operators_unary (list): A list of strings designating unary operators | |
(i.e., operators that have one operand, which must comes to the right | |
of the operator). E.g., if the language describes selection algebra, | |
these could be "name", "resid", or "chain". | |
operators_binary (list): A list of strings designating binary operators | |
(i.e., operators that have two operands, one on each side of the | |
operator). E.g., if the language describes selection algebra, thse | |
could be "and", "or", or "around". | |
eval_function (str): A function that is able to evaluate a leaf node of | |
the expression tree. It shall accept three parameters: | |
operator (str): name of the operator | |
left: the left operand. Will be None if the left operand is missing or | |
not relevant. Otherwise, can be either a list of strings, which | |
should represent an evaluatable sub-expression corresponding to the | |
left operand, or the result of a prior evaluation of this function. | |
right: Same as `left` but for the right operand. | |
The function should attempt to evaluate the resulting expression and | |
return None in the case of failing or a dictionary with the result of | |
the evaluation stored under key "result". | |
left_associativity (bool): If True (the default), operators are taken to be | |
left-associative. Meaning something like "A and B or C" is "(A and B) or C". | |
If False, the operators are taken to be right-associative, such that | |
the same expression becomes "A and (B or C)". NOTE: MST is right-associative | |
but often human intiution tends to be left-associative. | |
debug (bool): If True (default is false), will print a great deal of debugging | |
messages to help diagnose any evaluation problems. | |
""" | |
def __init__( | |
self, | |
operators_nullary: list, | |
operators_unary: list, | |
operators_binary: list, | |
eval_function: function, | |
left_associativity: bool = True, | |
debug: bool = False, | |
): | |
self.operators_nullary = operators_nullary | |
self.operators_unary = operators_unary | |
self.operators_binary = operators_binary | |
self.operators = operators_nullary + operators_unary + operators_binary | |
self.eval_function = eval_function | |
self.debug = debug | |
self.left_associativity = left_associativity | |
def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): | |
def _collect_operands(E, j): | |
# collect all operands before hitting an operator | |
operands = [] | |
for k in range(len(E[j:])): | |
if E[j + k] in self.operators: | |
k = k - 1 | |
break | |
operands.append(E[j + k]) | |
return operands, j + k + 1 | |
def _find_matching_close_paren(E, beg: int): | |
c = 0 | |
for i in range(beg, len(E)): | |
if E[i] == "(": | |
c = c + 1 | |
elif E[i] == ")": | |
c = c - 1 | |
if c == 0: | |
return i | |
return None | |
def _my_eval(op, left, right, debug=False): | |
if debug: | |
print( | |
f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" | |
) | |
result = self.eval_function(op, left, right) | |
if debug: | |
print(f"\t-> got result {operand_str(result)}") | |
return result | |
def operand_str(operand): | |
if isinstance(operand, dict): | |
if "result" in operand and len(operand["result"]) > 15: | |
vec = list(operand["result"]) | |
beg = ", ".join([str(i) for i in vec[:5]]) | |
end = ", ".join([str(i) for i in vec[-5:]]) | |
return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" | |
return str(operand) | |
return str(operand) | |
left, right, op = None, None, None | |
if debug: | |
print(f"-> received {E[i:]}") | |
while i < len(E): | |
if all([x is None for x in (left, right, op)]): | |
# first part can either be a left parenthesis, a left operand, a nullary operator, or a unary operator | |
if E[i] == "(": | |
end = _find_matching_close_paren(E, i) | |
if end is None: | |
return None, f"parenthesis imbalance starting with {E[i:]}" | |
# evaluate expression inside the parentheses, and it becomes the left operand | |
left, rem = self._traverse_expression_tree( | |
E[i + 1 : end], 0, eval_all=True, debug=debug | |
) | |
if left is None: | |
return None, rem | |
i = end + 1 | |
if not eval_all: | |
return left, i | |
elif E[i] in self.operators_nullary: | |
# evaluate nullary op | |
left = _my_eval(E[i], None, None, debug) | |
if left is None: | |
return None, f"failed to evaluate nullary operator '{E[i]}'" | |
i = i + 1 | |
elif E[i] in self.operators_unary: | |
op = E[i] | |
i = i + 1 | |
elif E[i] in self.operators: | |
# an operator other than a unary operator cannot appear first | |
return None, f"unexpected binary operator in the context {E[i:]}" | |
else: | |
# if not an operator, then we are looking at operand(s) | |
left, i = _collect_operands(E, i) | |
elif (left is not None) and (op is None) and (right is None): | |
# we have a left operand and now looking for a binary operator | |
if E[i] not in self.operators_binary: | |
return ( | |
None, | |
f"expected end or a binary operator when got '{E[i]}' in expression: {E}", | |
) | |
op = E[i] | |
i = i + 1 | |
elif ( | |
(left is None) and (op in self.operators_unary) and (right is None) | |
) or ( | |
(left is not None) and (op in self.operators_binary) and (right is None) | |
): | |
# we saw a unary operator before and now looking for a right operand, another unary operator, or a nullary operator | |
# OR | |
# we have a left operand and a binary operator before, now looking for a right operand, a unary operator, or a nullary operator | |
if ( | |
E[i] in (self.operators_nullary + self.operators_unary) | |
or E[i] == "(" | |
): | |
right, i = self._traverse_expression_tree( | |
E, i, eval_all=not self.left_associativity, debug=debug | |
) | |
if right is None: | |
return None, i | |
else: | |
right, i = _collect_operands(E, i) | |
# We are now ready to evaluate, because: | |
# we have a unary operator and a right operand | |
# OR | |
# we have a left operand, a binary operator, and a right operand | |
result = _my_eval(op, left, right, debug) | |
if result is None: | |
return ( | |
None, | |
f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", | |
) | |
if not eval_all: | |
return result, i | |
left = result | |
op, right = None, None | |
else: | |
return ( | |
None, | |
f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", | |
) | |
if (op is not None) or (right is not None): | |
return None, f"expression ended unexpectedly" | |
if left is None: | |
return None, f"failed to evaluate expression: {E}" | |
return left, i | |
def evaluate(self, expression: str): | |
"""Evaluates the expression and returns the result.""" | |
def _split_tokens(expr): | |
# first split by parentheses (preserving the parentheses themselves) | |
parts = list(re.split("([()])", expr)) | |
# then split by space (getting rid of space) | |
return [ | |
t.strip() | |
for p in parts | |
for t in re.split("\s+", p.strip()) | |
if t.strip() != "" | |
] | |
# parse expression into tokens | |
E = _split_tokens(expression) | |
val, rem = self._traverse_expression_tree(E, debug=self.debug) | |
if val is None: | |
raise Exception( | |
f"failed to evaluate expression: '{expression}', reason: {rem}" | |
) | |
return val["result"] | |