diff --git "a/chroma/chroma/data/system.py" "b/chroma/chroma/data/system.py" deleted file mode 100644--- "a/chroma/chroma/data/system.py" +++ /dev/null @@ -1,4524 +0,0 @@ -# Copyright Generate Biomedicines, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import copy -import logging -import re -import warnings -from dataclasses import dataclass -from functools import partial -from typing import Dict, List, Tuple - -import numpy as np -import torch - -import chroma.utility.polyseq as polyseq -import chroma.utility.starparser as sp -from chroma import constants - - -@dataclass -class SystemAssemblyInfo: - """A class for representing the assembly information for System objects. - - assemblies (dict): a dictionary of assemblies with keys being assembly IDs - and values being dictionaries with of the following structure: - { - "details": "complete icosahedral assembly", - "instructions": [ - { - "oper_expression": "(1-60)", - "chains": [0, 1, 2], - - # Each assembly instruction has information for generating - # one or more images, with image `i` generated by applying - # the sequence of operations with IDs in `operations[i]` to the - # list of chains in `chains`. The corresponding operations - # are described under `assembly_info["operations"][ID]`. - "operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], - }, - ... - ], - } - - operations (dict): a dictionary with symmetry operations. Keys are operation IDs - and values being dictionaries with the following structure: - { - "type": "identity operation", - "name": "1_555", - "matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), - "vector": np.array([0., 0., 0.]), - }, - ... - """ - - assemblies: dict - operations: dict - - def __init__(self, assemblies: dict = dict(), operations: dict = dict()): - self.assemblies = assemblies - self.operations = operations - - @staticmethod - def make_operation(type: str, name: str, matrix: list, vector: list): - op = { - "type": type, - "name": name, - "matrix": np.zeros([3, 3]), - "vector": np.zeros([3, 1]), - } - assert len(matrix) == 9, "expected 9 elements in rotation matrix" - assert len(vector) == 3, "expected 3 elements in translation vector" - for i in range(3): - op["vector"][i] = float(vector[i]) - for j in range(3): - op["matrix"][i][j] = float(matrix[i * 3 + j]) - return op - - def delete_chain(self, cid: str): - """Deletes the mention of the chain from assembly information. - - Args: - cid (str): Chain ID to delete. - """ - for ass_id, assembly in self.assemblies.items(): - for ins in assembly["instructions"]: - ins["chains"] = [_id for _id in ins["chains"] if _id != cid] - - def rename_chain(self, old_cid: str, new_cid: str): - """Renames all mentions of a chain to its new chain ID. - - Args: - old_cid (str): Chain ID to rename. - new_cid (str): Newly assigned Chain ID. - """ - for ass_id, assembly in self.assemblies.items(): - for ins in assembly["instructions"]: - ins["chains"] = [ - new_cid if cid == old_cid else cid for cid in ins["chains"] - ] - - -class StringList: - """A class for representing and accessing a list of strings in a highly memory-efficient - manner. Access is constant time, but modification is linear time in length of list. - """ - - def __init__(self, init_list: List[str] = []): - self.string = "" - self.rng = ArrayList(2, dtype=int) - for i in range(len(init_list)): - self.append(init_list[i]) - - def __getitem__(self, i: int): - beg, length = self.rng[i] - return self.string[beg : beg + length] - - def __setitem__(self, i: int, new_string: str): - beg, length = self.rng[i] - self.string = self.string[:beg] + new_string + self.string[beg + length :] - if len(new_string) != length: - self.rng[i, 1] = len(new_string) - self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length - - def __str__(self): - return self.string - - def __len__(self): - return len(self.rng) - - def copy(self): - new_list = StringList() - new_list.string = self.string - new_list.rng = self.rng.copy() - return new_list - - def append(self, new_string: str): - self.rng.append([len(self.string), len(new_string)]) - self.string = self.string + new_string - - def insert(self, i: int, new_string: str): - if i < len(self): - ix, _ = self.rng[i] - elif i == len(self): - if len(self) == 0: - ix = 0 - else: - ix = self.rng[i - 1].sum() - else: - raise Exception( - f"cannot insert in position {i} for stringList of length {len(self)}" - ) - self.string = self.string[0:ix] + new_string + self.string[ix:] - self.rng.insert(i, [ix, len(new_string)]) - if len(new_string) > 0: - self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - - def pop(self, i: int): - beg, length = self.rng[i] - val = self.string[beg : beg + length] - self.string = self.string[0:beg] + self.string[beg + length :] - self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) - self.rng.pop(i) - return val - - def delete_range(self, rng: range): - rng = sorted(rng) - [i, j] = [rng[0], rng[-1]] - beg, _ = self.rng[i] - end = self.rng[j].sum() - self.string = self.string[0:beg] + self.string[end:] - self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) - self.rng.delete_range(rng) - - -class NameList: - """A class for representing and accessing a list of "names"--i.e., strings that tend to - have generic values, such that many repeat values are expected in a given list.""" - - def __init__(self, init_list: List[str] = []): - self._reindex(init_list) - - def _reindex(self, init_list: List[str]): - self.unique_names = [] - self.name_indicies = dict() - self.index_use = dict() - self.indices = ArrayList(1, dtype=int) - for name in init_list: - self.append(name) - - def copy(self): - new_list = NameList() - new_list.unique_names = self.unique_names.copy() - new_list.name_indicies = self.name_indicies.copy() - new_list.index_use = self.index_use.copy() - new_list.indices = self.indices.copy() - return new_list - - def _check_index(self): - L = len(self.unique_names) - I = len(self.index_use) - if (L > 2 * I) and (L - I > 10): - self._reindex([self[i] for i in range(len(self))]) - - def __getitem__(self, i: int): - try: - idx = self.indices[i].item() - except IndexError as e: - raise IndexError(f"index {i} out of range for nameList\n" + str(e)) - return self.unique_names[idx] - - def __setitem__(self, i: int, new_name: str): - try: - idx = self.indices[i] - except IndexError as e: - raise IndexError(f"index {i} out of range for nameList\n" + str(e)) - self.index_use[idx] = self.index_use[idx] - 1 - if self.index_use[idx] == 0: - del self.index_use[idx] - if new_name not in self.name_indicies: - idx = len(self.name_indicies) - self.name_indicies[new_name] = idx - self.unique_names.append(new_name) - else: - idx = self.name_indicies[new_name] - self.indices[i] = idx - self._update_use(idx, 1) - self._check_index() - - def __str__(self): - return str([self[i] for i in range(len(self))]) - - def __len__(self): - return len(self.indices) - - def _update_use(self, idx, delta): - self.index_use[idx] = self.index_use.get(idx, 0) + delta - if self.index_use[idx] <= 0: - del self.index_use[idx] - - def _get_name_index(self, name: str): - if name not in self.name_indicies: - idx = len(self.name_indicies) - self.name_indicies[name] = idx - self.unique_names.append(name) - else: - idx = self.name_indicies[name] - return idx - - def append(self, name: str): - idx = self._get_name_index(name) - self.indices.append(idx) - self.index_use[idx] = self.index_use.get(idx, 0) + 1 - - def insert(self, i: int, new_string: str): - idx = self._get_name_index(new_string) - self.indices.insert(i, idx) - self.index_use[idx] = self.index_use.get(idx, 0) + 1 - - def pop(self, i: int): - idx = self.indices.pop(i).item() - val = self.unique_names[idx] - self._update_use(idx, -1) - self._check_index() - return val - - def delete_range(self, rng: range): - for i in reversed(sorted(rng)): - self.pop(i) - - -class ArrayList: - def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): - if ndims == 1: - self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) - else: - self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) - self.ndims = ndims - self._array[:] = val - self.length = length - # view of just the data without the extra allocated stuff - self.array = self._array[: self.length] - - def convert_negative_slice(self, slice_obj): - start = slice_obj.start if slice_obj.start is not None else 0 - stop = slice_obj.stop if slice_obj.stop is not None else self.length - - if start < 0: - start = self.length + start - if stop < 0: - stop = self.length + stop - - return slice(start, stop, slice_obj.step) - - def copy(self): - new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) - new_list[:] = self[:] - return new_list - - def __len__(self): - return self.length - - def capacity(self): - return self._array.shape[0] - - def __getitem__(self, i: int): - return self.array[i] - - def __setitem__(self, i: int, row: list): - self.array[i] = row - - def resize(self, delta): - # for speed, hard-code instead of calling len() and capacity() - new_length = self.length + delta - cap = self._array.shape[0] - if (new_length > cap) or (new_length < cap / 3): - new_capacity = 2 * new_length - self._resize(new_capacity) - self.length = new_length - self.array = self._array[: self.length] - - def _resize(self, new_size): - if self.ndims == 1: - self._array.resize((new_size), refcheck=False) - else: - self._array.resize((new_size, self.ndims), refcheck=False) - - def items(self): - for i in range(self.length): - yield self.array[i, :] - - def append(self, row: list): - self.resize(1) - self.array[-1] = row - - def insert(self, i: int, row: list): - """Insert the row such that it ends up being at index ``i`` in the new arrayList""" - # resize by +1 - self.resize(1) - - # everything in range [i:end-1) moves over by +1 - self.array[i + 1 :] = self.array[i:-1] - - # set the value at index i - self.array[i] = row - - def pop(self, i: int): - """Remove and return element at index i""" - - # get the element at index i - row = self.array[i].copy() - - # everything from [i+1; end) moves over by -1 - self.array[i:-1] = self.array[i + 1 :] - - # resize by -1 - self.resize(-1) - - return row - - def delete_range(self, rng: range): - i, j = min(rng), max(rng) - - # move over to the left to account for the removed part - cut_length = j - i + 1 - new_length = len(self) - cut_length - self.array[i:new_length] = self.array[j + 1 :] - - # resize by -1 - self.resize(-cut_length) - - def __str__(self): - return str([self[i] for i in range(len(self))]) - - -@dataclass -class HierarchicList: - """A utility class that represents a hierarchy of lists. Each level represents - a list of elements, each element having a set of properties (each property being - stored as an array-like object over elements). Further, each element has a number - of children corresponding to a range of elements in a lower-hierarhy list.""" - - _properties: dict - _parent_list: HierarchicList - _child_list: HierarchicList - _num_children: ArrayList # (1, n) - _child_offset: ArrayList # (1, n) - - def __init__( - self, - properties: dict, - parent_list: HierarchicList = None, - num_children: ArrayList = ArrayList(1, dtype=int), - ): - self._properties = dict() - for key in properties: - self._properties[key] = properties[key].copy() - self._parent_list = parent_list - if self._parent_list is not None: - self._parent_list._child_list = self - self._child_list = None - self._num_children = num_children.copy() if num_children is not None else None - # start off with lazy offsets, self.reindex() creates them - self._child_offset = None - - def copy(self): - new_list = HierarchicList( - self._properties, self._parent_list, self._num_children - ) - new_list._child_list = self._child_list - if self._child_offset is None: - new_list._child_offset = None - else: - new_list._child_offset = self._child_offset.copy() - return new_list - - def set_parent(self, parent_list: HierarchicList): - self._parent_list = parent_list - - def child_index(self, i: int, at: int): - if self._child_offset is not None: - return self._child_offset[i] + at - return self._num_children[0:i].sum() + at - - def reindex(self): - if self._num_children is not None: - self._child_offset = ArrayList( - 1, dtype=int, length=len(self._num_children), val=0 - ) - for i in range(1, len(self)): - self._child_offset[i] = ( - self._child_offset[i - 1] + self._num_children[i - 1] - ) - - def append_child(self, properties): - self._num_children[len(self._num_children) - 1] += 1 - self._child_list.append(properties) - - def insert_child(self, i: int, at: int, properties): - idx = self.child_index(i, at) - self._num_children[i] += 1 - self._child_offset = None - self._child_list.insert(idx, properties) - return idx - - def delete_child(self, i: int, at: int): - idx = self.child_index(i, at) - self._num_children[i] -= 1 - self._child_offset = None - self._child_list.delete(idx) - - def append(self, properties): - if set(properties.keys()) != set(self._properties.keys()): - raise Exception(f"unexpected set of attributes '{list(properties.keys())}") - for key, value in properties.items(): - self._properties[key].append(value) - if self._child_offset is not None: - self._child_offset.append( - self._child_offset[-1:].sum() + self._num_children[-1:].sum() - ) - if self._num_children is not None: - self._num_children.append(0) - - def insert(self, i: int, properties): - if set(properties.keys()) != set(self._properties.keys()): - raise Exception(f"unexpected set of attributes '{list(properties.keys())}") - for key, value in properties.items(): - self._properties[key].insert(i, value) - if self._child_offset is not None: - if i >= len(self._child_offset): - off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() - else: - off = self._child_offset[i] - self._child_offset.insert(i, off) - if self._num_children is not None: - self._num_children.insert(i, 0) - - def delete(self, i: int): - for key in self._properties: - self._properties[key].pop(i) - if self._num_children is not None and self._num_children[i] != 0: - for at in range(self._num_children[i] - 1, -1, -1): - self.delete_child(i, at) - self._num_children.pop(i) - self._child_offset = None - - def delete_range(self, rng: range): - for key in self._properties: - self._properties[key].delete_range(rng) - # iterating in descending order so that child offsets remain valid for subsequent elements - for i in reversed(sorted(rng)): - if self._num_children is not None and self._num_children[i] != 0: - idx = self.child_index(i, 0) - self._child_list.delete_range( - self, range(idx, idx + self._num_children[i]) - ) - self._num_children[i] = 0 - self._child_offset = None - - def __len__(self): - for key in self._properties: - return len(self._properties[key]) - return None - - def __getitem__(self, i: str): - return self._properties[i] - - # def __setitem__(self, i: tuple, val): - # self._properties[i[0]][i[1]] = val - - def num_children(self, i: int): - return self._num_children[i] - - def has_children(self, i: int): - return self._num_children is not None and self._num_children[i] - - def __str__(self): - string = "Properties:\n" - for key in self._properties: - string += f"{key}: {str(self._properties[key])}\n" - string += f"num_children: {str(self._num_children)}\n" - string += f"child_offset: {str(self._child_offset)}\n" - string += "----\n" - string += str(self._child_list) - return string - - -@dataclass -class System: - """A class for storing, accessing, managing, and manipulating a molecular - system's structure, sequence, and topological information. The class is - organized as a hierarchy of objects: - - System: top-level class containing all information about a molecular system - -> Chain: a sub-portion of the System; for polymers this is generally a - chemically connected molecular graph belong to a System (e.g., for - protein complexes, this would be one of the proteins). - -> Residue: a generally chemically-connected molecular unit (for polymers, - the repeating unit), belonging to a Chain. - -> Atom: an atom belonging to a Residue with zero, one, or more locations. - -> AtomLocation: the location of an Atom (3D coordinates and other information). - - Attributes: - name (str): given name for System - _chains (list): a list of Chain objects - _entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs - _chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into - `entities`) corresponding to the entity for chain `ci` - _extra_models (list): a list of hierarchicList object, representing locations - for alternative models - _labels (dict): a dictionary of residue labels. A label is a string value, - under some category (also a string), associated with a residue. E.g., - the category could be "SSE" and the value could be "H" or "S". If entry - `labels[category][gti]` exists and is equal to `value`, this means that - residue with global template index `gti` has the label `category:value`. - _selections (dict): a dictionary of selections. Keys are selection names and - values are lists of corresponding gti indices. - _assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can - be constructed from components of the molecular system. See ``SystemAssemblyInfo``. - """ - - name: str - _chains: HierarchicList - _residues: HierarchicList - _atoms: HierarchicList - _locations: HierarchicList - _entities: Dict[int, SystemEntity] - _chain_entities: List[int] - _extra_models: List[HierarchicList] - _labels: Dict[str, Dict[int, str]] - _selections: Dict[str, List[int]] - _assembly_info: SystemAssemblyInfo - - def __init__(self, name: str = "system"): - self.name = name - self._chains = HierarchicList( - properties={ - "cid": StringList(), - "segid": StringList(), - "authid": StringList(), - } - ) - self._residues = HierarchicList( - properties={ - "name": NameList(), - "resnum": ArrayList(1, dtype=int), - "authresid": StringList(), - "icode": ArrayList(1, dtype="U1"), - }, - parent_list=self._chains, - ) - self._atoms = HierarchicList( - properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, - parent_list=self._residues, - ) - self._locations = HierarchicList( - properties={ - "coor": ArrayList(5, dtype=float), - "alt": ArrayList(1, dtype="U1"), - }, - parent_list=self._atoms, - num_children=None, - ) - self._entities = dict() - self._chain_entities = [] - self._extra_models = [] - self._labels = dict() - self._selections = dict() - self._assembly_info = SystemAssemblyInfo() - - def _reindex(self): - self._chains.reindex() - self._residues.reindex() - self._atoms.reindex() - self._locations.reindex() - - def _print_indexing(self): - for chain in self.chains(): - off = self._chains.child_index(chain._ix, 0) - num = self._chains.num_children(chain._ix) - print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") - for residue in chain.residues(): - off = self._residues.child_index(residue._ix, 0) - num = self._residues.num_children(residue._ix) - print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") - for atom in residue.atoms(): - off = self._atoms.child_index(atom._ix, 0) - num = self._atoms.num_children(atom._ix) - print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") - for loc in atom.locations(): - has_children = self._locations.has_children(loc._ix) - print( - f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" - ) - - @classmethod - def from_XCS( - cls, - X: torch.Tensor, - C: torch.Tensor, - S: torch.Tensor, - alternate_alphabet: str = None, - ) -> System: - """Convert an XCS set of pytorch tensors to a new System object. - - B is batch size (Function only handles batch size of one now) - N is the number of residues - - Args: - X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. - `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. - C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes - positions as 0 when masked, positive integers for chain indices, - and negative integers to represent missing residues of the - corresponding positive integers. - S (torch.LongTensor): Sequence with shape `(1, num_residues)`. - alternate_alphabet (str, optional): Optional alternative alphabet for - sequence encoding. Otherwise the default alphabet is set in - `constants.AA20`.Amino acid alphabet for embedding. - Returns: - System: A System object with the new XCS data. - - """ - alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet - all_atom = X.shape[2] == 14 - - assert X.shape[0] == 1 - assert C.shape[0] == 1 - assert S.shape[0] == 1 - assert X.shape[1] == S.shape[1] - assert C.shape[1] == C.shape[1] - - X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] - - chain_ids = np.abs(C) - - atom_count = 0 - new_system = cls("system") - - for i, chain_id in enumerate(np.unique(chain_ids)): - if chain_id == 0: - continue - - chain_bool = chain_ids == chain_id - X_chain = X[chain_bool, :, :].tolist() - C_chain = C[chain_bool].tolist() - S_chain = S[chain_bool].tolist() - - # Build chain - chain = new_system.add_chain("A") - for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): - resname = polyseq.to_triple(alphabet[int(S_i)]) - - # Build residue - residue = chain.add_residue( - resname, chain_ix + 1, str(chain_ix + 1), " " - ) - - if C_i > 0: - atom_names = constants.ATOMS_BB - - if all_atom and resname in constants.AA_GEOMETRY: - atom_names = ( - atom_names + constants.AA_GEOMETRY[resname]["atoms"] - ) - - for atom_ix, atom_name in enumerate(atom_names): - x, y, z = X_i[atom_ix] - atom_count += 1 - residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") - - # add an entity for each chain (copy from chain information) - for ci, chain in enumerate(new_system.chains()): - seq = [None] * chain.num_residues() - het = [None] * chain.num_residues() - for ri, res in enumerate(chain.residues()): - seq[ri] = res.name - het[ri] = all(a.het for a in res.atoms()) - entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) - entity = SystemEntity( - entity_type, f"chain {chain.cid}", polymer_type, seq, het - ) - new_system.add_new_entity(entity, [ci]) - - return new_system - - def to_XCS( - self, - all_atom: bool = False, - batch_dimension: bool = True, - mask_unknown: bool = True, - unknown_token: int = 0, - reorder_chain: bool = True, - alternate_alphabet=None, - alternate_atoms=None, - get_indices=False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Convert System object to XCS format. - - `C` tensor has shape [num_residues], where it codes positions as 0 - when masked, positive integers for chain indices, and negative integers - to represent missing residues of the corresponding positive integers. - - `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. - If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if - also want to mask `unk residue` in `chain_map` - - This function takes into account missing residues and updates chain_map - accordingly. - - Args: - system (type): generate System object to convert. - all_atom (bool): Include side chain atoms. Default is `False`. - batch_dimension (bool): Include a batch dimension. Default is `True`. - mask_unknown (bool): Mask residues not found in the alphabet. Default is - `True`. - unknown_token (int): Default token index if a residue is not found in - the alphabet. Default is `0`. - reorder_chain (bool): If set to true will start indexing chain at 1, - else will use the alphabet index (Default: True) - altenate_alphabet (str): Alternative alphabet if not `None`. - alternate_atoms (list): Alternate atom name subset for `X` if not `None`. - get_indices (bool): Also return the location indices corresponding to the - returned `X` tensor. - - Returns: - X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. - `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. - C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes - positions as 0 when masked, positive integers for chain indices, - and negative integers to represent missing residues of the - corresponding positive integers. - S (torch.LongTensor): Sequence with shape `(1, num_residues)`. - location_indices (np.ndaray, optional): location indices corresponding to - the coordinates in `X`. - - """ - alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet - - # Get chain map grabbing each chain in system and look at length - C = [] - for ch_id, chain in enumerate(self.chains()): - ch_str = chain.cid - if ch_str in list(constants.CHAIN_ALPHABET): - map_ch_id = list(constants.CHAIN_ALPHABET).index(ch_str) - else: - # fmt: off - map_ch_id = np.setdiff1d(np.arange(1, len(constants.CHAIN_ALPHABET)), np.unique(C))[0] - # fmt: on - if reorder_chain: - map_ch_id = ch_id + 1 - C += [map_ch_id] * chain.num_residues() - - # Grab full sequence - oneLetterSeq = self.sequence(format="one-letter-string") - - if len(oneLetterSeq) != len(C): - logging.warning("Warning, System and chain_map length don't agree") - - # Initialize recipient arrays - atom_names = None - if all_atom: - num_atoms = 14 if all_atom else 4 - else: - if alternate_atoms is not None: - atom_names = alternate_atoms - else: - atom_names = constants.ATOMS_BB - num_atoms = len(atom_names) - atom_names = {a: i for (i, a) in enumerate(atom_names)} - num_residues = self.num_residues() - X = np.zeros([num_residues, num_atoms, 3]) - location_indices = ( - np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None - ) - - S = [] # Array will contain sequence indices - for i in range(num_residues): - # If residue should be mask or not - is_mask = False - - # Add sequence - if oneLetterSeq[i] in list(alphabet): - S.append(alphabet.index(oneLetterSeq[i])) - else: - S.append(unknown_token) - if mask_unknown: - is_mask = True - - # Get residue from system - res = self.get_residue(i) - if res is None or not res.has_structure(): - is_mask = True - - # If residue is mask because no structure or not found in alphabet - if is_mask: - # Set chain map to -x - C[i] = -abs(C[i]) - else: - # Loop through atoms - if all_atom: - code3 = constants.AA20_1_TO_3[oneLetterSeq[i]] - atom_names = ( - constants.ATOMS_BB + constants.AA_GEOMETRY[code3]["atoms"] - ) - atom_names = {a: i for (i, a) in enumerate(atom_names)} - - X[ - i, : - ] = np.nan # so we can tell whether some atom was previously found - num_rem = len(atom_names) - for atom in res.atoms(): - name = System.protein_backbone_atom_type(atom.name, False, True) - if name is None: - name = atom.name - ix = atom_names.get(name, None) - if ix is None or not np.isnan(X[i, ix, 0]): - continue - for loc in atom.locations(): - X[i, ix] = loc.coors - if location_indices is not None: - location_indices[i * num_atoms + ix] = loc.get_index() - num_rem -= 1 - break - if num_rem == 0: - break - if num_rem != 0: - C[i] = -abs(C[i]) - X[i, :] = 0 - np.nan_to_num(X[i, :], copy=False, nan=0) - - # Tensor everything - X = torch.tensor(X).float() - C = torch.tensor(C).type(torch.long) - S = torch.tensor(S).type(torch.long) - - # Unsqueeze all the thing - if batch_dimension: - X = X.unsqueeze(0) - C = C.unsqueeze(0) - S = S.unsqueeze(0) - - if location_indices is not None: - return X, C, S, location_indices - - return X, C, S - - def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): - """Update the System with XCS coordinates. NOTE: if the System has - more than one model, and if the shape of the System changes (i.e., - atoms are added or deleted), the additional models will be wiped. - - Args: - X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. - `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. - C (LongTensor): Chain map with shape `(1, num_residues)`. It codes - positions as 0 when masked, positive integers for chain indices, - and negative integers to represent missing residues of the - corresponding positive integers. Defaults to the current System's - chain map. - S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to - the current System's sequence. - """ - if C is None or S is None: - _, _C, _S = self.to_XCS() - if C is None: - C = _C - if S is None: - S = _S - - # check to make sure sizes agree - if not ( - (X.shape[1] == self.num_residues()) - and (X.shape[1] == C.shape[1]) - and (X.shape[1] == S.shape[1]) - ): - raise Exception( - f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" - ) - - def _process_inputs(T): - if T is not None: - if len(T.shape) == 2 or len(T.shape) == 4: - T = T.squeeze(0) - T = T.to("cpu").detach().numpy() - return T - - X, C, S = map(_process_inputs, [X, C, S]) - - shape_changed = False - alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet - for i, res in enumerate(self.residues()): - # atoms to update must have structure and are present in the chain map - if not res.has_structure() or C[i] <= 0: - continue - - # First, determine if the sequence has changed - resname = "UNK" - if S is not None and S[i] < len(alphabet): - resname = polyseq.to_triple(alphabet[S[i]]) - # If the identity changes, rename and delete side chain atoms - if res.name != resname: - res.rename(resname) - - # Second, delete all atoms that are missing in XCS or have duplicate names - atoms_sys = [atom.name for atom in res.atoms()] - atoms_XCS = constants.ATOMS_BB - if resname in constants.AA_GEOMETRY: - atoms_XCS = atoms_XCS + constants.AA_GEOMETRY[resname]["atoms"] - atoms_XCS = atoms_XCS[: X.shape[1]] - to_delete = [] - for ix_a, atom in enumerate(res.atoms()): - name = atom.name - if name not in atoms_XCS or name in atoms_sys[:ix_a]: - to_delete.append(atom) - if len(to_delete) > 0: - shape_changed = True - res.delete_atoms(to_delete) - - # Finally, update all atom coordinates and manufacture any missing atoms - for x_id, atom_name in enumerate(atoms_XCS): - atom = res.find_atom(atom_name) - x, y, z = [X[i][x_id][k].item() for k in range(3)] - if atom is not None and atom.num_locations() > 0: - atom.x = x - atom.y = y - atom.z = z - else: - shape_changed = True - if atom is not None: - atom.add_location(x, y, z) - else: - res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) - - # wipe extra models if the shape of the System changed - if shape_changed: - self._extra_models = [] - - def __str__(self): - return "system " + self.name - - def chains(self): - """Chain iterator (generator function).""" - for ci in range(len(self._chains)): - yield ChainView(ci, self) - - def get_chain(self, ci: int): - """Returns the chain by index. - - Args: - ci (int): Chain index (from 0) - - Returns: - ChainView object corresponding to the chain in question. - """ - return ChainView(ci, self) - - def get_chain_by_id(self, cid: str, segid=False): - """Returns the chain by its string ID. - - Args: - cid (str): Chain ID. - segid (bool, optional): If set to True (default is False) will - return the chain with the matching segment ID and not chain ID. - - Returns: - ChainView object corresponding to the chain in question. - """ - for ci, chain in enumerate(self.chains()): - if (not segid and cid == chain.cid) or (segid and cid == chain.segid): - return ChainView(ci, self) - return None - - def get_chains(self): - """Returns the list of all chains.""" - return [ChainView(ci, self) for ci in range(len(self._chains))] - - def get_chains_of_entity(self, entity_id: int, by=None): - """Returns the list of chains that correspond to the given entity ID. - - Args: - entity_id (int): Entity ID. - by (str, optional): If specified as "index", will return a - list of chain indices instead of ChainView objects. - - Returns: - List of ChainView objects or chain indices. - """ - cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] - if by == "index": - return cixs - return [ChainView(ci, self) for ci in cixs] - - def residues(self): - """Residue iterator (generator function).""" - for chain in self.chains(): - for residue in chain.residues(): - yield residue - - def get_residue(self, gti: int): - """Returns the residue at the given global index. - - Args: - gti (int): Global residue index. - - Returns: - ResidueView object corresponding to the index. - """ - if gti < 0: - raise Exception(f"negative residue index: {gti}") - off = 0 - for chain in self.chains(): - nr = chain.num_residues() - if gti < off + nr: - return chain.get_residue(gti - off) - off = off + nr - raise Exception( - f"residue index {gti} out of range for System, which has {self.num_residues()} residues" - ) - - def atoms(self): - """Iterator of atoms in this System (generator function).""" - for chain in self.chains(): - for residue in chain.residues(): - for atom in residue.atoms(): - yield atom - - def get_atom(self, aidx: int): - """Returns the atom at the given global atom index. - - Args: - gti (int): Global atom index. - - Returns: - AtomView object corresponding to the index. - """ - if aidx < 0: - raise Exception(f"negative atom index: {aidx}") - off = 0 - for chain in self.chains(): - na = chain.num_atoms() - if aidx < off + na: - return chain.get_atom(aidx - off) - off = off + na - raise Exception( - f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" - ) - - def locations(self): - """Iterator of atoms in this System (generator function).""" - for chain in self.chains(): - for residue in chain.residues(): - for atom in residue.atoms(): - for loc in atom.locations(): - yield loc - - def _new_locations(self): - new_locs = self._locations.copy() - for li in range(len(new_locs)): - new_locs["coor"][li] = [np.nan] * 5 - return new_locs - - def select(self, expression: str, left_associativity: bool = True): - """Evalates the given selection expression and returns all atoms - involved in the result as a list of AtomView's. - - Args: - expression (str): selection expression. - left_associativity (bool, optional): determines whether operators - in the expression are left-associative. - - Returns: - List of AtomView's. - """ - val, selex_info = self._select( - expression, left_associativity=left_associativity - ) - - # make a list of AtomView - result = [selex_info["all_atoms"][i].atom for i in sorted(val)] - - return result - - def select_residues( - self, - expression: str, - gti: bool = False, - allow_unstructured=False, - left_associativity: bool = True, - ): - """Evalates the given selection expression and returns all residues with any - atoms involved in the result as a list of ResidueView's or list of gti's. - - Args: - expression (str): selection expression. - gti (bool): if True (default is False), will return a list of gti - instead of a list of ResidueView's. - allow_unstructured (bool): If True (default is False), will allow - unstructured residues to be selected. - left_associativity (bool, optional): determines whether operators - in the expression are left-associative. - - Returns: - List of ResidueView's or gti's (ints). - """ - val, selex_info = self._select( - expression, - unstructured=allow_unstructured, - left_associativity=left_associativity, - ) - - # make a list of ResidueView or gti's - if gti: - result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) - else: - residues = dict() - for i in val: - a = selex_info["all_atoms"][i] - residues[a.rix] = a.atom.residue - result = [residues[rix] for rix in sorted(residues.keys())] - - return result - - def select_chains( - self, expression: str, allow_unstructured=False, left_associativity: bool = True - ): - """Evalates the given selection expression and returns all chains with any - atoms involved in the result as a list of ChainView's. - - Args: - expression (str): selection expression. - allow_unstructured (bool): If True (default is False), will allow - unstructured chains to be selected. - left_associativity (bool, optional): determines whether operators - in the expression are left-associative. - - Returns: - List of ResidueView's or gti's (ints). - """ - val, selex_info = self._select( - expression, - unstructured=allow_unstructured, - left_associativity=left_associativity, - ) - - # make a list of ResidueView or gti's - chains = dict() - for i in val: - a = selex_info["all_atoms"][i] - chains[a.cix] = a.atom.chain - result = [chains[rix] for rix in sorted(chains.keys())] - - return result - - def _select( - self, - expression: str, - unstructured: bool = False, - left_associativity: bool = True, - ): - # Build some helpful data structures to support _selex_eval - @dataclass(frozen=True) - class MappableAtom: - atom: AtomView - aix: int - rix: int - cix: int - - def __hash__(self) -> int: - return self.aix - - # first collect all real atoms - all_atoms = [None] * self.num_atoms() - cix, rix, aix = 0, 0, 0 - for chain in self.chains(): - for residue in chain.residues(): - for atom in residue.atoms(): - all_atoms[aix] = MappableAtom(atom, aix, rix, cix) - aix = aix + 1 - - # for residues that do not have atoms, add a dummy atom with no location - # or name; that way, we can still select the residue even though selection - # algebra fundamentally works on atoms - if residue.num_atoms() == 0: - view = DummyAtomView(residue) - view.dummy = True - # make more room at the end of the list since as this is an "extra" atom - all_atoms.append(None) - all_atoms[aix] = MappableAtom(view, aix, rix, cix) - aix = aix + 1 - rix = rix + 1 - cix = cix + 1 - - _selex_info = {"all_atoms": all_atoms} - _selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) - - # fmt: off - # make an expression tree object - tree = ExpressionTreeEvaluator( - ["hyd", "all", "none"], - ["not", "byres", "bychain", "first", "last", - "chain", "authchain", "segid", "namesel", "gti", "resix", "resid", - "authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], - ["and", "or", "around", "saround"], - eval_function=partial(self._selex_eval, _selex_info), - left_associativity=left_associativity, - debug=False, - ) - # fmt: on - - # evaluate the expression - val = tree.evaluate(expression) - - # if we are not looking to select unstructured residues, remove any dummy - # atoms. NOTE: making dummy atoms can still impact what structured atoms - # are selected (e.g., consider `saround` relative to an unstructured residue) - if not unstructured: - val = { - i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") - } - - return val, _selex_info - - def save_selection( - self, - expression: Optional[str] = None, - gti: Optional[List[int]] = None, - selname: str = "_default", - allow_unstructured=False, - left_associativity: bool = True, - ): - """Performs a selection on the System according to the given - selection string and saves the indices of residues involved in - the result (global template indices) under the given name. - - Args: - expression (str): (optional) selection expression. - gti (list of int): (optional) list of gti to define selection expression - selname (str): selection name. - allow_unstructured (bool): If True (default is False), will allow - unstructured residues to be selected. - left_associativity (bool, optional): determines whether operators - in the expression are left-associative. - """ - if gti is not None: - if expression is not None: - warnings.warn( - f"Expression and gti are both not null, expression will be ignored" - f" and gti will be used!" - ) - result = sorted(gti) - else: - result = self.select_residues( - expression, - allow_unstructured=allow_unstructured, - left_associativity=left_associativity, - gti=True, - ) - - # save the list of gti's - self._selections[selname] = result - - def get_selected(self, selname: str = "_default"): - """Returns the list of gti saved under the specified name. - - Args: - selname (str): selection name. - - Returns: - List of global template indices. - """ - if selname not in self._selections: - raise Exception( - f"selection by name '{selname}' does not exist in the System" - ) - return self._selections[selname] - - def has_selection(self, selname: str = "_default"): - """Returns whether the given named selection exists. - - Args: - selname (str): selection name. - - Returns: - Whether the selection exists in the System. - """ - return selname in self._selections - - def get_selection_names(self): - """Returns the list of all currently stored named selections.""" - return list(self._selections.keys()) - - def remove_selection(self, selname: str = "_default"): - """Deletes the selection under the specified name. - - Args: - selname (str): selection name. - """ - if selname not in self._selections: - raise Exception( - f"selection by name '{selname}' does not exist in the System" - ) - del self._selections[selname] - - def _selex_eval(self, _selex_info, op: str, left, right): - def _is_numeric(string: str) -> bool: - try: - float(string) - return True - except ValueError: - return False - - def _is_int(string: str) -> bool: - try: - int(string) - return True - except ValueError: - return False - - def _unpack_operands(operands, dests): - assert len(operands) == len(dests) - unpacked = [None] * len(operands) - succ = True - for i, (operand, dest) in enumerate(zip(operands, dests)): - if dest is None: - if operand is not None: - succ = False - break - elif dest == "result": - if not (isinstance(operand, dict) and "result" in operand): - succ = False - break - unpacked[i] = operand["result"] - elif dest == "string": - if not (len(operand) == 1 and isinstance(operand[0], str)): - succ = False - break - unpacked[i] = operand[0] - elif dest == "strings": - if not ( - isinstance(operand, list) - and all([isinstance(val, str) for val in operands]) - ): - succ = False - break - unpacked[i] = operands - elif dest == "float": - if not (len(operand) == 1 and _is_numeric(operand[0])): - succ = False - break - unpacked[i] = float(operand[0]) - elif dest == "floats": - if not ( - isinstance(operand, list) - and all([_is_numeric(val) for val in operands]) - ): - succ = False - break - unpacked[i] = [float(val) for val in operands] - elif dest == "range": - test = _parse_range(operand) - if test is None: - succ = False - break - unpacked[i] = test - elif dest == "int": - if not (len(operand) == 1 and _is_int(operand[0])): - succ = False - break - unpacked[i] = int(operand[0]) - elif dest == "ints": - if not ( - isinstance(operand, list) - and all([_is_int(val) for val in operands]) - ): - succ = False - break - unpacked[i] = [int(val) for val in operands] - elif dest == "int_range": - test = _parse_int_range(operand) - if test is None: - succ = False - break - unpacked[i] = test - elif dest == "int_range_string": - test = _parse_int_range(operand, string=True) - if test is None: - succ = False - break - unpacked[i] = test - return unpacked, succ - - def _parse_range(operands: list): - """Parses range information given a list of operands that were originally separated - by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with - optional spaces allowed between operands.""" - if not ( - isinstance(operands, list) - and all([isinstance(opr, str) for opr in operands]) - ): - return None - operand = "".join(operands) - if operand.startswith(">") or operand.startswith("<"): - if not _is_numeric(operand[1:]): - return None - num = float(operand[1:]) - if operand.startswith(">"): - test = lambda x, cut=num: x > cut - else: - test = lambda x, cut=num: x < cut - elif ":" in operand: - parts = operand.split(":") - if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): - return None - parts = [float(p) for p in parts] - test = lambda x, lims=parts: lims[0] < x < lims[1] - elif _is_numeric(operand): - target = float(operand) - test = lambda x, t=target: x == t - else: - return None - return test - - def _parse_int_range(operands: list, string: bool = False): - """Parses range of integers information given a list of operands that were - originally separated by spaces. Allowed range expressiosn are of the form: - `n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations - also allowed (e.g., "n+m+s+r-p+a").""" - if not ( - isinstance(operands, list) - and all([isinstance(opr, str) for opr in operands]) - ): - return None - operand = "".join(operands) - operands = operand.split("+") - ranges = [] - for operand in operands: - m = re.fullmatch("(.*\d)-(.+)", operand) - if m: - if not all([_is_int(g) for g in m.groups()]): - return None - r = range(int(m.group(1)), int(m.group(2)) + 1) - ranges.append(r) - else: - if not _is_int(operand): - return None - if string: - ranges.append(set([operand])) - else: - ranges.append(set([int(operand)])) - if string: - ranges = [[str(x) for x in r] for r in ranges] - test = lambda x, ranges=ranges: any([x in r for r in ranges]) - return test - - # evaluate expression and store result in list `result` - result = set() - if op in ("and", "or"): - (Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) - if not succ: - return None - if op == "and": - result = set(Si).intersection(set(Sj)) - else: - result = set(Si).union(set(Sj)) - elif op == "not": - (_, S), succ = _unpack_operands([left, right], [None, "result"]) - if not succ: - return None - result = _selex_info["all_indices_set"].difference(S) - elif op == "all": - (_, _), succ = _unpack_operands([left, right], [None, None]) - if not succ: - return None - result = _selex_info["all_indices_set"] - elif op == "none": - (_, _), succ = _unpack_operands([left, right], [None, None]) - if not succ: - return None - elif op == "around": - (S, rad), succ = _unpack_operands([left, right], ["result", "float"]) - if not succ: - return None - - # Convert to numpy for distance calculation - atom_indices = np.asarray( - [ - ai.aix - for ai in _selex_info["all_atoms"] - for xi in ai.atom.locations() - ] - ) - X_i = np.asarray( - [ - [xi.x, xi.y, xi.z] - for ai in _selex_info["all_atoms"] - for xi in ai.atom.locations() - ] - ) - X_j = np.asarray( - [ - [xi.x, xi.y, xi.z] - for j in S - for xi in _selex_info["all_atoms"][j].atom.locations() - ] - ) - D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) - ix_match = (D <= rad).sum(1) > 0 - match_hits = atom_indices[ix_match] - result = set(match_hits.tolist()) - elif op == "saround": - (S, srad), succ = _unpack_operands([left, right], ["result", "int"]) - if not succ: - return None - for j in S: - aj = _selex_info["all_atoms"][j] - rj = aj.rix - for ai in _selex_info["all_atoms"]: - if aj.atom.residue.chain != ai.atom.residue.chain: - continue - ri = ai.rix - if abs(ri - rj) <= srad: - result.add(ai.aix) - elif op == "byres": - (_, S), succ = _unpack_operands([left, right], [None, "result"]) - if not succ: - return None - gtis = set() - for j in S: - gtis.add(_selex_info["all_atoms"][j].rix) - for a in _selex_info["all_atoms"]: - if a.rix in gtis: - result.add(a.aix) - elif op == "bychain": - (_, S), succ = _unpack_operands([left, right], [None, "result"]) - if not succ: - return None - cixs = set() - for j in S: - cixs.add(_selex_info["all_atoms"][j].cix) - for a in _selex_info["all_atoms"]: - if a.cix in cixs: - result.add(a.aix) - elif op in ("first", "last"): - (_, S), succ = _unpack_operands([left, right], [None, "result"]) - if not succ: - return None - if op == "first": - mi = min([_selex_info["all_atoms"][i].aix for i in S]) - else: - mi = max([_selex_info["all_atoms"][i].aix for i in S]) - result.add(mi) - elif op == "name": - (_, name), succ = _unpack_operands([left, right], [None, "string"]) - if not succ: - return None - for a in _selex_info["all_atoms"]: - if a.atom.name == name: - result.add(a.aix) - elif op in ("re", "hyd"): - if op == "re": - (_, regex), succ = _unpack_operands([left, right], [None, "string"]) - else: - (_, _), succ = _unpack_operands([left, right], [None, None]) - regex = "[0123456789]?H.*" - if not succ: - return None - ex = re.compile(regex) - for a in _selex_info["all_atoms"]: - if a.atom.name is not None and ex.fullmatch(a.atom.name): - result.add(a.aix) - elif op in ("chain", "authchain", "segid"): - (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) - if not succ: - return None - if op == "chain": - prop = "cid" - elif op == "authchain": - prop = "authid" - elif op == "segid": - prop = "segid" - for a in _selex_info["all_atoms"]: - if getattr(a.atom.residue.chain, prop) == match_id: - result.add(a.aix) - elif op == "resid": - (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) - if not succ: - return None - for a in _selex_info["all_atoms"]: - if test(a.atom.residue.num): - result.add(a.aix) - elif op in ("resname", "icode"): - (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) - if not succ: - return None - if op == "resname": - prop = "name" - elif op == "icode": - prop = "icode" - for a in _selex_info["all_atoms"]: - if getattr(a.atom.residue, prop) == match_id: - result.add(a.aix) - elif op == "authresid": - (_, test), succ = _unpack_operands( - [left, right], [None, "int_range_string"] - ) - if not succ: - return None - for a in _selex_info["all_atoms"]: - if test(a.atom.residue.authid): - result.add(a.aix) - elif op == "gti": - (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) - if not succ: - return None - for a in _selex_info["all_atoms"]: - if test(a.rix): - result.add(a.aix) - elif op in ("x", "y", "z", "b", "occ"): - (_, test), succ = _unpack_operands([left, right], [None, "range"]) - if not succ: - return None - prop = op - if op == "b": - prop = "B" - for a in _selex_info["all_atoms"]: - for loc in a.atom.locations(): - if test(getattr(loc, prop)): - result.add(a.aix) - break - elif op == "namesel": - (_, selname), succ = _unpack_operands([left, right], [None, "string"]) - if not succ: - return None - if selname not in self._selections: - return None - gtis = set(self._selections[selname]) - for a in _selex_info["all_atoms"]: - if a.rix in gtis: - result.add(a.aix) - else: - return None - - return {"result": result} - - def __getitem__(self, chain_idx: int): - """Returns the chain at the given index.""" - return self.get_chain(chain_idx) - - def add_chain( - self, - cid: str, - segid: str = None, - authid: str = None, - entity_id: int = None, - auto_rename: bool = True, - at: int = None, - ): - """Adds a new chain to the System and returns a reference to it. - - Args: - cid (str): Chain ID. - segid (str): Segment ID. - authid (str): Author chain ID. - entity_id (int, optional): Entity ID of the entity corresponding to this chain. - auto_rename (bool, optional): If True, will pick a unique chain ID if the specified - one clashes with an already existing chain. - - Returns: - AtomView object corresponding to the index. - """ - if auto_rename: - cid = self._pick_unique_chain_name(cid) - if segid is None: - segid = cid - if authid is None: - authid = cid - if at is None: - at = self.num_chains() - self._chains.append({"cid": cid, "segid": segid, "authid": authid}) - self._chain_entities.append(entity_id) - else: - self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) - self._chain_entities.insert(at, entity_id) - return ChainView(at, self) - - def _append_residue(self, name: str, num: int, authid: str, icode: str): - """Add a new residue to the end this System. Internal method, do not use. - - Args: - name (str): Residue name. - num (int): Residue number (i.e., residue ID). - authid (str): Author residue ID. - icode (str): Insertion code. - - Returns: - Global index to the newly added residue. - """ - self._chains.append_child( - {"name": name, "resnum": num, "authresid": authid, "icode": icode} - ) - return len(self._residues) - 1 - - def _append_atom( - self, - name: str, - het: bool, - x: float = None, - y: float = None, - z: float = None, - occ: float = None, - B: float = None, - alt: str = None, - ): - """Adds a new atom to the end of this System. Internal method, do not use. - - Args: - name (str): Atom name. - het (bool): Whether it is a hetero-atom. - x, y, z (float): Atom location coordinates. - occ (float): Occupancy. - B (float): B-factor. - alt (str): Alternative position character. - - Returns: - Global index to the newly added atom. - """ - self._residues.append_child({"name": name, "het": het}) - return len(self._atoms) - 1 - - def _append_location(self, x, y, z, occ, B, alt): - """Adds a location to the end of this System. Internal method, do not use. - - Args: - x, y, z (float): coordinates of the location. - occ (float): occupancy for the location. - B (float): B-factor for the location. - alt (str): alternative location character. - - Returns: - Global index to the newly added location. - """ - self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) - return len(self._locations) - 1 - - def add_new_entity(self, entity: SystemEntity, chain_indices: list): - """Adds a new entity to the list contained within the System and - assigns chains with provided indices to this entity. - - Args: - entity (SystemEntity): The new entity to add to the System. - chain_indices (list): a list of Chain indices for chains to - assign to this entity. - - Returns: - The entity ID of the newly added entity. - """ - new_entity_id = len(self._entities) - while new_entity_id in self._entities: - new_entity_id = new_entity_id + 1 - self._entities[new_entity_id] = entity - for ci in chain_indices: - self._chain_entities[ci] = new_entity_id - return new_entity_id - - def delete_entity(self, entity_id: int): - """Deletes the entity with the specified ID. Takes care to unlink - any chains belonging to this entity from it. - - Args: - entity_id (int): Entity ID. - """ - chain_indices = self.get_chains_of_entity(entity_id) - for ci in chain_indices: - self._chain_entities[ci] = None - del self._entities[entity_id] - - def _pick_unique_chain_name(self, hint: str, verbose=False): - goodNames = list( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" - ) - taken = set([chain.cid for chain in self.chains()]) - - # first try to pick a conforming chain name (single alpha-numeric character) - for cid in [hint] + goodNames: - if cid not in taken: - return cid - if verbose: - warnings.warn( - "ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" - ) - - # if that does not work, get a longer chain name - for i in range(-1, len(goodNames)): - # first try to expand the original chain ID - base = hint if i < 0 else goodNames[i : i + 1] - if base == "": - continue - for k in range(1000): - longName = f"{base}{k}" - if longName not in taken: - return longName - raise Exception( - "ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" - ) - - def _ensure_unique_entity(self, ci: int): - """Any time we need to update some piece of information about a Chain that - relates to its entity (e.g., sequence info or hetero info), we cannot just - update it directly because other Chains may be pointing to the same entity. - This function checks for any other chains pointing to the same entity as the - specified chain, and (if so) assigns the given chain to a new (duplicate) - entity and returns its new ID. This clears the way updates of this Chain's entity. - - Args: - ci (int): Index of the Chain for which we are trying to update - entity information. - - Returns: - entity ID for either a newly created entity mapped to the Chain or its - initial entity ID if no other chains point to the same entity. - """ - chain = self.get_chain(ci) - entity_id = chain.get_entity_id() - if entity_id is None: - return entity_id - - # see if any other chains point to the same entity - unique = True - for other in self.chains(): - if (other != chain) and (entity_id == other.get_entity_id()): - unique = False - break - if unique: - return entity_id - - # if so, we need to make a new entity and point the chain to it - new_entity = copy.deepcopy(self._entities[entity_id]) - new_entity_id = self.add_new_entity(new_entity, [ci]) - return new_entity_id - - def num_chains(self): - """Returns the number of chains in the System.""" - return len(self._chains) - - def num_chains_of_entity(self, entity_id: int): - """Returns the number of chains of a given entity. - - Args: - entity_id (int): Entity ID. - - Returns: - number of chains mapping to the entity. - """ - - return sum([entity_id == eid for eid in self._chain_entities]) - - def num_molecules_of_entity(self, entity_id: int): - if self._entities[entity_id].is_polymer(): - return self.num_chains_of_entity(entity_id) - cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] - return sum([self[ci].num_residues() for ci in cixs]) - - def num_entities(self): - """Returns the number of entities in the System.""" - return len(self._entities) - - def num_residues(self): - """Returns the number of residues in the System.""" - return len(self._residues) - - def num_structured_residues(self): - """Returns the number of residues with any structure information.""" - return sum([chain.num_structured_residues() for chain in self.chains()]) - - def num_atoms(self): - """Returns the number of atoms in the System.""" - return len(self._atoms) - - def num_structured_atoms(self): - """Returns the number of atoms with any location information.""" - num = 0 - for chain in self.chains(): - for residue in chain.residues(): - for atom in residue.atoms(): - num = num + (atom.num_locations() > 0) - return num - - def num_atom_locations(self): - """Returns the number of atom locations. Note that an atom can have - multiple (alternative) locations and this functions counts all. - """ - return len(self._locations) - - def num_models(self): - """Returns the number of models in the System. A model is effectively - a conformation of the molecular system and each System object can have - an arbitrary number of different conformations. - """ - return len(self._extra_models) + 1 - - def swap_model(self, i: int): - """Swaps the model at index `i` with the current model (i.e., the - model at index 0). - - Args: - i (int): Model index - """ - if i == 0: - return - if i < 0 or i >= self.num_models(): - raise Exception(f"model index {i} out of range") - tmp = self._locations - self._locations = self._extra_models[i - 1] - self._extra_models[i - 1] = tmp - - def add_model(self, other: System): - """Adds a new model to the System by taking the current model from the - specified System `other`. Note that `other` and the present System - must have the same number of atom locations of matching atom and - residue names. - - Args: - other (System): The System to take the model from. - """ - if len(self._locations) != len(other._locations): - raise Exception( - f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" - ) - self._extra_models.append(other._locations.copy()) - self._extra_models[-1].set_parent(self._atoms) - - def add_model_from_X(self, X: torch.Tensor): - """Adds a new model to the System with given coordinates. - - Args: - X (torch.Tensor): Coordinate tensor of shape - `(residues, atoms (4 or 14), coordinates (3))` - """ - if len(self._locations) != X.numel() / 3: - raise Exception( - f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" - ) - X = X.detach().cpu() - self._extra_models.append(self._locations.copy()) - self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) - return None - - def num_assemblies(self): - """Returns the number of biological assemblies defined in this System.""" - return len(self._assembly_info.assemblies) - - @staticmethod - def from_CIF_string(cif_string: str): - """Initializes and returns a System object from a CIF string.""" - import io - - f = io.StringIO(cif_string) - return System._read_cif(f)[0] - - @staticmethod - def from_CIF(input_file: str): - """Initializes and returns a System object from a CIF file.""" - f = open(input_file, "r") - return System._read_cif(f)[0] - - @staticmethod - def _read_cif(f, strict=False): - def _warn_or_error(strict: bool, msg: str): - if strict: - raise Exception(msg) - else: - warnings.warn(msg) - - is_read = { - part: False for part in ["coors", "entities", "sequence", "entity_poly"] - } - category = "" - (in_loop, success) = (False, True) - peeked = sp.PeekedLine("", 0) - # number of molecules per entity prescribed in the CIF file - num_of_mols = dict() - - system = System("system") - while sp.peek_line(f, peeked): - if peeked.line.startswith("#"): - # nothing to do, skip comments - sp.advance(f, peeked) - elif peeked.line.startswith("data_"): - # nothing to do, this is the beginning of the file - sp.advance(f, peeked) - elif peeked.line.startswith("loop_"): - in_loop = True - category = "" - sp.advance(f, peeked) - else: - (cat, name, val) = ("", "", "") - if peeked.line.startswith("_"): - (cat, name, val) = sp.star_item_parse(peeked.line) - if cat != category: - if category != "": - in_loop = False - category = cat - - if (cat == "_entry") and (name == "id"): - if val != "": - system.name = val - sp.advance(f, peeked) - elif cat == "_entity_poly": - if is_read["entity_poly"]: - raise Exception("entity_poly block encountered multiple times") - tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) - for row in tab: - ent_id = int(row[0]) - 1 - if ent_id not in system._entities: - system._entities[ent_id] = SystemEntity( - None, None, row[1], None, None - ) - else: - system._entities[ent_id]._polymer_type = row[1] - is_read["entity_poly"] = True - elif cat == "_entity": - if is_read["entities"]: - raise Exception( - f"entities block encountered multiple times: {peeked.line}" - ) - tab = sp.star_read_data( - f, - ["id", "type", "pdbx_description", "pdbx_number_of_molecules"], - in_loop, - ) - for row in tab: - ent_id = int(row[0]) - 1 - if ent_id not in system._entities: - system._entities[ent_id] = SystemEntity( - row[1], row[2], None, None, None - ) - else: - system._entities[ent_id]._type = row[1] - system._entities[ent_id]._desc = row[2] - if row[3].isnumeric(): - num_of_mols[ent_id] = int(row[3]) - is_read["entities"] = True - elif cat == "_entity_poly_seq": - if is_read["sequence"]: - raise Exception(f"sequence block encountered multiple times") - tab = sp.star_read_data( - f, ["entity_id", "num", "mon_id", "hetero"], in_loop - ) - (seq, het) = ([], []) - for i in range(len(tab)): - # accumulate sequence information until we reach the end or a new entity ID - seq.append(tab[i][2]) - het.append(tab[i][3].startswith("y")) - if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): - ent_id = int(tab[i][0]) - 1 - system._entities[ent_id]._seq = seq - system._entities[ent_id]._het = het - (seq, het) = ([], []) - is_read["sequence"] = True - elif cat == "_pdbx_struct_assembly": - tab = sp.star_read_data(f, ["id", "details"], in_loop) - for row in tab: - system._assembly_info.assemblies[row[0]] = {"details": row[1]} - elif cat == "_pdbx_struct_assembly_gen": - tab = sp.star_read_data( - f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop - ) - for row in tab: - assembly = system._assembly_info.assemblies[row[0]] - if "instructions" not in assembly: - assembly["instructions"] = [] - chain_ids = [cid.strip() for cid in row[2].strip().split(",")] - assembly["instructions"].append( - {"oper_expression": row[1], "chains": chain_ids} - ) - elif cat == "_pdbx_struct_oper_list": - tab = sp.star_read_data( - f, - [ - "id", - "type", - "name", - "matrix[1][1]", - "matrix[1][2]", - "matrix[1][3]", - "matrix[2][1]", - "matrix[2][2]", - "matrix[2][3]", - "matrix[3][1]", - "matrix[3][2]", - "matrix[3][3]", - "vector[1]", - "vector[2]", - "vector[3]", - ], - in_loop, - ) - for row in tab: - system._assembly_info.operations[ - row[0] - ] = SystemAssemblyInfo.make_operation( - row[1], row[2], row[3:12], row[12:15] - ) - elif cat == "_generate_selections": - tab = sp.star_read_data(f, ["name", "indices"], in_loop) - for row in tab: - system._selections[row[0]] = [ - int(gti.strip()) for gti in row[1].strip().split() - ] - elif cat == "_generate_labels": - tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) - for row in tab: - if row[0] not in system._labels: - system._labels[row[0]] = dict() - idx = int(row[1]) - system._labels[row[0]][int(row[1])] = row[2] - elif cat == "_atom_site": - if is_read["coors"]: - raise Exception(f"ATOM_SITE block encountered multiple times") - # this section is special as it cannot have quoted blocks (because some atom names have the single quote character in them) - tab = sp.star_read_data( - f, - [ - "group_PDB", - "id", - "label_atom_id", - "label_alt_id", - "label_comp_id", - "label_asym_id", - "label_entity_id", - "label_seq_id", - "pdbx_PDB_ins_code", - "Cartn_x", - "Cartn_y", - "Cartn_z", - "occupancy", - "B_iso_or_equiv", - "pdbx_PDB_model_num", - "auth_seq_id", - "auth_asym_id", - ], - in_loop, - cols=False, - has_blocks=False, - ) - - groupCol = 0 - idxCol = 1 - atomNameCol = 2 - altIdCol = 3 - resNameCol = 4 - chainNameCol = 5 - entityIdCol = 6 - seqIdCol = 7 - insCodeCol = 8 - xCol = 9 - yCol = 10 - zCol = 11 - occCol = 12 - bCol = 13 - modelCol = 14 - authSeqIdCol = 15 - authChainNameCol = 16 - - ( - atom, - residue, - chain, - prev_chain, - prev_residue, - prev_atom, - prev_entity_id, - prev_seq_id, - prev_auth_seq_id, - ) = (None, None, None, None, None, None, None, None, None) - loc = None # first model location - aIdx = 0 - for i in range(len(tab)): - if i == 0: - first_model = tab[i][modelCol] - prev_model = first_model - elif (tab[i][modelCol] != prev_model) or ( - tab[i][modelCol] != first_model - ): - if tab[i][modelCol] != prev_model: - aIdx = 0 - num_loc = system.num_atom_locations() - # setting the default value to None allows us to tell when the - # same coordinate in a subsequent model was not specified (e.g., - # when an alternative coordinate is not specified) - system._extra_models.append(system._new_locations()) - prev_model = tab[i][modelCol] - locations_generator = (l for l in system.locations()) - - loc = next(locations_generator, None) - if aIdx >= num_loc: - _warn_or_error( - strict, - f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", - ) - success = False - system._extra_models.clear() - break - - # check that the atoms correspond - same = ( - (loc is not None) - and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) - and (tab[i][resNameCol] == loc.atom.residue.name) - and ( - int( - sp.star_value( - tab[i][seqIdCol], loc.atom.residue.num - ) - ) - == loc.atom.residue.num - ) - and (tab[i][atomNameCol] == loc.atom.name) - ) - if not same: - _warn_or_error( - strict, - f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", - ) - success = False - system._extra_models.clear() - break - - coor = [ - float(tab[i][c]) - for c in [xCol, yCol, zCol, occCol, bCol] - ] - system._extra_models[-1]["coor"][aIdx] = coor - system._extra_models[-1]["alt"][aIdx] = sp.star_value( - tab[i][altIdCol], " " - )[0] - aIdx = aIdx + 1 - continue - - # new chain? - if ( - (chain is None) - or (prev_entity_id != tab[i][entityIdCol]) - or (tab[i][chainNameCol] != chain.cid) - ): - authid = ( - tab[i][authChainNameCol] - if (tab[i][authChainNameCol] != "") - else tab[i][chainNameCol] - ) - chain = system.add_chain( - tab[i][chainNameCol], - tab[i][chainNameCol], - authid, - int(tab[i][entityIdCol]) - 1, - ) - - # new residue - if ( - (residue is None) - or (chain != prev_chain) - or (prev_seq_id != tab[i][seqIdCol]) - or (prev_auth_seq_id != tab[i][authSeqIdCol]) - ): - resnum = ( - int(tab[i][seqIdCol]) - if sp.star_value_defined(tab[i][seqIdCol]) - else chain.num_residues() + 1 - ) - ri = system._append_residue( - tab[i][resNameCol], - resnum, - tab[i][authSeqIdCol], - sp.star_value(tab[i][insCodeCol], " ")[0], - ) - residue = ResidueView(ri, chain) - - # usually will be a new atom, but may be an alternative coordinate - # TODO: this only covers cases where alternative atom coordinates are listed next to each other, - # but sometimes they are not -- need to actively use the altIdCol information - x, y, z, occ, B = [ - float(tab[i][col]) - for col in [xCol, yCol, zCol, occCol, bCol] - ] - alt = sp.star_value(tab[i][altIdCol], " ")[0] - if ( - (atom is None) - or (residue != prev_residue) - or (tab[i][atomNameCol] != atom.name) - ): - ai = system._append_atom( - tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") - ) - atom = AtomView(ai, residue) - system._append_location(x, y, z, occ, B, alt) - - prev_chain = chain - prev_residue = residue - prev_entity_id = tab[i][entityIdCol] - prev_seq_id = tab[i][seqIdCol] - prev_auth_seq_id = tab[i][authSeqIdCol] - is_read["coors"] = True - else: - sp.advance(f, peeked) - - # fill in any "missing" polymer chains (e.g., chains with no visible density - # or known structure, but which are nevertheless present) - for entity_id in num_of_mols: - if system._entities[entity_id].is_polymer(): - rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) - for _ in range(rem): - # the chain will get renamed to avoid clashes - system.add_chain("A", None, None, entity_id, auto_rename=True) - - # fill in missing residues (i.e., those that exist in the entity but not - # the atomistic section) - for chain in system.chains(): - entity = chain.get_entity() - if not entity.is_polymer() or entity._seq is None: - continue - k = 0 - for ri in range(len(entity._seq)): - cur_res = chain.get_residue(k) if k < chain.num_residues() else None - if (cur_res is None) or (cur_res.num > ri + 1): - # insert new residue to correspond to entity monomer with index ri - chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) - elif cur_res.num < ri + 1: - _warn_or_error( - strict, f"inconsistent numbering in chain {chain.cid}" - ) - break - k = k + 1 - - # do an entity-to-structure sequence check for all chains - for chain in system.chains(): - if not chain.check_sequence(): - _warn_or_error( - strict, - f"chain {chain.cid} did not pass sequence check against corresponding entity", - ) - - system._reindex() - return system, success - - @staticmethod - def from_PDB_string(cif_string: str, options=""): - """Initializes and returns a System object from a PDB string.""" - import io - - f = io.StringIO(cif_string) - sys = System._read_pdb(f, options) - sys.name = "from_string" - return sys - - @staticmethod - def from_PDB(input_file: str, options=""): - """Initializes and returns a System object from a PDB file.""" - f = open(input_file, "r") - sys = System._read_pdb(f, options) - sys.name = input_file - return sys - - @staticmethod - def _read_pdb(f, strict=False, options=""): - def _to_float(strval, default): - v = default - try: - v = float(strval) - except: - pass - return v - - last_resnum = None - last_resname = None - last_icode = None - last_chain_id = None - last_alt = None - chain = None - residue = None - - # flag to indicate that chain terminus was reached. Initialize to true so as to create a new chain upon reading the first atom. - ter = True - - # various parsing options (the wonders of dealing with the good-old PDB format) - # and any user-specified overrides - options = options.upper() - # use segment IDs to name chains instead of chain IDs? (useful when the latter - # are absent OR when too many chains, so need multi-letter names) - usese_gid = True if ("USESEGID" in options) else False - - # the PDB file was written by CHARMM (slightly different format) - charmm_format = True if ("CHARMM" in options) else False - - # upon reading, convert from all-hydrogen topology (param22 and higher) to - # the CHARMM19 united-atom topology (matters for HIS protonation states) - charmm19_format = True if ("CHARMM19" in options) else False - - # make sure chain IDs are unique, even if they are not unique in the read file - uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True - - # rename CD in ILE to CD1 (as is standard in PDB, but not some MM packages) - fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True - - # consequtive residues that differ only in their insertion code will be treated - # as separate residues - icodes_as_sep_res = True - - # if true, will not pay attention to TER lines in deciding when chains end/begin - ignore_ter = True if ("IGNORE-TER" in options) else False - - # report various warnings when weird things are found and fixed? - verbose = False if ("QUIET" in options) else True - - chains_to_rename = [] - - # read line by line and build the System - system = System("system") - all_system = system - model_index = 0 - for line in f: - line = line.strip() - if line.startswith("ENDMDL"): - # merge the last read model with the overall System - if model_index: - try: - all_system.add_model(system) - except Exception as e: - warnings.warn( - f"error when adding model {model_index + 1}: {str(e)}, skipping model..." - ) - system = System("system") - model_index = model_index + 1 - last_resnum = None - last_resname = None - last_icode = None - last_chain_id = None - last_alt = None - chain = None - residue = None - continue - if line.startswith("END"): - break - if line.startswith("MODEL"): - # new model - continue - if line.startswith("TER") and not ignore_ter: - ter = True - continue - if not (line.startswith("ATOM") or line.startswith("HETATM")): - continue - - """ Now read atom record. Sometimes PDB lines are too short (if they do not contain some - of the last optional columns). We don't want to read past the end of the string!""" - line += " " * 100 - atominx = int(line[6:11]) - atomname = line[12:16].strip() - alt = line[16:17] - resname = line[17:21].strip() - chain_id = line[21:22].strip() - resnum = int(line[23:27]) if charmm_format else int(line[22:26]) - icode = " " if charmm_format else line[26:27] - x = float(line[30:38]) - y = float(line[38:46]) - z = float(line[46:54]) - seg_id = line[72:76].strip() - B = _to_float(line[60:66], 0.0) - occ = _to_float(line[54:60], 0.0) - het = line.startswith("HETATM") - - # use segment ID's instead of chain ID's? - if usese_gid: - chain_id = seg_id - elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): - # use first character of segment name if no chain name is specified, a segment ID - # is specified, and the latter starts with an alphanumeric character - chain_id = seg_id[0:1] - - # create a new chain object, if necessary - if (chain_id != last_chain_id) or ter: - cid_used = system.get_chain_by_id(chain_id) is not None - chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) - # non-unique chains will be automatically renamed (unless the user specified not to rename chains), BUT we need to - # remember the name that was actually read, since this name is what will be used to determine when the next chain comes - if uniq_chain_ids and cid_used: - chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" - if model_index == 0: - chains_to_rename.append(chain) - if verbose: - warnings.warn( - "chain name '" - + chain_id - + "' was repeated while reading, will rename at the end..." - ) - - # start to count residue numbers in this chain - last_resnum = None - last_resname = None - ter = False - - if charmm19_format: - if resname == "HSE": - resname = "HSD" # neutral HIS, proton on ND1 - if resname == "HSD": - resname = "HIS" # neutral HIS, proton on NE2 - if resname == "HSC": - resname = "HSP" # doubley-protonated +1 HIS - - # many PDB files in the Protein Data Bank call the delta carbon of isoleucine CD1, but - # the convention in basically all MM packages is to call it CD, since there is only one - if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): - atomname = "CD1" - - # if necessary, make a new residue - really_new_atom = True # is this a truely new atom, as opposed to an alternative position? - if ( - (resnum != last_resnum) - or (resname != last_resname) - or (icodes_as_sep_res and (icode != last_icode)) - ): - # this corresponds to a case, where the alternative location flag is being used to - # designate two (or more) different possible amino acids at a particular position - # (e.g., where the density is not clear to assign one). In this case, we shall keep - # only the first option, because we don't know any better. But we need to separate - # this from the case, where we end up here because we are trying to separate residues - # by insertion code. - if ( - (resnum == last_resnum) - and (resname != last_resname) - and (alt != last_alt) - and (not icodes_as_sep_res or (icode == last_icode)) - ): - continue - - residue = chain.add_residue( - resname, chain.num_residues() + 1, str(resnum), icode[0] - ) - elif alt != " ": - # if this is not a new residue AND the alternative location flag is specified, - # figure out if another location for this atom has already been given. If not, - # then treat this as the "primary" location, and whatever other locations - # are specified will be treated as alternatives. - a = residue.find_atom(atomname) - if a is not None: - really_new_atom = False - a.add_location(x, y, z, occ, B, alt[0]) - - # if necessary, make a new atom - if really_new_atom: - a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) - - # remember previous values for determining whether something interesting happens next - last_resnum = resnum - last_icode = icode - last_resname = resname - last_chain_id = chain_id - last_alt = alt - - # take care of renaming any chains that had duplicate IDs - for chain in chains_to_rename: - parts = chain.cid.split("|") - assert ( - len(parts) > 1 - ), "something went wrong when renaming a chain at the end of reading" - name = all_system._pick_unique_chain_name(parts[0], verbose) - chain.cid = name - if len(name): - chain.segid = name - - # add an entity for each chain (copy from chain information) - for ci, chain in enumerate(all_system.chains()): - seq = [None] * chain.num_residues() - het = [None] * chain.num_residues() - for ri, res in enumerate(chain.residues()): - seq[ri] = res.name - het[ri] = all(a.het for a in res.atoms()) - entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) - entity = SystemEntity( - entity_type, f"chain {chain.cid}", polymer_type, seq, het - ) - all_system.add_new_entity(entity, [ci]) - - return all_system - - def to_CIF(self, output_file: str): - """Writes the System to a CIF file.""" - f = open(output_file, "w") - self._write_cif(f) - - def to_CIF_string(self): - """Returns a CIF string representing the System.""" - import io - - f = io.StringIO("") - self._write_cif(f) - cif_str = f.getvalue() - f.close() - return cif_str - - def _write_cif(self, f): - # fmt: off - _specials_atom_names = [ - "MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", - "BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", - "AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" - ] - # fmt: on - _ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] - - def _guess_type(atom_name, res_name): - if len(atom_name) > 0 and atom_name[0] == '"': - atom_name = atom_name.replace('"', "") - if atom_name[:2] in _specials_atom_names: - return atom_name[:2] - else: - if atom_name in _ambiguous_atom_names and res_name == atom_name: - return atom_name - elif atom_name == "UNK": - return "X" - return atom_name[:1] - - entry_id = self.name.strip() - if entry_id == "": - entry_id = "system" - f.write( - "data_GNR8\n#\n" - + "_entry.id " - + sp.star_string_escape(entry_id) - + "\n#\n" - ) - - # write entities table - sp.star_loop_header_write( - f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] - ) - for id, entity in self._entities.items(): - num_mol = self.num_molecules_of_entity(id) - f.write( - f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" - ) - f.write("#\n") - - # write entity polymer sequences - sp.star_loop_header_write( - f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] - ) - for id, entity in self._entities.items(): - if entity._seq is not None: - for i, (res, het) in enumerate(zip(entity._seq, entity._het)): - f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") - f.write("#\n") - - # write entity polymer types - sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) - for id, entity in self._entities.items(): - if entity.is_polymer(): - f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") - f.write("#\n") - - if self.num_assemblies(): - assemblies = self._assembly_info.assemblies - ops = self._assembly_info.operations - # assembly info table - sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) - for assembly_id, assembly in assemblies.items(): - f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") - f.write("#\n") - - # assembly generation instructions table - sp.star_loop_header_write( - f, - "_pdbx_struct_assembly_gen", - ["assembly_id", "oper_expression", "asym_id_list"], - ) - for assembly_id, assembly in assemblies.items(): - for instruction in assembly["instructions"]: - chain_list = ",".join([str(ci) for ci in instruction["chains"]]) - f.write( - f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" - ) - f.write("#\n") - - # symmetry operations table - sp.star_loop_header_write( - f, - "_pdbx_struct_oper_list", - [ - "id", - "type", - "name", - "matrix[1][1]", - "matrix[1][2]", - "matrix[1][3]", - "matrix[2][1]", - "matrix[2][2]", - "matrix[2][3]", - "matrix[3][1]", - "matrix[3][2]", - "matrix[3][3]", - "vector[1]", - "vector[2]", - "vector[3]", - ], - ) - for op_id, op in ops.items(): - f.write( - f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " - ) - f.write( - f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " - ) - f.write( - f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " - ) - f.write( - f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " - ) - f.write( - f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" - ) - f.write("#\n") - - sp.star_loop_header_write( - f, - "_atom_site", - [ - "group_PDB", - "id", - "label_atom_id", - "label_alt_id", - "label_comp_id", - "label_asym_id", - "label_entity_id", - "label_seq_id", - "pdbx_PDB_ins_code", - "Cartn_x", - "Cartn_y", - "Cartn_z", - "occupancy", - "B_iso_or_equiv", - "pdbx_PDB_model_num", - "auth_seq_id", - "auth_asym_id", - "type_symbol", - ], - ) - idx = -1 - for model_index in range(self.num_models()): - self.swap_model(model_index) - for chain, entity_id in zip(self.chains(), self._chain_entities): - authchainid = ( - chain.authid if sp.star_value_defined(chain.authid) else chain.cid - ) - for residue in chain.residues(): - authresid = ( - residue.authid - if sp.star_value_defined(residue.authid) - else residue.num - ) - for atom in residue.atoms(): - idx = idx + 1 - for location in atom.locations(): - # this means this coordinate was not specified for this model - if not location.defined(): - continue - - coor = location.coor_info - f.write("HETATM " if atom.het else "ATOM ") - f.write( - f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " - ) - entity_id_str = ( - f"{entity_id + 1}" if entity_id is not None else "?" - ) - f.write( - f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " - ) - f.write( - f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " - ) - f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") - f.write( - f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" - ) - self.swap_model(model_index) - f.write("#\n") - - # write out selections - if len(self._selections): - sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) - for name, indices in self._selections.items(): - f.write( - f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" - ) - f.write("#\n") - - # write out labels - if len(self._labels): - sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) - for category, label_dict in self._labels.items(): - for gti, label in label_dict.items(): - f.write( - f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" - ) - f.write("#\n") - - def to_PDB(self, output_file: str, options: str = ""): - """Writes the System to a PDB file. - - Args: - output_file (str): output PDB file name. - options (str, optional): a string specifying various options for - the writing process. The presence of certain sub-strings will - trigger specific behaviors. Currently recognized sub-strings - include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", - "NOTER", and "NOALT". This option is case-insensitive. - """ - f = open(output_file, "w") - self._write_pdb(f, options) - - def to_PDB_string(self, options=""): - """Writes the System to a PDB string. The options string has the same - interpretation as with System::toPDB. - """ - import io - - f = io.StringIO("") - self._write_pdb(f, options) - cif_str = f.getvalue() - f.close() - return cif_str - - def _write_pdb(self, f, options=""): - def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): - if rn is None: - rn = loc.atom.residue.name - if ri is None: - ri = loc.atom.residue.num - if an is None: - an = loc.atom.name - icode = loc.atom.residue.icode - cid = loc.atom.residue.chain.cid - if len(cid) > 1: - cid = cid[0] - segid = loc.atom.residue.chain.segid - if len(segid) > 4: - segid = segid[0:4] - - # atom name placement is different when it is 4 characters long - if len(an) < 4: - an_str = " %-.3s" % an - else: - an_str = "%.4s" % an - - # moduli are used to make sure numbers do not go over prescribe field widths - # (this is not enforced by sprintf like with strings) - line = ( - "%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" - % ( - "HETATM" if loc.atom.het else "ATOM ", - ai % 100000, - an_str, - loc.alt, - rn, - cid, - ri % 10000, - icode, - loc.x, - loc.y, - loc.z, - loc.occ, - loc.B, - segid, - ) - ) - - return line - - # various formating options (the wonders of dealing with the good-old PDB format) - # and user-defined overrides - options = options.upper() - # the PDB file is intended for use in CHARMM or some other MM package - charmmFormat = True if "CHARMM" in options else False - - # upon writing, convert from all-hydrogen topology (param 22 and higher) - # to CHARMM19 united-atom topology (matters for HIS protonation states) - charmm19Format = True if "CHARMM19" in options else False - - # upon writing, convert from CHARMM19 united-atom topology to all-hydrogen - # param 22 topology (matters for HIS protonation states). Also works for - # converting generic PDB files downloaded from the PDB. - charmm22Format = True if "CHARMM22" in options else False - - # upon writing, renumber residue and atom names to start from 1 and go in order - renumber = True if "RENUMBER" in options else False - - # do not write END at the end of the PDB file (e.g., useful for - # concatenating chains from several structures) - noend = True if "NOEND" in options else False - - # do not demark the end of each chain with TER (this is not _really_ - # necessary, assuming chain names are unique, and it is sometimes nice - # not to have extra lines other than atoms) - noter = True if "NOTER" in options else False - - # write alternative locations by default - writeAlt = True if "NOALT" in options else False - - # upon writing, convert to a generic PDB naming convention (no - # protonation state specified for HIS) - genericFormat = False - - if charmm19Format and charmm22Format: - raise Exception( - "CHARMM 19 and 22 formatting options cannot be specified together" - ) - - atomIndex = 1 - for ci, chain in enumerate(self.chains()): - for ri, residue in enumerate(chain.residues()): - for ai, atom in enumerate(residue.atoms()): - # dirty details of formating for MM purposes converting - atomname = atom.name - resname = residue.name - if charmmFormat: - if (residue.name == "ILE") and (atom.name == "CD1"): - atomname = "CD" - if (atom.name == "O") and (ri == chain.num_residues() - 1): - atomname = "OT1" - if (atom.name == "OXT") and (ri == chain.num_residues() - 1): - atomname = "OT2" - if residue.name == "HOH": - resname = "TIP3" - - if charmm19Format: - if residue.name == "HSD": # neutral HIS, proton on ND1 - resname = "HIS" - if residue.name == "HSE": # neutral HIS, proton on NE2 - resname = "HSD" - if residue.name == "HSC": # doubley-protonated +1 HIS - resname = "HSP" - elif charmm22Format: - """This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded - * PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen - * topology, HIS protonation state must be explicitly specified, so there is no HIS per se. - * Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one - * does not even really know the protonation state). Whether sometimes people do specify it - * nevertheless, and what naming format they use to do so, I am not sure (welcome to the - * PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to - * HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but - * something needs to be assumed. Doing this renaming will make the PDB file work in MM - * packages with the all-hydrogen model.""" - if residue.name == "HSD": # neutral HIS, proton on NE2 - resname = "HSE" - if residue.name == "HIS": # neutral HIS, proton on ND1 - resname = "HSD" - if residue.name == "HSP": # doubley-protonated +1 HIS - resname = "HSC" - elif genericFormat: - if residue.name in ["HSD", "HSP", "HSE", "HSC"]: - resname = "HIS" - if (residue.name == "ILE") and (atom.name == "CD"): - atomname = "CD1" - - # write the atom line - for li in range(atom.num_locations()): - if renumber: - f.write( - _pdb_line( - atom.get_location(li), - atomIndex, - ri=ri + 1, - rn=resname, - an=atomname, - ) - + "\n" - ) - else: - f.write( - _pdb_line( - atom.get_location(li), - atomIndex, - rn=resname, - an=atomname, - ) - + "\n" - ) - atomIndex = atomIndex + 1 - - if not noter and (ri == chain.num_residues() - 1): - f.write("TER\n") - if not noend and (ci == self.num_chains() - 1): - f.write("END\n") - - def canonicalize_protein( - self, - level=2, - drop_coors_unknowns=False, - drop_coors_missing_backbone=False, - filter_by_entity=False, - ): - """Canonicalize the calling System object (in place) by assuming that it represents - a protein molecular system. Different canonicalization rigor and options - can be specified but are all optional. - - Args: - level (int): Canonicalization level that determines which nonstandard-to-standard - residue mappings are performed. Possible values are 1, 2 or 3, with 2 being - the default and higher values meaning more rigorous (and less conservative) - canonicalization. With level 1, only truly equivalent mappings are performed - (e.g., different His protonation states are mapped to the canonical residue - name HIS that does not specify protonation). Level 2 adds to this some less - exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- - cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds - even less equivalent but still reasonable mappings--i.e., phosphorylated SER, - THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys - to Cys. - drop_coors_unknowns (bool, optional): if True, will discard structural information - for all residues that are not natural or mappable under the current level. - NOTE: any sequence record for these residues (i.e., if they are part of a - polymer entity) will be preserved. - drop_coors_missing_backbone (bool, optional): if True, will discard structural - information for residues that do not have at least the N, CA, C, and O - backbone atoms. Same note applies regarding the full sequence record as in - the above. - filter_by_entity (bool, optional): if True, will remove any chains that do not - represent polymer/polypeptide entities. This is convenient for cases where a - System object has both protein and non-protein components. However, depending - on how the System object was generated, entity metadata may not have been filled, - so applying this canonicalization approach will remove the entire structure. - For this reason, the option is False by default. - """ - - def _mod_to_standard_aa_mappings( - less_standard: bool, almost_standard: bool, standard: bool - ): - # Perfectly corresponding to standard residues - standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} - - # Almost perfectly corresponding to standard residues: - # * MSE -- selenomethyonine; SEC -- selenocysteine - almost_standard_map = {"MSE": "MET", "SEC": "CYS"} - - # A little less perfectly corresponding pairings, but can be acceptable (depends): - # * HIP -- ND1-phosphohistidine; SEP -- phosphoserine; TPO -- phosphothreonine; - # * PTR -- o-phosphotyrosine. - less_standard_map = { - "HIP": "HIS", - "CSX": "CYS", - "SEP": "SER", - "TPO": "THR", - "PTR": "TYR", - } - - ret = dict() - if standard: - ret.update(standard_map) - if almost_standard: - ret.update(almost_standard_map) - if less_standard: - ret.update(less_standard_map) - return ret - - def _to_standard_aa_mappings( - less_standard: bool, almost_standard: bool, standard: bool - ): - # get the mapping between modifications and their corresponding standard forms - mapping = _mod_to_standard_aa_mappings( - less_standard, almost_standard, standard - ) - - # add mapping between standard names and themselves - import chroma.utility.polyseq as polyseq - - for aa in polyseq.canonical_amino_acids(): - mapping[aa] = aa - - return mapping - - less_standard, almost_standard, standard = False, False, False - if level == 3: - less_standard, almost_standard, standard = True, True, True - elif level == 2: - less_standard, almost_standard, standard = False, True, True - elif level == 1: - less_standard, almost_standard, standard = False, False, True - else: - raise Exception(f"unknown canonicalization level {level}") - - to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) - - # NOTE: need to re-implement the canonicalization procedure such that it: - # 1. checks to make sure entity sequence and structure sequence agree (error if not) - # 2. goes over entities and looks for residues to rename, does the renaming on the entities - # and all chains simultaneously (so that no new entities are created) - # 3. then goes over the structured part and fixes atoms - - # For residue renamings, we will first record all edits and will perform them - # afterwards in one go, so we can judge whether any new entities have to be - # created. The dictionary `esidues_to_rename`` will be as follows: - # entity_id: { - # chain_index: [list of (residue index, rew name) tuples] - # } - chains_to_delete = [] - residues_to_rename = dict() - for ci, chain in enumerate(self.chains()): - entity = chain.get_entity() - if filter_by_entity: - if ( - (entity is None) - or (entity._type != "polymer") - or ("polypeptide" not in entity.polymer_time) - ): - chains_to_delete.append(chain) - continue - - # iterate in reverse order so we can safely delete any residues we find necessary - cleared_residues = 0 - for residue in reversed(list(chain.residues())): - aa = residue.name - delete_atoms = False - # canonicalize amino acid (delete structure if unknown, provided this was asked for) - if aa in to_standard: - aa_new = to_standard[aa] - if aa != aa_new: - # edit any atoms to reflect the mutation - if ( - (aa == "HSD") - or (aa == "HSE") - or (aa == "HSC") - or (aa == "HSP") - ) and (aa_new == "HIS"): - pass - elif ((aa == "MSE") and (aa_new == "MET")) or ( - (aa == "SEC") and (aa_new == "CYS") - ): - SE = residue.find_atom("SE") - if SE is not None: - if aa == "MSE": - SE.residue.rename("SD") - else: - SE.residue.rename("SG") - elif ( - ((aa == "HIP") and (aa_new == "HIS")) - or ((aa == "SEP") and (aa_new == "SER")) - or ((aa == "TPO") and (aa_new == "THR")) - or ((aa == "PTR") and (aa_new == "TYR")) - ): - # delete the phosphate group - for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: - a = residue.find_atom(atomname) - if a is not None: - a.delete() - elif (aa == "CSX") and (aa_new == "CYS"): - a = residue.find_atom("OD") - if a is not None: - a.delete() - - # record residue renaming operation to be done later - entity_id = chain.get_entity_id() - if entity_id is None: - residue.rename(aa_new) - else: - if entity_id not in residues_to_rename: - residues_to_rename[entity_id] = dict() - if ci not in residues_to_rename[entity_id]: - residues_to_rename[entity_id][ci] = list() - residues_to_rename[entity_id][ci].append( - (residue.get_index_in_chain(), aa_new) - ) - else: - if aa == "ARG": - A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} - for an in A: - atom = residue.find_atom(an) - if atom is not None and atom.num_locations(): - A[an] = atom.get_location(0) - if all([a is not None for n, a in A.items()]): - dihe1 = System.dihedral( - A["CD"], A["NE"], A["CZ"], A["NH1"] - ) - dihe2 = System.dihedral( - A["CD"], A["NE"], A["CZ"], A["NH2"] - ) - if abs(dihe1) > abs(dihe2): - A["NH1"].name = "NH2" - A["NH2"].name = "NH1" - elif drop_coors_unknowns: - delete_atoms = True - - if not drop_coors_missing_backbone: - if not delete_atoms and not residue.has_full_backbone(): - delete_atoms = True - - if delete_atoms: - residue.delete_atoms() - cleared_residues += 1 - - # If we have deleted all residues in this chain, then this is probably not - # a protein chain, so get rid of it. Unless we are asked to pay attention to - # the entity type (i.e., whether it is peptidic), in which case the decision - # of whether to keep the chain would have been made previously. - if ( - not filter_by_entity - and (cleared_residues != 0) - and (cleared_residues == chain.num_residues()) - ): - chains_to_delete.append(chain) - - # rename residues differently depending on whether all chains of a given entity - # have the same set of renamings - for entity_id, ops in residues_to_rename.items(): - chain_indices = set(ops.keys()) - entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) - unique_renames = set([tuple(v) for v in ops.values()]) - fork = True - if (chain_indices == entity_chains) and (len(unique_renames) == 1): - # we can rename without updating entities, because all entity chains are updated the same way - fork = False - for ci, renames in ops.items(): - chain = self.get_chain(ci) - for ri, new_name in renames: - chain.get_residue(ri).rename(new_name, fork_entity=fork) - - # now delete any chains - for chain in reversed(chains_to_delete): - chain.delete() - - self._reindex() - - def sequence(self, format="three-letter-list"): - """Returns the full sequence of this System, concatenated over all - chains in their order within the System. - - Args: - format (str): sequence format. Possible options are either - "three-letter-list" (default) or "one-letter-string". - - Returns: - List (default) or string. - """ - if format == "three-letter-list": - seq = [] - else: - seq = "" - - for chain in self.chains(): - seq = seq + chain.sequence(format) - return seq - - @staticmethod - def distance(a1: AtomLocationView, a2: AtomLocationView): - """Computes the distance between atom locations `a1` and `a2`.""" - v21 = a1.coors - a2.coors - return np.linalg.norm(v21) - - @staticmethod - def angle( - a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False - ): - """Computes the angle formed by three 3D points represented by AtomLocationView objects. - - Args: - a1, a2, a3 (AtomLocationView): three 3D points. - radian (bool, optional): if True (default False), will return the angle in radians. - Otherwise, in degrees. - - Returns: - Angle `a1`-`a2`-`a3`. - """ - v21 = a1.coors - a2.coors - v23 = a3.coors - a2.coors - v21 = v21 / np.linalg.norm(v21) - v23 = v23 / np.linalg.norm(v23) - c = np.dot(v21, v23) - return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) - - @staticmethod - def dihedral( - a1: AtomLocationView, - a2: AtomLocationView, - a3: AtomLocationView, - a4: AtomLocationView, - radians=False, - ): - """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. - - Args: - a1, a2, a3, a4 (AtomLocationView): four 3D points. - radian (bool, optional): if True (default False), will return the angle in radians. - Otherwise, in degrees. - - Returns: - Dihedral angle `a1`-`a2`-`a3`-`a4`. - """ - AB = a1.coors - a2.coors - CB = a3.coors - a2.coors - DC = a4.coors - a3.coors - - if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: - raise Exception("some points coincide in dihedral calculation") - - ABxCB = np.cross(AB, CB) - ABxCB = ABxCB / np.linalg.norm(ABxCB) - DCxCB = np.cross(DC, CB) - DCxCB = DCxCB / np.linalg.norm(DCxCB) - - # the following is necessary for values very close to 1 but just above - dotp = np.dot(ABxCB, DCxCB) - if dotp > 1.0: - dotp = 1.0 - elif dotp < -1.0: - dotp = -1.0 - - angle = np.arccos(dotp) - if np.dot(ABxCB, DC) > 0: - angle *= -1 - if not radians: - angle *= 180.0 / np.pi - - return angle - - @staticmethod - def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): - """Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. - Specifically, possible known names in each category are: - 'N', 'NT' - 'CA', 'C', 'CY', 'CAY' - 'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' - 'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' - """ - array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] - if atom_name in ["N", "NT"]: - return array[0] - if atom_name == "CA": - return array[1] - if (atom_name == "C") or (atom_name == "CY"): - return array[2] - if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): - return array[3] - if not no_hyd: - if atom_name in ["H", "HA", "HN"]: - return array[4] - if atom_name.startswith("HT") or atom_name.startswith("HY"): - return array[4] - # Rosetta's N-terinal amine has hydrogens named 1H, 2H, and 3H - if ( - atom_name.startswith("1H") - or atom_name.startswith("2H") - or atom_name.startswith("3H") - ): - return array[4] - return None - - -@dataclass -class SystemEntity: - """A molecular entity represented in a molecular system.""" - - _type: str - _desc: str - _polymer_type: str - _seq: list - _het: list - - def is_polymer(self): - """Returns whether the entity represents a polymer.""" - return self._type == "polymer" - - @classmethod - def guess_entity_and_polymer_type(cls, seq: List): - is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 - polymer_type = None - if is_poly: - entity_type = "polymer" - for ptype in polyseq.polymerType: - if ( - np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) - > 0.8 - ): - polymer_type = polyseq.polymer_type_name(ptype) - break - else: - entity_type = "unknown" - - return entity_type, polymer_type - - @property - def type(self): - return self._type - - @property - def description(self): - return self._desc - - @property - def polymer_type(self): - return self._polymer_type - - @property - def sequence(self): - return self._seq - - @property - def hetero(self): - return self._het - - -@dataclass -class BaseView: - """An abstract base "view" class for accessing different parts of System.""" - - _ix: int - _parent: object - - def get_index(self): - """Return the index of this atom location in its System.""" - return self._ix - - def is_valid(self): - return self._ix >= 0 and self._parent is not None - - def _delete(self): - at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) - self.parent._siblings.delete_child(self.parent._ix, at) - - @property - def parent(self): - return self._parent - - -@dataclass -class ChainView(BaseView): - """A Chain view, allowing hierarchical exploration and editing.""" - - def __init__(self, ix: int, system: System): - self._ix = ix - self._parent = system - self._siblings = system._chains - - def __str__(self): - return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" - - def residues(self): - for rn in range(self.num_residues()): - ri = self._siblings.child_index(self._ix, rn) - yield ResidueView(ri, self) - - def num_residues(self): - """Returns the number of residues in the Chain.""" - return self._siblings.num_children(self._ix) - - def num_structured_residues(self): - return sum([res.has_structure() for res in self.residues()]) - - def num_atoms(self): - return sum([res.num_atoms() for res in self.residues()]) - - def num_atom_locations(self): - return sum([res.num_atom_locations() for res in self.residues()]) - - def sequence(self, format="three-letter-list"): - """Returns the sequence of this chain. See `System::sequence()` for - possible formats. - """ - if format == "three-letter-list": - seq = [None] * self.num_residues() - for ri, residue in enumerate(self.residues()): - seq[ri] = residue.name - return seq - elif format == "one-letter-string": - import chroma.utility.polyseq as polyseq - - seq = [None] * self.num_residues() - for ri, residue in enumerate(self.residues()): - seq[ri] = polyseq.to_single(residue.name) - return "".join(seq) - else: - raise Exception(f"unknown sequence format {format}") - - def get_residue(self, ri: int): - """Get the residue at the specified index within the Chain. - - Args: - ri (int): Residue index within the Chain. - - Returns: - ResidueView object corresponding to the residue in question. - """ - if ri < 0 or ri >= self.num_residues(): - raise Exception( - f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" - ) - ri = self._siblings.child_index(self._ix, ri) - return ResidueView(ri, self) - - def get_residue_index(self, residue: ResidueView): - """Get the index of the given residue in this Chain.""" - return residue._ix - self._siblings.child_index(self._ix, 0) - - def get_atom(self, aidx: int): - """Get the atom at index `aidx` within this chain.""" - if aidx < 0: - raise Exception(f"negative atom index: {aidx}") - off = 0 - for residue in self.residues(): - na = residue.num_atoms() - if aidx < off + na: - return residue.get_atom(aidx - off) - off = off + na - raise Exception( - f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" - ) - - def get_atoms(self): - """Return a list of all atoms in this chain.""" - atoms_views = [] - for residue in self.residues(): - atoms_views.extend(residue.get_atoms()) - return atoms_views - - def __getitem__(self, res_idx: int): - return self.get_residue(res_idx) - - def get_entity_id(self): - """Return the entity ID corresponding to this chain.""" - return self.system._chain_entities[self._ix] - - def get_entity(self): - """Return the entity this chain belongs to.""" - entity_id = self.get_entity_id() - if entity_id is None: - return None - return self.system._entities[entity_id] - - def check_sequence(self): - """Compare the list of residue names of this chain to the corresponding entity sequence record.""" - entity = self.get_entity() - if entity is not None and entity.is_polymer(): - if self.num_residues() != len(entity._seq): - return False - for res, ent_aan in zip(self.residues(), entity._seq): - if res.name != ent_aan: - return False - return True - - def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): - """Add a new residue to this chain. - - Args: - name (str): Residue name. - num (int): Residue number (i.e., residue ID). - authid (str): Author residue ID. - icode (str): Insertion code. - at (int, optional): Index at which to insert the residue. Default - is to append to the end of the chain (i.e., equivalent of ``at` - being equal to the present length of the chain). - """ - if at is None: - at = self.num_residues() - ri = self._siblings.insert_child( - self._ix, - at, - {"name": name, "resnum": num, "authresid": authid, "icode": icode}, - ) - return ResidueView(ri, self) - - def delete(self, keep_entity=False): - """Deletes this Chain from its System. - - Args: - keep_entity (bool, optional): If False (default) and if the chain - being deleted happens to be the last representative of the - entity it belongs to, the entity will be deleted. If True, the - entity will always be kept. - """ - # delete the mention of the chain from assembly information - self.system._assembly_info.delete_chain(self.cid) - - # optionally, delete the corresponding entity if no other chains point to it - if not keep_entity: - eid = self.get_entity_id() - if self.system.num_chains_of_entity(eid) == 0: - self.system.delete_entity(eid) - - self.system._chain_entities.pop(self._ix) - self._siblings.delete(self._ix) - self._ix = -1 # invalidate the view - - @property - def system(self): - return self._parent - - @property - def cid(self): - return self._siblings["cid"][self._ix] - - @property - def segid(self): - return self._siblings["segid"][self._ix] - - @property - def authid(self): - return self._siblings["authid"][self._ix] - - @cid.setter - def cid(self, val): - self._siblings["cid"][self._ix] = val - - @segid.setter - def segid(self, val): - self._siblings["segid"][self._ix] = val - - @authid.setter - def authid(self, val): - self._siblings["authid"][self._ix] = val - - -@dataclass -class ResidueView(BaseView): - """A Residue view, allowing hierarchical exploration and editing.""" - - def __init__(self, ix: int, chain: ChainView): - self._ix = ix - self._parent = chain - self._siblings = chain.system._residues - - def __str__(self): - return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" - - def atoms(self): - off = self._siblings.child_index(self._ix, 0) - for an in range(self.num_atoms()): - yield AtomView(off + an, self) - - def num_atoms(self): - return self._siblings.num_children(self._ix) - - def num_atom_locations(self): - return sum([a.num_locations() for a in self.atoms()]) - - def has_structure(self): - """Returns whether the atom has any structural information (i.e., one or more locations).""" - for a in self.atoms(): - if a.num_locations(): - return True - return False - - def get_atom(self, ai: int): - """Get the atom at the specified index within the Residue. - - Args: - atom_idx (int): Atom index within the Residue. - - Returns: - AtomView object corresponding to the atom in question. - """ - - if ai < 0 or ai >= self.num_atoms(): - raise Exception( - f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" - ) - ai = self._siblings.child_index(self._ix, ai) - return AtomView(ai, self) - - def get_atom_index(self, atom: AtomView): - """Get the index of the given atom in this Residue.""" - return atom._ix - self._siblings.child_index(self._ix, 0) - - def find_atom(self, name): - """Find and return the first atom (as AtomView object) with the given name - within the Residue or None.""" - for atom in self.atoms(): - if atom.name == name: - return atom - return None - - def __getitem__(self, atom_idx: int): - return self.get_atom(atom_idx) - - def get_index_in_chain(self): - """Return the index of the Residue in its parent Chain.""" - return self.chain.get_residue_index(self) - - def rename(self, new_name: str, fork_entity=True): - """Assigns the residue a new name with all proper updates. - - Args: - new_name (str): New residue name. - fork_entity (bool, optional): If True (default) and if parent - chain corresponds to an entity that has other chains - associated with it and there is a real renaming (i.e., - the old name is not the same as the new name), will - make a new (duplicate) entity for to this chain and - will edit the new one, leaving the old one unchanged. - If False, will not perform this regardless. NOTE: - setting this to False can create an inconsistent state - between chain and entity sequence information. - """ - entity_id = self.chain.get_entity_id() - if entity_id is not None: - entity = self.system._entities[entity_id] - ri = self.get_index_in_chain() - if fork_entity and (entity._seq[ri] != new_name): - ci = self.chain.get_index() - entity_id = self.system._ensure_unique_entity(ci) - entity = self.system._entities[entity_id] - entity._seq[ri] = new_name - self._siblings["name"][self._ix] = new_name - - def add_atom( - self, - name: str, - het: bool, - x: float = None, - y: float = None, - z: float = None, - occ: float = 1.0, - B: float = 0.0, - alt: str = " ", - at=None, - ): - """Adds a new atom to the residue (appending it at the end) and - returns an AtomView to it. If atom location information is - specified, will also add a location to the atom. - - Args: - name (str): Atom name. - het (bool): Whether it is a hetero-atom. - x, y, z (float): Atom location coordinates. - occ (float): Occupancy. - B (float): B-factor. - alt (str): Alternative position character. - at (int, optional): Index at which to insert the atom. Default - is to append to the end of the residue (i.e., equivalent of - ``at` being equal to the number of atoms in the residue). - - Returns: - AtomView object corresponding to the newly added atom. - """ - if at is None: - at = self.num_atoms() - ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) - atom = AtomView(ai, self) - - # now add a location to this atom - if x is not None: - atom.add_location(x, y, z, occ, B, alt) - - return atom - - def delete(self, fork_entity=True): - """Deletes this residue from its Chain/System. - - Args: - fork_entity (bool, optional): If True (default) and if parent - chain corresponds to an entity that has other chains - associated with it, will make a new (duplicate) entity - for to this chain and will edit the new one, leaving the - old one unchanged. If False, will not perform this. - NOTE: setting this to False can create an inconsistent state - between chain and entity sequence information. - """ - # update the entity (duplicating, if necessary) - entity_id = self.chain.get_entity_id() - if entity_id is not None: - entity = self.system._entities[entity_id] - ri = self.get_index_in_chain() - if fork_entity: - ci = self.chain.get_index() - entity_id = self.system._ensure_unique_entity(ci) - entity = self.system._entities[entity_id] - entity._seq.pop(ri) - - # delete the residue - self._delete() - self._ix = -1 # invalidate the view - - def delete_atoms(self, atoms=None): - """Delete either the specified list of atoms or all atoms from the residue. - - Args: - atoms (list, optional): List of AtomView objects corresponding to the - atoms to delete. If not specified, will delete all atoms in the residue. - """ - if atoms is None: - atoms = list(self.atoms()) - for atom in reversed(atoms): - if atom.residue != self: - raise Exception(f"Atom {atom} does not belong to Residue {self}") - atom.delete() - - @property - def chain(self): - return self._parent - - @property - def system(self): - return self.chain.system - - @property - def name(self): - return self._siblings["name"][self._ix] - - @property - def num(self): - return self._siblings["resnum"][self._ix] - - @property - def authid(self): - return self._siblings["authresid"][self._ix] - - @property - def icode(self): - return self._siblings["icode"][self._ix] - - def get_backbone(self, no_hyd=True): - """Assuming that this is a protein residue (i.e., an amino acid), returns the - list of atoms corresponding to the residue's backbone, in the order: - backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), - and amide hydrogen (H, optional). - - Args: - no_hyd (bool, optional): If True (default), will exclude the amide hydrogen - and only return four atoms. If False, will include the amide hydrogen. - - Returns: - A list with each entry being an AtomView object corresponding to the backbone - atom in the order above or None if the atom does not exist in the residue. - """ - bb = [None] * (4 if no_hyd else 5) - left = len(bb) - for atom in self.atoms(): - i = System.protein_backbone_atom_type(atom.name, no_hyd) - if i is None or bb[i] is not None: - continue - bb[i] = atom - left = left - 1 - if left == 0: - break - return bb - - def has_full_backbone(self, no_hyd=True): - """Assuming that this is a protein residue (i.e., an amino acid), returns - whether the residue harbors a structurally defined backbone (i.e., has - all backbone atoms each of which has location information). - - Args: - no_hyd (bool, optional): If True (default), will ignore whether the amide - hydrogen exists or not (if False will consider it). - - Returns: - Boolean indicating whether there is a full backbone in the residue. - """ - bb = self.get_backbone(no_hyd) - return all([(a is not None) and a.num_locations() for a in bb]) - - def delete_non_backbone(self, no_hyd=True): - """Assuming that this is a protein residue (i.e., an amino acid), deletes - all atoms except backbone atoms. - - Args: - no_hyd (bool, optional): If True (default), will not consider the amide - hydrogen as a backbone atom (if False will consider it). - """ - to_delete = [] - for atom in self.atoms(): - if System.protein_backbone_atom_type(atom.name, no_hyd) is None: - to_delete.append(atom) - self.delete_atoms(to_delete) - - -@dataclass -class AtomView(BaseView): - """An Atom view, allowing hierarchical exploration and editing.""" - - def __init__(self, ix: int, residue: ResidueView): - self._ix = ix - self._parent = residue - self._siblings = residue.system._atoms - - def __str__(self): - string = self.name + (" (HET) " if self.het else " ") - if self.num_locations() > 0: - string = string + str(self.get_location(0)) - string = string + f" ({self.num_locations()})" - return string + " -> " + str(self.residue) - - def locations(self): - off = self._siblings.child_index(self._ix, 0) - for ln in range(self.num_locations()): - yield AtomLocationView(off + ln, self) - - def num_locations(self): - return self._siblings.num_children(self._ix) - - def __getitem__(self, loc_idx: int): - return self.get_location(loc_idx) - - def get_location(self, li: int = 0): - """Returns the (li+1)-th location of the atom.""" - if li < 0 or li >= self.num_locations(): - raise Exception( - f"location index {li} out of range for Atom with {self.num_locations()} locations" - ) - li = self._siblings.child_index(self._ix, li) - return AtomLocationView(li, self) - - def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): - """Adds a location to this atom, append it to the end. - - Args: - x, y, z (float): coordinates of the location. - occ (float): occupancy for the location. - B (float): B-factor for the location. - alt (str): alternative location character. - at (int, optional): Index at which to insert the location. Default - is to append at the end (i.e., equivalent of ``at` being equal - to the current number of locations). - """ - if at is None: - at = self.num_locations() - li = self._siblings.insert_child( - self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} - ) - return AtomLocationView(li, self) - - def delete(self): - """Deletes this atom from its Residue/Chain/System.""" - self._delete() - self._ix = -1 # invalidate the view - - @property - def residue(self): - return self._parent - - @property - def chain(self): - return self.residue.chain - - @property - def system(self): - return self.chain.system - - @property - def name(self): - return self._siblings["name"][self._ix] - - @property - def het(self): - return self._siblings["het"][self._ix] - - """Location information getters and setters operate on the default (first) - location for this atom and throw an index error if there are no locations.""" - - @property - def x(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 0] - - @property - def y(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 1] - - @property - def z(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 2] - - @property - def coors(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 0:3] - - @property - def occ(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 3] - - @property - def B(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["coor"][ix, 4] - - @property - def alt(self): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - return self.system._locations["alt"][ix] - - @x.setter - def x(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["coor"][ix, 0] = val - - @y.setter - def y(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["coor"][ix, 1] = val - - @z.setter - def z(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["coor"][ix, 2] = val - - @occ.setter - def occ(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["coor"][ix, 3] = val - - @B.setter - def B(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["coor"][ix, 4] = val - - @alt.setter - def alt(self, val): - if self._siblings.num_children(self._ix) == 0: - raise Exception("atom has no locations") - ix = self._siblings.child_index(self._ix, 0) - self.system._locations["alt"][ix] = val - - -class DummyAtomView(AtomView): - """An dummy Atom view that can be attached to a residue but that does not - have any locations and with no other information.""" - - def __init__(self, residue: ResidueView): - self._ix = -1 - self._parent = residue - - def __str__(self): - return "DUMMY -> " + str(self.residue) - - def locations(self): - return - yield - - def num_locations(self): - return 0 - - def __getitem__(self, loc_idx: int): - return None - - def get_location(self, li: int = 0): - raise Exception(f"no locations in DUMMY atom") - - def add_location(self, x, y, z, occ, B, alt, at=None): - raise Exception(f"can't add no locations to DUMMY atom") - - @property - def residue(self): - return self._parent - - @property - def chain(self): - return self.residue.chain - - @property - def system(self): - return self.chain.system - - @property - def name(self): - return None - - @property - def het(self): - return None - - @property - def x(self): - raise Exception(f"no coordinates in DUMMY atom") - - @property - def y(self): - raise Exception(f"no coordinates in DUMMY atom") - - @property - def z(self): - raise Exception(f"no coordinates in DUMMY atom") - - @property - def occ(self): - raise Exception(f"no occupancy in DUMMY atom") - - @property - def B(self): - raise Exception(f"no B-factor in DUMMY atom") - - @property - def alt(self): - raise Exception(f"no alt flag in DUMMY atom") - - @x.setter - def x(self, val): - raise Exception(f"can't set coordinate for DUMMY atom") - - @y.setter - def y(self, val): - raise Exception(f"can't set coordinate for DUMMY atom") - - @z.setter - def z(self, val): - raise Exception(f"can't set coordinate for DUMMY atom") - - @occ.setter - def occ(self, val): - raise Exception(f"can't set occupancy for DUMMY atom") - - @B.setter - def B(self, val): - raise Exception(f"can't set B-factor for DUMMY atom") - - @alt.setter - def alt(self, val): - raise Exception(f"can't set alt flag for DUMMY atom") - - -@dataclass -class AtomLocationView(BaseView): - """An AtomLocation view, allowing hierarchical exploration and editing.""" - - def __init__(self, ix: int, atom: AtomView): - self._ix = ix - self._parent = atom - self._siblings = atom.system._locations - - def __str__(self): - return f"{self.x} {self.y} {self.z}" - - def swap(self, other: AtomLocationView): - """Swaps information between itself and the provided atom location. - - Args: - other (AtomLocationView): the other atom location to swap with. - """ - self.x, other.x = other.x, self.x - self.y, other.y = other.y, self.y - self.z, other.z = other.z, self.z - self.occ, other.occ = other.occ, self.occ - self.B, other.B = other.B, self.B - self.alt, other.alt = other.alt, self.alt - - def defined(self): - """Return whether this is a valid location.""" - return (self.x is not None) and (self.y is not None) and (self.z is not None) - - @property - def atom(self): - return self._parent - - @property - def residue(self): - return self.atom.residue - - @property - def chain(self): - return self.residue.chain - - @property - def system(self): - return self.chain.system - - @property - def x(self): - return self.system._locations["coor"][self._ix, 0] - - @property - def y(self): - return self.system._locations["coor"][self._ix, 1] - - @property - def z(self): - return self.system._locations["coor"][self._ix, 2] - - @property - def occ(self): - return self.system._locations["coor"][self._ix, 3] - - @property - def B(self): - return self.system._locations["coor"][self._ix, 4] - - @property - def alt(self): - return self.system._locations["alt"][self._ix] - - @property - def coors(self): - return np.array(self.system._locations["coor"][self._ix, 0:3]) - - @property - def coor_info(self): - return np.array(self.system._locations["coor"][self._ix]) - - @x.setter - def x(self, val): - self.system._locations["coor"][self._ix, 0] = val - - @y.setter - def y(self, val): - self.system._locations["coor"][self._ix, 1] = val - - @z.setter - def z(self, val): - self.system._locations["coor"][self._ix, 2] = val - - @coors.setter - def coors(self, val): - self.system._locations["coor"][self._ix, 0:3] = val - - @coor_info.setter - def coor_info(self, val): - self.system._locations["coor"][self._ix] = val - - @occ.setter - def occ(self, val): - self.system._locations["coor"][self._ix, 3] = val - - @B.setter - def B(self, val): - self.system._locations["coor"][self._ix, 4] = val - - @alt.setter - def alt(self, val): - self.system._locations["alt"][self._ix] = val - - -class ExpressionTreeEvaluator: - """A class for evaluating custom logical parenthetical expressions. The - implementation is very generic, supports nullary, unary, and binary - operators, and does not know anything about what the expressions actually - mean. Instead the class interprets the expression as a tree of sub- - expressions, governed by parentheses and operators, and traverses the - calling upon a user-specified evaluation function to evaluate leaf - nodes as the tree is gradually collapsed into a single node. This - can be used for evaluating set expressions, algebraic expressions, and - others. - - Args: - operators_nullary (list): A list of strings designating nullary operators - (i.e., operators that do not have any operands). E.g., if the language - describes selection algebra, these could be "hyd", "all", or "none"]. - operators_unary (list): A list of strings designating unary operators - (i.e., operators that have one operand, which must comes to the right - of the operator). E.g., if the language describes selection algebra, - these could be "name", "resid", or "chain". - operators_binary (list): A list of strings designating binary operators - (i.e., operators that have two operands, one on each side of the - operator). E.g., if the language describes selection algebra, thse - could be "and", "or", or "around". - eval_function (str): A function that is able to evaluate a leaf node of - the expression tree. It shall accept three parameters: - - operator (str): name of the operator - left: the left operand. Will be None if the left operand is missing or - not relevant. Otherwise, can be either a list of strings, which - should represent an evaluatable sub-expression corresponding to the - left operand, or the result of a prior evaluation of this function. - right: Same as `left` but for the right operand. - - The function should attempt to evaluate the resulting expression and - return None in the case of failing or a dictionary with the result of - the evaluation stored under key "result". - left_associativity (bool): If True (the default), operators are taken to be - left-associative. Meaning something like "A and B or C" is "(A and B) or C". - If False, the operators are taken to be right-associative, such that - the same expression becomes "A and (B or C)". NOTE: MST is right-associative - but often human intiution tends to be left-associative. - debug (bool): If True (default is false), will print a great deal of debugging - messages to help diagnose any evaluation problems. - """ - - def __init__( - self, - operators_nullary: list, - operators_unary: list, - operators_binary: list, - eval_function: function, - left_associativity: bool = True, - debug: bool = False, - ): - self.operators_nullary = operators_nullary - self.operators_unary = operators_unary - self.operators_binary = operators_binary - self.operators = operators_nullary + operators_unary + operators_binary - self.eval_function = eval_function - self.debug = debug - self.left_associativity = left_associativity - - def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): - def _collect_operands(E, j): - # collect all operands before hitting an operator - operands = [] - for k in range(len(E[j:])): - if E[j + k] in self.operators: - k = k - 1 - break - operands.append(E[j + k]) - return operands, j + k + 1 - - def _find_matching_close_paren(E, beg: int): - c = 0 - for i in range(beg, len(E)): - if E[i] == "(": - c = c + 1 - elif E[i] == ")": - c = c - 1 - if c == 0: - return i - return None - - def _my_eval(op, left, right, debug=False): - if debug: - print( - f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" - ) - result = self.eval_function(op, left, right) - if debug: - print(f"\t-> got result {operand_str(result)}") - return result - - def operand_str(operand): - if isinstance(operand, dict): - if "result" in operand and len(operand["result"]) > 15: - vec = list(operand["result"]) - beg = ", ".join([str(i) for i in vec[:5]]) - end = ", ".join([str(i) for i in vec[-5:]]) - return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" - return str(operand) - return str(operand) - - left, right, op = None, None, None - if debug: - print(f"-> received {E[i:]}") - - while i < len(E): - if all([x is None for x in (left, right, op)]): - # first part can either be a left parenthesis, a left operand, a nullary operator, or a unary operator - if E[i] == "(": - end = _find_matching_close_paren(E, i) - if end is None: - return None, f"parenthesis imbalance starting with {E[i:]}" - # evaluate expression inside the parentheses, and it becomes the left operand - left, rem = self._traverse_expression_tree( - E[i + 1 : end], 0, eval_all=True, debug=debug - ) - if left is None: - return None, rem - i = end + 1 - if not eval_all: - return left, i - elif E[i] in self.operators_nullary: - # evaluate nullary op - left = _my_eval(E[i], None, None, debug) - if left is None: - return None, f"failed to evaluate nullary operator '{E[i]}'" - i = i + 1 - elif E[i] in self.operators_unary: - op = E[i] - i = i + 1 - elif E[i] in self.operators: - # an operator other than a unary operator cannot appear first - return None, f"unexpected binary operator in the context {E[i:]}" - else: - # if not an operator, then we are looking at operand(s) - left, i = _collect_operands(E, i) - elif (left is not None) and (op is None) and (right is None): - # we have a left operand and now looking for a binary operator - if E[i] not in self.operators_binary: - return ( - None, - f"expected end or a binary operator when got '{E[i]}' in expression: {E}", - ) - op = E[i] - i = i + 1 - elif ( - (left is None) and (op in self.operators_unary) and (right is None) - ) or ( - (left is not None) and (op in self.operators_binary) and (right is None) - ): - # we saw a unary operator before and now looking for a right operand, another unary operator, or a nullary operator - # OR - # we have a left operand and a binary operator before, now looking for a right operand, a unary operator, or a nullary operator - if ( - E[i] in (self.operators_nullary + self.operators_unary) - or E[i] == "(" - ): - right, i = self._traverse_expression_tree( - E, i, eval_all=not self.left_associativity, debug=debug - ) - if right is None: - return None, i - else: - right, i = _collect_operands(E, i) - - # We are now ready to evaluate, because: - # we have a unary operator and a right operand - # OR - # we have a left operand, a binary operator, and a right operand - result = _my_eval(op, left, right, debug) - if result is None: - return ( - None, - f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", - ) - if not eval_all: - return result, i - left = result - op, right = None, None - - else: - return ( - None, - f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", - ) - - if (op is not None) or (right is not None): - return None, f"expression ended unexpectedly" - if left is None: - return None, f"failed to evaluate expression: {E}" - - return left, i - - def evaluate(self, expression: str): - """Evaluates the expression and returns the result.""" - - def _split_tokens(expr): - # first split by parentheses (preserving the parentheses themselves) - parts = list(re.split("([()])", expr)) - # then split by space (getting rid of space) - return [ - t.strip() - for p in parts - for t in re.split("\s+", p.strip()) - if t.strip() != "" - ] - - # parse expression into tokens - E = _split_tokens(expression) - val, rem = self._traverse_expression_tree(E, debug=self.debug) - if val is None: - raise Exception( - f"failed to evaluate expression: '{expression}', reason: {rem}" - ) - - return val["result"]