diff --git "a/chroma/chroma/data/system.py" "b/chroma/chroma/data/system.py" new file mode 100644--- /dev/null +++ "b/chroma/chroma/data/system.py" @@ -0,0 +1,4524 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import logging +import re +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Dict, List, Tuple + +import numpy as np +import torch + +import chroma.utility.polyseq as polyseq +import chroma.utility.starparser as sp +from chroma import constants + + +@dataclass +class SystemAssemblyInfo: + """A class for representing the assembly information for System objects. + + assemblies (dict): a dictionary of assemblies with keys being assembly IDs + and values being dictionaries with of the following structure: + { + "details": "complete icosahedral assembly", + "instructions": [ + { + "oper_expression": "(1-60)", + "chains": [0, 1, 2], + + # Each assembly instruction has information for generating + # one or more images, with image `i` generated by applying + # the sequence of operations with IDs in `operations[i]` to the + # list of chains in `chains`. The corresponding operations + # are described under `assembly_info["operations"][ID]`. + "operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], + }, + ... + ], + } + + operations (dict): a dictionary with symmetry operations. Keys are operation IDs + and values being dictionaries with the following structure: + { + "type": "identity operation", + "name": "1_555", + "matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), + "vector": np.array([0., 0., 0.]), + }, + ... + """ + + assemblies: dict + operations: dict + + def __init__(self, assemblies: dict = dict(), operations: dict = dict()): + self.assemblies = assemblies + self.operations = operations + + @staticmethod + def make_operation(type: str, name: str, matrix: list, vector: list): + op = { + "type": type, + "name": name, + "matrix": np.zeros([3, 3]), + "vector": np.zeros([3, 1]), + } + assert len(matrix) == 9, "expected 9 elements in rotation matrix" + assert len(vector) == 3, "expected 3 elements in translation vector" + for i in range(3): + op["vector"][i] = float(vector[i]) + for j in range(3): + op["matrix"][i][j] = float(matrix[i * 3 + j]) + return op + + def delete_chain(self, cid: str): + """Deletes the mention of the chain from assembly information. + + Args: + cid (str): Chain ID to delete. + """ + for ass_id, assembly in self.assemblies.items(): + for ins in assembly["instructions"]: + ins["chains"] = [_id for _id in ins["chains"] if _id != cid] + + def rename_chain(self, old_cid: str, new_cid: str): + """Renames all mentions of a chain to its new chain ID. + + Args: + old_cid (str): Chain ID to rename. + new_cid (str): Newly assigned Chain ID. + """ + for ass_id, assembly in self.assemblies.items(): + for ins in assembly["instructions"]: + ins["chains"] = [ + new_cid if cid == old_cid else cid for cid in ins["chains"] + ] + + +class StringList: + """A class for representing and accessing a list of strings in a highly memory-efficient + manner. Access is constant time, but modification is linear time in length of list. + """ + + def __init__(self, init_list: List[str] = []): + self.string = "" + self.rng = ArrayList(2, dtype=int) + for i in range(len(init_list)): + self.append(init_list[i]) + + def __getitem__(self, i: int): + beg, length = self.rng[i] + return self.string[beg : beg + length] + + def __setitem__(self, i: int, new_string: str): + beg, length = self.rng[i] + self.string = self.string[:beg] + new_string + self.string[beg + length :] + if len(new_string) != length: + self.rng[i, 1] = len(new_string) + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length + + def __str__(self): + return self.string + + def __len__(self): + return len(self.rng) + + def copy(self): + new_list = StringList() + new_list.string = self.string + new_list.rng = self.rng.copy() + return new_list + + def append(self, new_string: str): + self.rng.append([len(self.string), len(new_string)]) + self.string = self.string + new_string + + def insert(self, i: int, new_string: str): + if i < len(self): + ix, _ = self.rng[i] + elif i == len(self): + if len(self) == 0: + ix = 0 + else: + ix = self.rng[i - 1].sum() + else: + raise Exception( + f"cannot insert in position {i} for stringList of length {len(self)}" + ) + self.string = self.string[0:ix] + new_string + self.string[ix:] + self.rng.insert(i, [ix, len(new_string)]) + if len(new_string) > 0: + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) + + def pop(self, i: int): + beg, length = self.rng[i] + val = self.string[beg : beg + length] + self.string = self.string[0:beg] + self.string[beg + length :] + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) + self.rng.pop(i) + return val + + def delete_range(self, rng: range): + rng = sorted(rng) + [i, j] = [rng[0], rng[-1]] + beg, _ = self.rng[i] + end = self.rng[j].sum() + self.string = self.string[0:beg] + self.string[end:] + self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) + self.rng.delete_range(rng) + + +class NameList: + """A class for representing and accessing a list of "names"--i.e., strings that tend to + have generic values, such that many repeat values are expected in a given list.""" + + def __init__(self, init_list: List[str] = []): + self._reindex(init_list) + + def _reindex(self, init_list: List[str]): + self.unique_names = [] + self.name_indicies = dict() + self.index_use = dict() + self.indices = ArrayList(1, dtype=int) + for name in init_list: + self.append(name) + + def copy(self): + new_list = NameList() + new_list.unique_names = self.unique_names.copy() + new_list.name_indicies = self.name_indicies.copy() + new_list.index_use = self.index_use.copy() + new_list.indices = self.indices.copy() + return new_list + + def _check_index(self): + L = len(self.unique_names) + I = len(self.index_use) + if (L > 2 * I) and (L - I > 10): + self._reindex([self[i] for i in range(len(self))]) + + def __getitem__(self, i: int): + try: + idx = self.indices[i].item() + except IndexError as e: + raise IndexError(f"index {i} out of range for nameList\n" + str(e)) + return self.unique_names[idx] + + def __setitem__(self, i: int, new_name: str): + try: + idx = self.indices[i] + except IndexError as e: + raise IndexError(f"index {i} out of range for nameList\n" + str(e)) + self.index_use[idx] = self.index_use[idx] - 1 + if self.index_use[idx] == 0: + del self.index_use[idx] + if new_name not in self.name_indicies: + idx = len(self.name_indicies) + self.name_indicies[new_name] = idx + self.unique_names.append(new_name) + else: + idx = self.name_indicies[new_name] + self.indices[i] = idx + self._update_use(idx, 1) + self._check_index() + + def __str__(self): + return str([self[i] for i in range(len(self))]) + + def __len__(self): + return len(self.indices) + + def _update_use(self, idx, delta): + self.index_use[idx] = self.index_use.get(idx, 0) + delta + if self.index_use[idx] <= 0: + del self.index_use[idx] + + def _get_name_index(self, name: str): + if name not in self.name_indicies: + idx = len(self.name_indicies) + self.name_indicies[name] = idx + self.unique_names.append(name) + else: + idx = self.name_indicies[name] + return idx + + def append(self, name: str): + idx = self._get_name_index(name) + self.indices.append(idx) + self.index_use[idx] = self.index_use.get(idx, 0) + 1 + + def insert(self, i: int, new_string: str): + idx = self._get_name_index(new_string) + self.indices.insert(i, idx) + self.index_use[idx] = self.index_use.get(idx, 0) + 1 + + def pop(self, i: int): + idx = self.indices.pop(i).item() + val = self.unique_names[idx] + self._update_use(idx, -1) + self._check_index() + return val + + def delete_range(self, rng: range): + for i in reversed(sorted(rng)): + self.pop(i) + + +class ArrayList: + def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): + if ndims == 1: + self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) + else: + self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) + self.ndims = ndims + self._array[:] = val + self.length = length + # view of just the data without the extra allocated stuff + self.array = self._array[: self.length] + + def convert_negative_slice(self, slice_obj): + start = slice_obj.start if slice_obj.start is not None else 0 + stop = slice_obj.stop if slice_obj.stop is not None else self.length + + if start < 0: + start = self.length + start + if stop < 0: + stop = self.length + stop + + return slice(start, stop, slice_obj.step) + + def copy(self): + new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) + new_list[:] = self[:] + return new_list + + def __len__(self): + return self.length + + def capacity(self): + return self._array.shape[0] + + def __getitem__(self, i: int): + return self.array[i] + + def __setitem__(self, i: int, row: list): + self.array[i] = row + + def resize(self, delta): + # for speed, hard-code instead of calling len() and capacity() + new_length = self.length + delta + cap = self._array.shape[0] + if (new_length > cap) or (new_length < cap / 3): + new_capacity = 2 * new_length + self._resize(new_capacity) + self.length = new_length + self.array = self._array[: self.length] + + def _resize(self, new_size): + if self.ndims == 1: + self._array.resize((new_size), refcheck=False) + else: + self._array.resize((new_size, self.ndims), refcheck=False) + + def items(self): + for i in range(self.length): + yield self.array[i, :] + + def append(self, row: list): + self.resize(1) + self.array[-1] = row + + def insert(self, i: int, row: list): + """Insert the row such that it ends up being at index ``i`` in the new arrayList""" + # resize by +1 + self.resize(1) + + # everything in range [i:end-1) moves over by +1 + self.array[i + 1 :] = self.array[i:-1] + + # set the value at index i + self.array[i] = row + + def pop(self, i: int): + """Remove and return element at index i""" + + # get the element at index i + row = self.array[i].copy() + + # everything from [i+1; end) moves over by -1 + self.array[i:-1] = self.array[i + 1 :] + + # resize by -1 + self.resize(-1) + + return row + + def delete_range(self, rng: range): + i, j = min(rng), max(rng) + + # move over to the left to account for the removed part + cut_length = j - i + 1 + new_length = len(self) - cut_length + self.array[i:new_length] = self.array[j + 1 :] + + # resize by -1 + self.resize(-cut_length) + + def __str__(self): + return str([self[i] for i in range(len(self))]) + + +@dataclass +class HierarchicList: + """A utility class that represents a hierarchy of lists. Each level represents + a list of elements, each element having a set of properties (each property being + stored as an array-like object over elements). Further, each element has a number + of children corresponding to a range of elements in a lower-hierarhy list.""" + + _properties: dict + _parent_list: HierarchicList + _child_list: HierarchicList + _num_children: ArrayList # (1, n) + _child_offset: ArrayList # (1, n) + + def __init__( + self, + properties: dict, + parent_list: HierarchicList = None, + num_children: ArrayList = ArrayList(1, dtype=int), + ): + self._properties = dict() + for key in properties: + self._properties[key] = properties[key].copy() + self._parent_list = parent_list + if self._parent_list is not None: + self._parent_list._child_list = self + self._child_list = None + self._num_children = num_children.copy() if num_children is not None else None + # start off with lazy offsets, self.reindex() creates them + self._child_offset = None + + def copy(self): + new_list = HierarchicList( + self._properties, self._parent_list, self._num_children + ) + new_list._child_list = self._child_list + if self._child_offset is None: + new_list._child_offset = None + else: + new_list._child_offset = self._child_offset.copy() + return new_list + + def set_parent(self, parent_list: HierarchicList): + self._parent_list = parent_list + + def child_index(self, i: int, at: int): + if self._child_offset is not None: + return self._child_offset[i] + at + return self._num_children[0:i].sum() + at + + def reindex(self): + if self._num_children is not None: + self._child_offset = ArrayList( + 1, dtype=int, length=len(self._num_children), val=0 + ) + for i in range(1, len(self)): + self._child_offset[i] = ( + self._child_offset[i - 1] + self._num_children[i - 1] + ) + + def append_child(self, properties): + self._num_children[len(self._num_children) - 1] += 1 + self._child_list.append(properties) + + def insert_child(self, i: int, at: int, properties): + idx = self.child_index(i, at) + self._num_children[i] += 1 + self._child_offset = None + self._child_list.insert(idx, properties) + return idx + + def delete_child(self, i: int, at: int): + idx = self.child_index(i, at) + self._num_children[i] -= 1 + self._child_offset = None + self._child_list.delete(idx) + + def append(self, properties): + if set(properties.keys()) != set(self._properties.keys()): + raise Exception(f"unexpected set of attributes '{list(properties.keys())}") + for key, value in properties.items(): + self._properties[key].append(value) + if self._child_offset is not None: + self._child_offset.append( + self._child_offset[-1:].sum() + self._num_children[-1:].sum() + ) + if self._num_children is not None: + self._num_children.append(0) + + def insert(self, i: int, properties): + if set(properties.keys()) != set(self._properties.keys()): + raise Exception(f"unexpected set of attributes '{list(properties.keys())}") + for key, value in properties.items(): + self._properties[key].insert(i, value) + if self._child_offset is not None: + if i >= len(self._child_offset): + off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() + else: + off = self._child_offset[i] + self._child_offset.insert(i, off) + if self._num_children is not None: + self._num_children.insert(i, 0) + + def delete(self, i: int): + for key in self._properties: + self._properties[key].pop(i) + if self._num_children is not None and self._num_children[i] != 0: + for at in range(self._num_children[i] - 1, -1, -1): + self.delete_child(i, at) + self._num_children.pop(i) + self._child_offset = None + + def delete_range(self, rng: range): + for key in self._properties: + self._properties[key].delete_range(rng) + # iterating in descending order so that child offsets remain valid for subsequent elements + for i in reversed(sorted(rng)): + if self._num_children is not None and self._num_children[i] != 0: + idx = self.child_index(i, 0) + self._child_list.delete_range( + self, range(idx, idx + self._num_children[i]) + ) + self._num_children[i] = 0 + self._child_offset = None + + def __len__(self): + for key in self._properties: + return len(self._properties[key]) + return None + + def __getitem__(self, i: str): + return self._properties[i] + + # def __setitem__(self, i: tuple, val): + # self._properties[i[0]][i[1]] = val + + def num_children(self, i: int): + return self._num_children[i] + + def has_children(self, i: int): + return self._num_children is not None and self._num_children[i] + + def __str__(self): + string = "Properties:\n" + for key in self._properties: + string += f"{key}: {str(self._properties[key])}\n" + string += f"num_children: {str(self._num_children)}\n" + string += f"child_offset: {str(self._child_offset)}\n" + string += "----\n" + string += str(self._child_list) + return string + + +@dataclass +class System: + """A class for storing, accessing, managing, and manipulating a molecular + system's structure, sequence, and topological information. The class is + organized as a hierarchy of objects: + + System: top-level class containing all information about a molecular system + -> Chain: a sub-portion of the System; for polymers this is generally a + chemically connected molecular graph belong to a System (e.g., for + protein complexes, this would be one of the proteins). + -> Residue: a generally chemically-connected molecular unit (for polymers, + the repeating unit), belonging to a Chain. + -> Atom: an atom belonging to a Residue with zero, one, or more locations. + -> AtomLocation: the location of an Atom (3D coordinates and other information). + + Attributes: + name (str): given name for System + _chains (list): a list of Chain objects + _entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs + _chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into + `entities`) corresponding to the entity for chain `ci` + _extra_models (list): a list of hierarchicList object, representing locations + for alternative models + _labels (dict): a dictionary of residue labels. A label is a string value, + under some category (also a string), associated with a residue. E.g., + the category could be "SSE" and the value could be "H" or "S". If entry + `labels[category][gti]` exists and is equal to `value`, this means that + residue with global template index `gti` has the label `category:value`. + _selections (dict): a dictionary of selections. Keys are selection names and + values are lists of corresponding gti indices. + _assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can + be constructed from components of the molecular system. See ``SystemAssemblyInfo``. + """ + + name: str + _chains: HierarchicList + _residues: HierarchicList + _atoms: HierarchicList + _locations: HierarchicList + _entities: Dict[int, SystemEntity] + _chain_entities: List[int] + _extra_models: List[HierarchicList] + _labels: Dict[str, Dict[int, str]] + _selections: Dict[str, List[int]] + _assembly_info: SystemAssemblyInfo + + def __init__(self, name: str = "system"): + self.name = name + self._chains = HierarchicList( + properties={ + "cid": StringList(), + "segid": StringList(), + "authid": StringList(), + } + ) + self._residues = HierarchicList( + properties={ + "name": NameList(), + "resnum": ArrayList(1, dtype=int), + "authresid": StringList(), + "icode": ArrayList(1, dtype="U1"), + }, + parent_list=self._chains, + ) + self._atoms = HierarchicList( + properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, + parent_list=self._residues, + ) + self._locations = HierarchicList( + properties={ + "coor": ArrayList(5, dtype=float), + "alt": ArrayList(1, dtype="U1"), + }, + parent_list=self._atoms, + num_children=None, + ) + self._entities = dict() + self._chain_entities = [] + self._extra_models = [] + self._labels = dict() + self._selections = dict() + self._assembly_info = SystemAssemblyInfo() + + def _reindex(self): + self._chains.reindex() + self._residues.reindex() + self._atoms.reindex() + self._locations.reindex() + + def _print_indexing(self): + for chain in self.chains(): + off = self._chains.child_index(chain._ix, 0) + num = self._chains.num_children(chain._ix) + print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") + for residue in chain.residues(): + off = self._residues.child_index(residue._ix, 0) + num = self._residues.num_children(residue._ix) + print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") + for atom in residue.atoms(): + off = self._atoms.child_index(atom._ix, 0) + num = self._atoms.num_children(atom._ix) + print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") + for loc in atom.locations(): + has_children = self._locations.has_children(loc._ix) + print( + f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" + ) + + @classmethod + def from_XCS( + cls, + X: torch.Tensor, + C: torch.Tensor, + S: torch.Tensor, + alternate_alphabet: str = None, + ) -> System: + """Convert an XCS set of pytorch tensors to a new System object. + + B is batch size (Function only handles batch size of one now) + N is the number of residues + + Args: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + alternate_alphabet (str, optional): Optional alternative alphabet for + sequence encoding. Otherwise the default alphabet is set in + `constants.AA20`.Amino acid alphabet for embedding. + Returns: + System: A System object with the new XCS data. + + """ + alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet + all_atom = X.shape[2] == 14 + + assert X.shape[0] == 1 + assert C.shape[0] == 1 + assert S.shape[0] == 1 + assert X.shape[1] == S.shape[1] + assert C.shape[1] == C.shape[1] + + X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] + + chain_ids = np.abs(C) + + atom_count = 0 + new_system = cls("system") + + for i, chain_id in enumerate(np.unique(chain_ids)): + if chain_id == 0: + continue + + chain_bool = chain_ids == chain_id + X_chain = X[chain_bool, :, :].tolist() + C_chain = C[chain_bool].tolist() + S_chain = S[chain_bool].tolist() + + # Build chain + chain = new_system.add_chain("A") + for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): + resname = polyseq.to_triple(alphabet[int(S_i)]) + + # Build residue + residue = chain.add_residue( + resname, chain_ix + 1, str(chain_ix + 1), " " + ) + + if C_i > 0: + atom_names = constants.ATOMS_BB + + if all_atom and resname in constants.AA_GEOMETRY: + atom_names = ( + atom_names + constants.AA_GEOMETRY[resname]["atoms"] + ) + + for atom_ix, atom_name in enumerate(atom_names): + x, y, z = X_i[atom_ix] + atom_count += 1 + residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") + + # add an entity for each chain (copy from chain information) + for ci, chain in enumerate(new_system.chains()): + seq = [None] * chain.num_residues() + het = [None] * chain.num_residues() + for ri, res in enumerate(chain.residues()): + seq[ri] = res.name + het[ri] = all(a.het for a in res.atoms()) + entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) + entity = SystemEntity( + entity_type, f"chain {chain.cid}", polymer_type, seq, het + ) + new_system.add_new_entity(entity, [ci]) + + return new_system + + def to_XCS( + self, + all_atom: bool = False, + batch_dimension: bool = True, + mask_unknown: bool = True, + unknown_token: int = 0, + reorder_chain: bool = True, + alternate_alphabet=None, + alternate_atoms=None, + get_indices=False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert System object to XCS format. + + `C` tensor has shape [num_residues], where it codes positions as 0 + when masked, positive integers for chain indices, and negative integers + to represent missing residues of the corresponding positive integers. + + `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. + If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if + also want to mask `unk residue` in `chain_map` + + This function takes into account missing residues and updates chain_map + accordingly. + + Args: + system (type): generate System object to convert. + all_atom (bool): Include side chain atoms. Default is `False`. + batch_dimension (bool): Include a batch dimension. Default is `True`. + mask_unknown (bool): Mask residues not found in the alphabet. Default is + `True`. + unknown_token (int): Default token index if a residue is not found in + the alphabet. Default is `0`. + reorder_chain (bool): If set to true will start indexing chain at 1, + else will use the alphabet index (Default: True) + altenate_alphabet (str): Alternative alphabet if not `None`. + alternate_atoms (list): Alternate atom name subset for `X` if not `None`. + get_indices (bool): Also return the location indices corresponding to the + returned `X` tensor. + + Returns: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + location_indices (np.ndaray, optional): location indices corresponding to + the coordinates in `X`. + + """ + alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet + + # Get chain map grabbing each chain in system and look at length + C = [] + for ch_id, chain in enumerate(self.chains()): + ch_str = chain.cid + if ch_str in list(constants.CHAIN_ALPHABET): + map_ch_id = list(constants.CHAIN_ALPHABET).index(ch_str) + else: + # fmt: off + map_ch_id = np.setdiff1d(np.arange(1, len(constants.CHAIN_ALPHABET)), np.unique(C))[0] + # fmt: on + if reorder_chain: + map_ch_id = ch_id + 1 + C += [map_ch_id] * chain.num_residues() + + # Grab full sequence + oneLetterSeq = self.sequence(format="one-letter-string") + + if len(oneLetterSeq) != len(C): + logging.warning("Warning, System and chain_map length don't agree") + + # Initialize recipient arrays + atom_names = None + if all_atom: + num_atoms = 14 if all_atom else 4 + else: + if alternate_atoms is not None: + atom_names = alternate_atoms + else: + atom_names = constants.ATOMS_BB + num_atoms = len(atom_names) + atom_names = {a: i for (i, a) in enumerate(atom_names)} + num_residues = self.num_residues() + X = np.zeros([num_residues, num_atoms, 3]) + location_indices = ( + np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None + ) + + S = [] # Array will contain sequence indices + for i in range(num_residues): + # If residue should be mask or not + is_mask = False + + # Add sequence + if oneLetterSeq[i] in list(alphabet): + S.append(alphabet.index(oneLetterSeq[i])) + else: + S.append(unknown_token) + if mask_unknown: + is_mask = True + + # Get residue from system + res = self.get_residue(i) + if res is None or not res.has_structure(): + is_mask = True + + # If residue is mask because no structure or not found in alphabet + if is_mask: + # Set chain map to -x + C[i] = -abs(C[i]) + else: + # Loop through atoms + if all_atom: + code3 = constants.AA20_1_TO_3[oneLetterSeq[i]] + atom_names = ( + constants.ATOMS_BB + constants.AA_GEOMETRY[code3]["atoms"] + ) + atom_names = {a: i for (i, a) in enumerate(atom_names)} + + X[ + i, : + ] = np.nan # so we can tell whether some atom was previously found + num_rem = len(atom_names) + for atom in res.atoms(): + name = System.protein_backbone_atom_type(atom.name, False, True) + if name is None: + name = atom.name + ix = atom_names.get(name, None) + if ix is None or not np.isnan(X[i, ix, 0]): + continue + for loc in atom.locations(): + X[i, ix] = loc.coors + if location_indices is not None: + location_indices[i * num_atoms + ix] = loc.get_index() + num_rem -= 1 + break + if num_rem == 0: + break + if num_rem != 0: + C[i] = -abs(C[i]) + X[i, :] = 0 + np.nan_to_num(X[i, :], copy=False, nan=0) + + # Tensor everything + X = torch.tensor(X).float() + C = torch.tensor(C).type(torch.long) + S = torch.tensor(S).type(torch.long) + + # Unsqueeze all the thing + if batch_dimension: + X = X.unsqueeze(0) + C = C.unsqueeze(0) + S = S.unsqueeze(0) + + if location_indices is not None: + return X, C, S, location_indices + + return X, C, S + + def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): + """Update the System with XCS coordinates. NOTE: if the System has + more than one model, and if the shape of the System changes (i.e., + atoms are added or deleted), the additional models will be wiped. + + Args: + X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. Defaults to the current System's + chain map. + S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to + the current System's sequence. + """ + if C is None or S is None: + _, _C, _S = self.to_XCS() + if C is None: + C = _C + if S is None: + S = _S + + # check to make sure sizes agree + if not ( + (X.shape[1] == self.num_residues()) + and (X.shape[1] == C.shape[1]) + and (X.shape[1] == S.shape[1]) + ): + raise Exception( + f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" + ) + + def _process_inputs(T): + if T is not None: + if len(T.shape) == 2 or len(T.shape) == 4: + T = T.squeeze(0) + T = T.to("cpu").detach().numpy() + return T + + X, C, S = map(_process_inputs, [X, C, S]) + + shape_changed = False + alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet + for i, res in enumerate(self.residues()): + # atoms to update must have structure and are present in the chain map + if not res.has_structure() or C[i] <= 0: + continue + + # First, determine if the sequence has changed + resname = "UNK" + if S is not None and S[i] < len(alphabet): + resname = polyseq.to_triple(alphabet[S[i]]) + # If the identity changes, rename and delete side chain atoms + if res.name != resname: + res.rename(resname) + + # Second, delete all atoms that are missing in XCS or have duplicate names + atoms_sys = [atom.name for atom in res.atoms()] + atoms_XCS = constants.ATOMS_BB + if resname in constants.AA_GEOMETRY: + atoms_XCS = atoms_XCS + constants.AA_GEOMETRY[resname]["atoms"] + atoms_XCS = atoms_XCS[: X.shape[1]] + to_delete = [] + for ix_a, atom in enumerate(res.atoms()): + name = atom.name + if name not in atoms_XCS or name in atoms_sys[:ix_a]: + to_delete.append(atom) + if len(to_delete) > 0: + shape_changed = True + res.delete_atoms(to_delete) + + # Finally, update all atom coordinates and manufacture any missing atoms + for x_id, atom_name in enumerate(atoms_XCS): + atom = res.find_atom(atom_name) + x, y, z = [X[i][x_id][k].item() for k in range(3)] + if atom is not None and atom.num_locations() > 0: + atom.x = x + atom.y = y + atom.z = z + else: + shape_changed = True + if atom is not None: + atom.add_location(x, y, z) + else: + res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) + + # wipe extra models if the shape of the System changed + if shape_changed: + self._extra_models = [] + + def __str__(self): + return "system " + self.name + + def chains(self): + """Chain iterator (generator function).""" + for ci in range(len(self._chains)): + yield ChainView(ci, self) + + def get_chain(self, ci: int): + """Returns the chain by index. + + Args: + ci (int): Chain index (from 0) + + Returns: + ChainView object corresponding to the chain in question. + """ + return ChainView(ci, self) + + def get_chain_by_id(self, cid: str, segid=False): + """Returns the chain by its string ID. + + Args: + cid (str): Chain ID. + segid (bool, optional): If set to True (default is False) will + return the chain with the matching segment ID and not chain ID. + + Returns: + ChainView object corresponding to the chain in question. + """ + for ci, chain in enumerate(self.chains()): + if (not segid and cid == chain.cid) or (segid and cid == chain.segid): + return ChainView(ci, self) + return None + + def get_chains(self): + """Returns the list of all chains.""" + return [ChainView(ci, self) for ci in range(len(self._chains))] + + def get_chains_of_entity(self, entity_id: int, by=None): + """Returns the list of chains that correspond to the given entity ID. + + Args: + entity_id (int): Entity ID. + by (str, optional): If specified as "index", will return a + list of chain indices instead of ChainView objects. + + Returns: + List of ChainView objects or chain indices. + """ + cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] + if by == "index": + return cixs + return [ChainView(ci, self) for ci in cixs] + + def residues(self): + """Residue iterator (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + yield residue + + def get_residue(self, gti: int): + """Returns the residue at the given global index. + + Args: + gti (int): Global residue index. + + Returns: + ResidueView object corresponding to the index. + """ + if gti < 0: + raise Exception(f"negative residue index: {gti}") + off = 0 + for chain in self.chains(): + nr = chain.num_residues() + if gti < off + nr: + return chain.get_residue(gti - off) + off = off + nr + raise Exception( + f"residue index {gti} out of range for System, which has {self.num_residues()} residues" + ) + + def atoms(self): + """Iterator of atoms in this System (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + yield atom + + def get_atom(self, aidx: int): + """Returns the atom at the given global atom index. + + Args: + gti (int): Global atom index. + + Returns: + AtomView object corresponding to the index. + """ + if aidx < 0: + raise Exception(f"negative atom index: {aidx}") + off = 0 + for chain in self.chains(): + na = chain.num_atoms() + if aidx < off + na: + return chain.get_atom(aidx - off) + off = off + na + raise Exception( + f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" + ) + + def locations(self): + """Iterator of atoms in this System (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + for loc in atom.locations(): + yield loc + + def _new_locations(self): + new_locs = self._locations.copy() + for li in range(len(new_locs)): + new_locs["coor"][li] = [np.nan] * 5 + return new_locs + + def select(self, expression: str, left_associativity: bool = True): + """Evalates the given selection expression and returns all atoms + involved in the result as a list of AtomView's. + + Args: + expression (str): selection expression. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of AtomView's. + """ + val, selex_info = self._select( + expression, left_associativity=left_associativity + ) + + # make a list of AtomView + result = [selex_info["all_atoms"][i].atom for i in sorted(val)] + + return result + + def select_residues( + self, + expression: str, + gti: bool = False, + allow_unstructured=False, + left_associativity: bool = True, + ): + """Evalates the given selection expression and returns all residues with any + atoms involved in the result as a list of ResidueView's or list of gti's. + + Args: + expression (str): selection expression. + gti (bool): if True (default is False), will return a list of gti + instead of a list of ResidueView's. + allow_unstructured (bool): If True (default is False), will allow + unstructured residues to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of ResidueView's or gti's (ints). + """ + val, selex_info = self._select( + expression, + unstructured=allow_unstructured, + left_associativity=left_associativity, + ) + + # make a list of ResidueView or gti's + if gti: + result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) + else: + residues = dict() + for i in val: + a = selex_info["all_atoms"][i] + residues[a.rix] = a.atom.residue + result = [residues[rix] for rix in sorted(residues.keys())] + + return result + + def select_chains( + self, expression: str, allow_unstructured=False, left_associativity: bool = True + ): + """Evalates the given selection expression and returns all chains with any + atoms involved in the result as a list of ChainView's. + + Args: + expression (str): selection expression. + allow_unstructured (bool): If True (default is False), will allow + unstructured chains to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of ResidueView's or gti's (ints). + """ + val, selex_info = self._select( + expression, + unstructured=allow_unstructured, + left_associativity=left_associativity, + ) + + # make a list of ResidueView or gti's + chains = dict() + for i in val: + a = selex_info["all_atoms"][i] + chains[a.cix] = a.atom.chain + result = [chains[rix] for rix in sorted(chains.keys())] + + return result + + def _select( + self, + expression: str, + unstructured: bool = False, + left_associativity: bool = True, + ): + # Build some helpful data structures to support _selex_eval + @dataclass(frozen=True) + class MappableAtom: + atom: AtomView + aix: int + rix: int + cix: int + + def __hash__(self) -> int: + return self.aix + + # first collect all real atoms + all_atoms = [None] * self.num_atoms() + cix, rix, aix = 0, 0, 0 + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + all_atoms[aix] = MappableAtom(atom, aix, rix, cix) + aix = aix + 1 + + # for residues that do not have atoms, add a dummy atom with no location + # or name; that way, we can still select the residue even though selection + # algebra fundamentally works on atoms + if residue.num_atoms() == 0: + view = DummyAtomView(residue) + view.dummy = True + # make more room at the end of the list since as this is an "extra" atom + all_atoms.append(None) + all_atoms[aix] = MappableAtom(view, aix, rix, cix) + aix = aix + 1 + rix = rix + 1 + cix = cix + 1 + + _selex_info = {"all_atoms": all_atoms} + _selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) + + # fmt: off + # make an expression tree object + tree = ExpressionTreeEvaluator( + ["hyd", "all", "none"], + ["not", "byres", "bychain", "first", "last", + "chain", "authchain", "segid", "namesel", "gti", "resix", "resid", + "authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], + ["and", "or", "around", "saround"], + eval_function=partial(self._selex_eval, _selex_info), + left_associativity=left_associativity, + debug=False, + ) + # fmt: on + + # evaluate the expression + val = tree.evaluate(expression) + + # if we are not looking to select unstructured residues, remove any dummy + # atoms. NOTE: making dummy atoms can still impact what structured atoms + # are selected (e.g., consider `saround` relative to an unstructured residue) + if not unstructured: + val = { + i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") + } + + return val, _selex_info + + def save_selection( + self, + expression: Optional[str] = None, + gti: Optional[List[int]] = None, + selname: str = "_default", + allow_unstructured=False, + left_associativity: bool = True, + ): + """Performs a selection on the System according to the given + selection string and saves the indices of residues involved in + the result (global template indices) under the given name. + + Args: + expression (str): (optional) selection expression. + gti (list of int): (optional) list of gti to define selection expression + selname (str): selection name. + allow_unstructured (bool): If True (default is False), will allow + unstructured residues to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + """ + if gti is not None: + if expression is not None: + warnings.warn( + f"Expression and gti are both not null, expression will be ignored" + f" and gti will be used!" + ) + result = sorted(gti) + else: + result = self.select_residues( + expression, + allow_unstructured=allow_unstructured, + left_associativity=left_associativity, + gti=True, + ) + + # save the list of gti's + self._selections[selname] = result + + def get_selected(self, selname: str = "_default"): + """Returns the list of gti saved under the specified name. + + Args: + selname (str): selection name. + + Returns: + List of global template indices. + """ + if selname not in self._selections: + raise Exception( + f"selection by name '{selname}' does not exist in the System" + ) + return self._selections[selname] + + def has_selection(self, selname: str = "_default"): + """Returns whether the given named selection exists. + + Args: + selname (str): selection name. + + Returns: + Whether the selection exists in the System. + """ + return selname in self._selections + + def get_selection_names(self): + """Returns the list of all currently stored named selections.""" + return list(self._selections.keys()) + + def remove_selection(self, selname: str = "_default"): + """Deletes the selection under the specified name. + + Args: + selname (str): selection name. + """ + if selname not in self._selections: + raise Exception( + f"selection by name '{selname}' does not exist in the System" + ) + del self._selections[selname] + + def _selex_eval(self, _selex_info, op: str, left, right): + def _is_numeric(string: str) -> bool: + try: + float(string) + return True + except ValueError: + return False + + def _is_int(string: str) -> bool: + try: + int(string) + return True + except ValueError: + return False + + def _unpack_operands(operands, dests): + assert len(operands) == len(dests) + unpacked = [None] * len(operands) + succ = True + for i, (operand, dest) in enumerate(zip(operands, dests)): + if dest is None: + if operand is not None: + succ = False + break + elif dest == "result": + if not (isinstance(operand, dict) and "result" in operand): + succ = False + break + unpacked[i] = operand["result"] + elif dest == "string": + if not (len(operand) == 1 and isinstance(operand[0], str)): + succ = False + break + unpacked[i] = operand[0] + elif dest == "strings": + if not ( + isinstance(operand, list) + and all([isinstance(val, str) for val in operands]) + ): + succ = False + break + unpacked[i] = operands + elif dest == "float": + if not (len(operand) == 1 and _is_numeric(operand[0])): + succ = False + break + unpacked[i] = float(operand[0]) + elif dest == "floats": + if not ( + isinstance(operand, list) + and all([_is_numeric(val) for val in operands]) + ): + succ = False + break + unpacked[i] = [float(val) for val in operands] + elif dest == "range": + test = _parse_range(operand) + if test is None: + succ = False + break + unpacked[i] = test + elif dest == "int": + if not (len(operand) == 1 and _is_int(operand[0])): + succ = False + break + unpacked[i] = int(operand[0]) + elif dest == "ints": + if not ( + isinstance(operand, list) + and all([_is_int(val) for val in operands]) + ): + succ = False + break + unpacked[i] = [int(val) for val in operands] + elif dest == "int_range": + test = _parse_int_range(operand) + if test is None: + succ = False + break + unpacked[i] = test + elif dest == "int_range_string": + test = _parse_int_range(operand, string=True) + if test is None: + succ = False + break + unpacked[i] = test + return unpacked, succ + + def _parse_range(operands: list): + """Parses range information given a list of operands that were originally separated + by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with + optional spaces allowed between operands.""" + if not ( + isinstance(operands, list) + and all([isinstance(opr, str) for opr in operands]) + ): + return None + operand = "".join(operands) + if operand.startswith(">") or operand.startswith("<"): + if not _is_numeric(operand[1:]): + return None + num = float(operand[1:]) + if operand.startswith(">"): + test = lambda x, cut=num: x > cut + else: + test = lambda x, cut=num: x < cut + elif ":" in operand: + parts = operand.split(":") + if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): + return None + parts = [float(p) for p in parts] + test = lambda x, lims=parts: lims[0] < x < lims[1] + elif _is_numeric(operand): + target = float(operand) + test = lambda x, t=target: x == t + else: + return None + return test + + def _parse_int_range(operands: list, string: bool = False): + """Parses range of integers information given a list of operands that were + originally separated by spaces. Allowed range expressiosn are of the form: + `n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations + also allowed (e.g., "n+m+s+r-p+a").""" + if not ( + isinstance(operands, list) + and all([isinstance(opr, str) for opr in operands]) + ): + return None + operand = "".join(operands) + operands = operand.split("+") + ranges = [] + for operand in operands: + m = re.fullmatch("(.*\d)-(.+)", operand) + if m: + if not all([_is_int(g) for g in m.groups()]): + return None + r = range(int(m.group(1)), int(m.group(2)) + 1) + ranges.append(r) + else: + if not _is_int(operand): + return None + if string: + ranges.append(set([operand])) + else: + ranges.append(set([int(operand)])) + if string: + ranges = [[str(x) for x in r] for r in ranges] + test = lambda x, ranges=ranges: any([x in r for r in ranges]) + return test + + # evaluate expression and store result in list `result` + result = set() + if op in ("and", "or"): + (Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) + if not succ: + return None + if op == "and": + result = set(Si).intersection(set(Sj)) + else: + result = set(Si).union(set(Sj)) + elif op == "not": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + result = _selex_info["all_indices_set"].difference(S) + elif op == "all": + (_, _), succ = _unpack_operands([left, right], [None, None]) + if not succ: + return None + result = _selex_info["all_indices_set"] + elif op == "none": + (_, _), succ = _unpack_operands([left, right], [None, None]) + if not succ: + return None + elif op == "around": + (S, rad), succ = _unpack_operands([left, right], ["result", "float"]) + if not succ: + return None + + # Convert to numpy for distance calculation + atom_indices = np.asarray( + [ + ai.aix + for ai in _selex_info["all_atoms"] + for xi in ai.atom.locations() + ] + ) + X_i = np.asarray( + [ + [xi.x, xi.y, xi.z] + for ai in _selex_info["all_atoms"] + for xi in ai.atom.locations() + ] + ) + X_j = np.asarray( + [ + [xi.x, xi.y, xi.z] + for j in S + for xi in _selex_info["all_atoms"][j].atom.locations() + ] + ) + D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) + ix_match = (D <= rad).sum(1) > 0 + match_hits = atom_indices[ix_match] + result = set(match_hits.tolist()) + elif op == "saround": + (S, srad), succ = _unpack_operands([left, right], ["result", "int"]) + if not succ: + return None + for j in S: + aj = _selex_info["all_atoms"][j] + rj = aj.rix + for ai in _selex_info["all_atoms"]: + if aj.atom.residue.chain != ai.atom.residue.chain: + continue + ri = ai.rix + if abs(ri - rj) <= srad: + result.add(ai.aix) + elif op == "byres": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + gtis = set() + for j in S: + gtis.add(_selex_info["all_atoms"][j].rix) + for a in _selex_info["all_atoms"]: + if a.rix in gtis: + result.add(a.aix) + elif op == "bychain": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + cixs = set() + for j in S: + cixs.add(_selex_info["all_atoms"][j].cix) + for a in _selex_info["all_atoms"]: + if a.cix in cixs: + result.add(a.aix) + elif op in ("first", "last"): + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + if op == "first": + mi = min([_selex_info["all_atoms"][i].aix for i in S]) + else: + mi = max([_selex_info["all_atoms"][i].aix for i in S]) + result.add(mi) + elif op == "name": + (_, name), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if a.atom.name == name: + result.add(a.aix) + elif op in ("re", "hyd"): + if op == "re": + (_, regex), succ = _unpack_operands([left, right], [None, "string"]) + else: + (_, _), succ = _unpack_operands([left, right], [None, None]) + regex = "[0123456789]?H.*" + if not succ: + return None + ex = re.compile(regex) + for a in _selex_info["all_atoms"]: + if a.atom.name is not None and ex.fullmatch(a.atom.name): + result.add(a.aix) + elif op in ("chain", "authchain", "segid"): + (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if op == "chain": + prop = "cid" + elif op == "authchain": + prop = "authid" + elif op == "segid": + prop = "segid" + for a in _selex_info["all_atoms"]: + if getattr(a.atom.residue.chain, prop) == match_id: + result.add(a.aix) + elif op == "resid": + (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.atom.residue.num): + result.add(a.aix) + elif op in ("resname", "icode"): + (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if op == "resname": + prop = "name" + elif op == "icode": + prop = "icode" + for a in _selex_info["all_atoms"]: + if getattr(a.atom.residue, prop) == match_id: + result.add(a.aix) + elif op == "authresid": + (_, test), succ = _unpack_operands( + [left, right], [None, "int_range_string"] + ) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.atom.residue.authid): + result.add(a.aix) + elif op == "gti": + (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.rix): + result.add(a.aix) + elif op in ("x", "y", "z", "b", "occ"): + (_, test), succ = _unpack_operands([left, right], [None, "range"]) + if not succ: + return None + prop = op + if op == "b": + prop = "B" + for a in _selex_info["all_atoms"]: + for loc in a.atom.locations(): + if test(getattr(loc, prop)): + result.add(a.aix) + break + elif op == "namesel": + (_, selname), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if selname not in self._selections: + return None + gtis = set(self._selections[selname]) + for a in _selex_info["all_atoms"]: + if a.rix in gtis: + result.add(a.aix) + else: + return None + + return {"result": result} + + def __getitem__(self, chain_idx: int): + """Returns the chain at the given index.""" + return self.get_chain(chain_idx) + + def add_chain( + self, + cid: str, + segid: str = None, + authid: str = None, + entity_id: int = None, + auto_rename: bool = True, + at: int = None, + ): + """Adds a new chain to the System and returns a reference to it. + + Args: + cid (str): Chain ID. + segid (str): Segment ID. + authid (str): Author chain ID. + entity_id (int, optional): Entity ID of the entity corresponding to this chain. + auto_rename (bool, optional): If True, will pick a unique chain ID if the specified + one clashes with an already existing chain. + + Returns: + AtomView object corresponding to the index. + """ + if auto_rename: + cid = self._pick_unique_chain_name(cid) + if segid is None: + segid = cid + if authid is None: + authid = cid + if at is None: + at = self.num_chains() + self._chains.append({"cid": cid, "segid": segid, "authid": authid}) + self._chain_entities.append(entity_id) + else: + self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) + self._chain_entities.insert(at, entity_id) + return ChainView(at, self) + + def _append_residue(self, name: str, num: int, authid: str, icode: str): + """Add a new residue to the end this System. Internal method, do not use. + + Args: + name (str): Residue name. + num (int): Residue number (i.e., residue ID). + authid (str): Author residue ID. + icode (str): Insertion code. + + Returns: + Global index to the newly added residue. + """ + self._chains.append_child( + {"name": name, "resnum": num, "authresid": authid, "icode": icode} + ) + return len(self._residues) - 1 + + def _append_atom( + self, + name: str, + het: bool, + x: float = None, + y: float = None, + z: float = None, + occ: float = None, + B: float = None, + alt: str = None, + ): + """Adds a new atom to the end of this System. Internal method, do not use. + + Args: + name (str): Atom name. + het (bool): Whether it is a hetero-atom. + x, y, z (float): Atom location coordinates. + occ (float): Occupancy. + B (float): B-factor. + alt (str): Alternative position character. + + Returns: + Global index to the newly added atom. + """ + self._residues.append_child({"name": name, "het": het}) + return len(self._atoms) - 1 + + def _append_location(self, x, y, z, occ, B, alt): + """Adds a location to the end of this System. Internal method, do not use. + + Args: + x, y, z (float): coordinates of the location. + occ (float): occupancy for the location. + B (float): B-factor for the location. + alt (str): alternative location character. + + Returns: + Global index to the newly added location. + """ + self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) + return len(self._locations) - 1 + + def add_new_entity(self, entity: SystemEntity, chain_indices: list): + """Adds a new entity to the list contained within the System and + assigns chains with provided indices to this entity. + + Args: + entity (SystemEntity): The new entity to add to the System. + chain_indices (list): a list of Chain indices for chains to + assign to this entity. + + Returns: + The entity ID of the newly added entity. + """ + new_entity_id = len(self._entities) + while new_entity_id in self._entities: + new_entity_id = new_entity_id + 1 + self._entities[new_entity_id] = entity + for ci in chain_indices: + self._chain_entities[ci] = new_entity_id + return new_entity_id + + def delete_entity(self, entity_id: int): + """Deletes the entity with the specified ID. Takes care to unlink + any chains belonging to this entity from it. + + Args: + entity_id (int): Entity ID. + """ + chain_indices = self.get_chains_of_entity(entity_id) + for ci in chain_indices: + self._chain_entities[ci] = None + del self._entities[entity_id] + + def _pick_unique_chain_name(self, hint: str, verbose=False): + goodNames = list( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + ) + taken = set([chain.cid for chain in self.chains()]) + + # first try to pick a conforming chain name (single alpha-numeric character) + for cid in [hint] + goodNames: + if cid not in taken: + return cid + if verbose: + warnings.warn( + "ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" + ) + + # if that does not work, get a longer chain name + for i in range(-1, len(goodNames)): + # first try to expand the original chain ID + base = hint if i < 0 else goodNames[i : i + 1] + if base == "": + continue + for k in range(1000): + longName = f"{base}{k}" + if longName not in taken: + return longName + raise Exception( + "ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" + ) + + def _ensure_unique_entity(self, ci: int): + """Any time we need to update some piece of information about a Chain that + relates to its entity (e.g., sequence info or hetero info), we cannot just + update it directly because other Chains may be pointing to the same entity. + This function checks for any other chains pointing to the same entity as the + specified chain, and (if so) assigns the given chain to a new (duplicate) + entity and returns its new ID. This clears the way updates of this Chain's entity. + + Args: + ci (int): Index of the Chain for which we are trying to update + entity information. + + Returns: + entity ID for either a newly created entity mapped to the Chain or its + initial entity ID if no other chains point to the same entity. + """ + chain = self.get_chain(ci) + entity_id = chain.get_entity_id() + if entity_id is None: + return entity_id + + # see if any other chains point to the same entity + unique = True + for other in self.chains(): + if (other != chain) and (entity_id == other.get_entity_id()): + unique = False + break + if unique: + return entity_id + + # if so, we need to make a new entity and point the chain to it + new_entity = copy.deepcopy(self._entities[entity_id]) + new_entity_id = self.add_new_entity(new_entity, [ci]) + return new_entity_id + + def num_chains(self): + """Returns the number of chains in the System.""" + return len(self._chains) + + def num_chains_of_entity(self, entity_id: int): + """Returns the number of chains of a given entity. + + Args: + entity_id (int): Entity ID. + + Returns: + number of chains mapping to the entity. + """ + + return sum([entity_id == eid for eid in self._chain_entities]) + + def num_molecules_of_entity(self, entity_id: int): + if self._entities[entity_id].is_polymer(): + return self.num_chains_of_entity(entity_id) + cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] + return sum([self[ci].num_residues() for ci in cixs]) + + def num_entities(self): + """Returns the number of entities in the System.""" + return len(self._entities) + + def num_residues(self): + """Returns the number of residues in the System.""" + return len(self._residues) + + def num_structured_residues(self): + """Returns the number of residues with any structure information.""" + return sum([chain.num_structured_residues() for chain in self.chains()]) + + def num_atoms(self): + """Returns the number of atoms in the System.""" + return len(self._atoms) + + def num_structured_atoms(self): + """Returns the number of atoms with any location information.""" + num = 0 + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + num = num + (atom.num_locations() > 0) + return num + + def num_atom_locations(self): + """Returns the number of atom locations. Note that an atom can have + multiple (alternative) locations and this functions counts all. + """ + return len(self._locations) + + def num_models(self): + """Returns the number of models in the System. A model is effectively + a conformation of the molecular system and each System object can have + an arbitrary number of different conformations. + """ + return len(self._extra_models) + 1 + + def swap_model(self, i: int): + """Swaps the model at index `i` with the current model (i.e., the + model at index 0). + + Args: + i (int): Model index + """ + if i == 0: + return + if i < 0 or i >= self.num_models(): + raise Exception(f"model index {i} out of range") + tmp = self._locations + self._locations = self._extra_models[i - 1] + self._extra_models[i - 1] = tmp + + def add_model(self, other: System): + """Adds a new model to the System by taking the current model from the + specified System `other`. Note that `other` and the present System + must have the same number of atom locations of matching atom and + residue names. + + Args: + other (System): The System to take the model from. + """ + if len(self._locations) != len(other._locations): + raise Exception( + f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" + ) + self._extra_models.append(other._locations.copy()) + self._extra_models[-1].set_parent(self._atoms) + + def add_model_from_X(self, X: torch.Tensor): + """Adds a new model to the System with given coordinates. + + Args: + X (torch.Tensor): Coordinate tensor of shape + `(residues, atoms (4 or 14), coordinates (3))` + """ + if len(self._locations) != X.numel() / 3: + raise Exception( + f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" + ) + X = X.detach().cpu() + self._extra_models.append(self._locations.copy()) + self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) + return None + + def num_assemblies(self): + """Returns the number of biological assemblies defined in this System.""" + return len(self._assembly_info.assemblies) + + @staticmethod + def from_CIF_string(cif_string: str): + """Initializes and returns a System object from a CIF string.""" + import io + + f = io.StringIO(cif_string) + return System._read_cif(f)[0] + + @staticmethod + def from_CIF(input_file: str): + """Initializes and returns a System object from a CIF file.""" + f = open(input_file, "r") + return System._read_cif(f)[0] + + @staticmethod + def _read_cif(f, strict=False): + def _warn_or_error(strict: bool, msg: str): + if strict: + raise Exception(msg) + else: + warnings.warn(msg) + + is_read = { + part: False for part in ["coors", "entities", "sequence", "entity_poly"] + } + category = "" + (in_loop, success) = (False, True) + peeked = sp.PeekedLine("", 0) + # number of molecules per entity prescribed in the CIF file + num_of_mols = dict() + + system = System("system") + while sp.peek_line(f, peeked): + if peeked.line.startswith("#"): + # nothing to do, skip comments + sp.advance(f, peeked) + elif peeked.line.startswith("data_"): + # nothing to do, this is the beginning of the file + sp.advance(f, peeked) + elif peeked.line.startswith("loop_"): + in_loop = True + category = "" + sp.advance(f, peeked) + else: + (cat, name, val) = ("", "", "") + if peeked.line.startswith("_"): + (cat, name, val) = sp.star_item_parse(peeked.line) + if cat != category: + if category != "": + in_loop = False + category = cat + + if (cat == "_entry") and (name == "id"): + if val != "": + system.name = val + sp.advance(f, peeked) + elif cat == "_entity_poly": + if is_read["entity_poly"]: + raise Exception("entity_poly block encountered multiple times") + tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) + for row in tab: + ent_id = int(row[0]) - 1 + if ent_id not in system._entities: + system._entities[ent_id] = SystemEntity( + None, None, row[1], None, None + ) + else: + system._entities[ent_id]._polymer_type = row[1] + is_read["entity_poly"] = True + elif cat == "_entity": + if is_read["entities"]: + raise Exception( + f"entities block encountered multiple times: {peeked.line}" + ) + tab = sp.star_read_data( + f, + ["id", "type", "pdbx_description", "pdbx_number_of_molecules"], + in_loop, + ) + for row in tab: + ent_id = int(row[0]) - 1 + if ent_id not in system._entities: + system._entities[ent_id] = SystemEntity( + row[1], row[2], None, None, None + ) + else: + system._entities[ent_id]._type = row[1] + system._entities[ent_id]._desc = row[2] + if row[3].isnumeric(): + num_of_mols[ent_id] = int(row[3]) + is_read["entities"] = True + elif cat == "_entity_poly_seq": + if is_read["sequence"]: + raise Exception(f"sequence block encountered multiple times") + tab = sp.star_read_data( + f, ["entity_id", "num", "mon_id", "hetero"], in_loop + ) + (seq, het) = ([], []) + for i in range(len(tab)): + # accumulate sequence information until we reach the end or a new entity ID + seq.append(tab[i][2]) + het.append(tab[i][3].startswith("y")) + if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): + ent_id = int(tab[i][0]) - 1 + system._entities[ent_id]._seq = seq + system._entities[ent_id]._het = het + (seq, het) = ([], []) + is_read["sequence"] = True + elif cat == "_pdbx_struct_assembly": + tab = sp.star_read_data(f, ["id", "details"], in_loop) + for row in tab: + system._assembly_info.assemblies[row[0]] = {"details": row[1]} + elif cat == "_pdbx_struct_assembly_gen": + tab = sp.star_read_data( + f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop + ) + for row in tab: + assembly = system._assembly_info.assemblies[row[0]] + if "instructions" not in assembly: + assembly["instructions"] = [] + chain_ids = [cid.strip() for cid in row[2].strip().split(",")] + assembly["instructions"].append( + {"oper_expression": row[1], "chains": chain_ids} + ) + elif cat == "_pdbx_struct_oper_list": + tab = sp.star_read_data( + f, + [ + "id", + "type", + "name", + "matrix[1][1]", + "matrix[1][2]", + "matrix[1][3]", + "matrix[2][1]", + "matrix[2][2]", + "matrix[2][3]", + "matrix[3][1]", + "matrix[3][2]", + "matrix[3][3]", + "vector[1]", + "vector[2]", + "vector[3]", + ], + in_loop, + ) + for row in tab: + system._assembly_info.operations[ + row[0] + ] = SystemAssemblyInfo.make_operation( + row[1], row[2], row[3:12], row[12:15] + ) + elif cat == "_generate_selections": + tab = sp.star_read_data(f, ["name", "indices"], in_loop) + for row in tab: + system._selections[row[0]] = [ + int(gti.strip()) for gti in row[1].strip().split() + ] + elif cat == "_generate_labels": + tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) + for row in tab: + if row[0] not in system._labels: + system._labels[row[0]] = dict() + idx = int(row[1]) + system._labels[row[0]][int(row[1])] = row[2] + elif cat == "_atom_site": + if is_read["coors"]: + raise Exception(f"ATOM_SITE block encountered multiple times") + # this section is special as it cannot have quoted blocks (because some atom names have the single quote character in them) + tab = sp.star_read_data( + f, + [ + "group_PDB", + "id", + "label_atom_id", + "label_alt_id", + "label_comp_id", + "label_asym_id", + "label_entity_id", + "label_seq_id", + "pdbx_PDB_ins_code", + "Cartn_x", + "Cartn_y", + "Cartn_z", + "occupancy", + "B_iso_or_equiv", + "pdbx_PDB_model_num", + "auth_seq_id", + "auth_asym_id", + ], + in_loop, + cols=False, + has_blocks=False, + ) + + groupCol = 0 + idxCol = 1 + atomNameCol = 2 + altIdCol = 3 + resNameCol = 4 + chainNameCol = 5 + entityIdCol = 6 + seqIdCol = 7 + insCodeCol = 8 + xCol = 9 + yCol = 10 + zCol = 11 + occCol = 12 + bCol = 13 + modelCol = 14 + authSeqIdCol = 15 + authChainNameCol = 16 + + ( + atom, + residue, + chain, + prev_chain, + prev_residue, + prev_atom, + prev_entity_id, + prev_seq_id, + prev_auth_seq_id, + ) = (None, None, None, None, None, None, None, None, None) + loc = None # first model location + aIdx = 0 + for i in range(len(tab)): + if i == 0: + first_model = tab[i][modelCol] + prev_model = first_model + elif (tab[i][modelCol] != prev_model) or ( + tab[i][modelCol] != first_model + ): + if tab[i][modelCol] != prev_model: + aIdx = 0 + num_loc = system.num_atom_locations() + # setting the default value to None allows us to tell when the + # same coordinate in a subsequent model was not specified (e.g., + # when an alternative coordinate is not specified) + system._extra_models.append(system._new_locations()) + prev_model = tab[i][modelCol] + locations_generator = (l for l in system.locations()) + + loc = next(locations_generator, None) + if aIdx >= num_loc: + _warn_or_error( + strict, + f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", + ) + success = False + system._extra_models.clear() + break + + # check that the atoms correspond + same = ( + (loc is not None) + and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) + and (tab[i][resNameCol] == loc.atom.residue.name) + and ( + int( + sp.star_value( + tab[i][seqIdCol], loc.atom.residue.num + ) + ) + == loc.atom.residue.num + ) + and (tab[i][atomNameCol] == loc.atom.name) + ) + if not same: + _warn_or_error( + strict, + f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", + ) + success = False + system._extra_models.clear() + break + + coor = [ + float(tab[i][c]) + for c in [xCol, yCol, zCol, occCol, bCol] + ] + system._extra_models[-1]["coor"][aIdx] = coor + system._extra_models[-1]["alt"][aIdx] = sp.star_value( + tab[i][altIdCol], " " + )[0] + aIdx = aIdx + 1 + continue + + # new chain? + if ( + (chain is None) + or (prev_entity_id != tab[i][entityIdCol]) + or (tab[i][chainNameCol] != chain.cid) + ): + authid = ( + tab[i][authChainNameCol] + if (tab[i][authChainNameCol] != "") + else tab[i][chainNameCol] + ) + chain = system.add_chain( + tab[i][chainNameCol], + tab[i][chainNameCol], + authid, + int(tab[i][entityIdCol]) - 1, + ) + + # new residue + if ( + (residue is None) + or (chain != prev_chain) + or (prev_seq_id != tab[i][seqIdCol]) + or (prev_auth_seq_id != tab[i][authSeqIdCol]) + ): + resnum = ( + int(tab[i][seqIdCol]) + if sp.star_value_defined(tab[i][seqIdCol]) + else chain.num_residues() + 1 + ) + ri = system._append_residue( + tab[i][resNameCol], + resnum, + tab[i][authSeqIdCol], + sp.star_value(tab[i][insCodeCol], " ")[0], + ) + residue = ResidueView(ri, chain) + + # usually will be a new atom, but may be an alternative coordinate + # TODO: this only covers cases where alternative atom coordinates are listed next to each other, + # but sometimes they are not -- need to actively use the altIdCol information + x, y, z, occ, B = [ + float(tab[i][col]) + for col in [xCol, yCol, zCol, occCol, bCol] + ] + alt = sp.star_value(tab[i][altIdCol], " ")[0] + if ( + (atom is None) + or (residue != prev_residue) + or (tab[i][atomNameCol] != atom.name) + ): + ai = system._append_atom( + tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") + ) + atom = AtomView(ai, residue) + system._append_location(x, y, z, occ, B, alt) + + prev_chain = chain + prev_residue = residue + prev_entity_id = tab[i][entityIdCol] + prev_seq_id = tab[i][seqIdCol] + prev_auth_seq_id = tab[i][authSeqIdCol] + is_read["coors"] = True + else: + sp.advance(f, peeked) + + # fill in any "missing" polymer chains (e.g., chains with no visible density + # or known structure, but which are nevertheless present) + for entity_id in num_of_mols: + if system._entities[entity_id].is_polymer(): + rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) + for _ in range(rem): + # the chain will get renamed to avoid clashes + system.add_chain("A", None, None, entity_id, auto_rename=True) + + # fill in missing residues (i.e., those that exist in the entity but not + # the atomistic section) + for chain in system.chains(): + entity = chain.get_entity() + if not entity.is_polymer() or entity._seq is None: + continue + k = 0 + for ri in range(len(entity._seq)): + cur_res = chain.get_residue(k) if k < chain.num_residues() else None + if (cur_res is None) or (cur_res.num > ri + 1): + # insert new residue to correspond to entity monomer with index ri + chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) + elif cur_res.num < ri + 1: + _warn_or_error( + strict, f"inconsistent numbering in chain {chain.cid}" + ) + break + k = k + 1 + + # do an entity-to-structure sequence check for all chains + for chain in system.chains(): + if not chain.check_sequence(): + _warn_or_error( + strict, + f"chain {chain.cid} did not pass sequence check against corresponding entity", + ) + + system._reindex() + return system, success + + @staticmethod + def from_PDB_string(cif_string: str, options=""): + """Initializes and returns a System object from a PDB string.""" + import io + + f = io.StringIO(cif_string) + sys = System._read_pdb(f, options) + sys.name = "from_string" + return sys + + @staticmethod + def from_PDB(input_file: str, options=""): + """Initializes and returns a System object from a PDB file.""" + f = open(input_file, "r") + sys = System._read_pdb(f, options) + sys.name = input_file + return sys + + @staticmethod + def _read_pdb(f, strict=False, options=""): + def _to_float(strval, default): + v = default + try: + v = float(strval) + except: + pass + return v + + last_resnum = None + last_resname = None + last_icode = None + last_chain_id = None + last_alt = None + chain = None + residue = None + + # flag to indicate that chain terminus was reached. Initialize to true so as to create a new chain upon reading the first atom. + ter = True + + # various parsing options (the wonders of dealing with the good-old PDB format) + # and any user-specified overrides + options = options.upper() + # use segment IDs to name chains instead of chain IDs? (useful when the latter + # are absent OR when too many chains, so need multi-letter names) + usese_gid = True if ("USESEGID" in options) else False + + # the PDB file was written by CHARMM (slightly different format) + charmm_format = True if ("CHARMM" in options) else False + + # upon reading, convert from all-hydrogen topology (param22 and higher) to + # the CHARMM19 united-atom topology (matters for HIS protonation states) + charmm19_format = True if ("CHARMM19" in options) else False + + # make sure chain IDs are unique, even if they are not unique in the read file + uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True + + # rename CD in ILE to CD1 (as is standard in PDB, but not some MM packages) + fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True + + # consequtive residues that differ only in their insertion code will be treated + # as separate residues + icodes_as_sep_res = True + + # if true, will not pay attention to TER lines in deciding when chains end/begin + ignore_ter = True if ("IGNORE-TER" in options) else False + + # report various warnings when weird things are found and fixed? + verbose = False if ("QUIET" in options) else True + + chains_to_rename = [] + + # read line by line and build the System + system = System("system") + all_system = system + model_index = 0 + for line in f: + line = line.strip() + if line.startswith("ENDMDL"): + # merge the last read model with the overall System + if model_index: + try: + all_system.add_model(system) + except Exception as e: + warnings.warn( + f"error when adding model {model_index + 1}: {str(e)}, skipping model..." + ) + system = System("system") + model_index = model_index + 1 + last_resnum = None + last_resname = None + last_icode = None + last_chain_id = None + last_alt = None + chain = None + residue = None + continue + if line.startswith("END"): + break + if line.startswith("MODEL"): + # new model + continue + if line.startswith("TER") and not ignore_ter: + ter = True + continue + if not (line.startswith("ATOM") or line.startswith("HETATM")): + continue + + """ Now read atom record. Sometimes PDB lines are too short (if they do not contain some + of the last optional columns). We don't want to read past the end of the string!""" + line += " " * 100 + atominx = int(line[6:11]) + atomname = line[12:16].strip() + alt = line[16:17] + resname = line[17:21].strip() + chain_id = line[21:22].strip() + resnum = int(line[23:27]) if charmm_format else int(line[22:26]) + icode = " " if charmm_format else line[26:27] + x = float(line[30:38]) + y = float(line[38:46]) + z = float(line[46:54]) + seg_id = line[72:76].strip() + B = _to_float(line[60:66], 0.0) + occ = _to_float(line[54:60], 0.0) + het = line.startswith("HETATM") + + # use segment ID's instead of chain ID's? + if usese_gid: + chain_id = seg_id + elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): + # use first character of segment name if no chain name is specified, a segment ID + # is specified, and the latter starts with an alphanumeric character + chain_id = seg_id[0:1] + + # create a new chain object, if necessary + if (chain_id != last_chain_id) or ter: + cid_used = system.get_chain_by_id(chain_id) is not None + chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) + # non-unique chains will be automatically renamed (unless the user specified not to rename chains), BUT we need to + # remember the name that was actually read, since this name is what will be used to determine when the next chain comes + if uniq_chain_ids and cid_used: + chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" + if model_index == 0: + chains_to_rename.append(chain) + if verbose: + warnings.warn( + "chain name '" + + chain_id + + "' was repeated while reading, will rename at the end..." + ) + + # start to count residue numbers in this chain + last_resnum = None + last_resname = None + ter = False + + if charmm19_format: + if resname == "HSE": + resname = "HSD" # neutral HIS, proton on ND1 + if resname == "HSD": + resname = "HIS" # neutral HIS, proton on NE2 + if resname == "HSC": + resname = "HSP" # doubley-protonated +1 HIS + + # many PDB files in the Protein Data Bank call the delta carbon of isoleucine CD1, but + # the convention in basically all MM packages is to call it CD, since there is only one + if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): + atomname = "CD1" + + # if necessary, make a new residue + really_new_atom = True # is this a truely new atom, as opposed to an alternative position? + if ( + (resnum != last_resnum) + or (resname != last_resname) + or (icodes_as_sep_res and (icode != last_icode)) + ): + # this corresponds to a case, where the alternative location flag is being used to + # designate two (or more) different possible amino acids at a particular position + # (e.g., where the density is not clear to assign one). In this case, we shall keep + # only the first option, because we don't know any better. But we need to separate + # this from the case, where we end up here because we are trying to separate residues + # by insertion code. + if ( + (resnum == last_resnum) + and (resname != last_resname) + and (alt != last_alt) + and (not icodes_as_sep_res or (icode == last_icode)) + ): + continue + + residue = chain.add_residue( + resname, chain.num_residues() + 1, str(resnum), icode[0] + ) + elif alt != " ": + # if this is not a new residue AND the alternative location flag is specified, + # figure out if another location for this atom has already been given. If not, + # then treat this as the "primary" location, and whatever other locations + # are specified will be treated as alternatives. + a = residue.find_atom(atomname) + if a is not None: + really_new_atom = False + a.add_location(x, y, z, occ, B, alt[0]) + + # if necessary, make a new atom + if really_new_atom: + a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) + + # remember previous values for determining whether something interesting happens next + last_resnum = resnum + last_icode = icode + last_resname = resname + last_chain_id = chain_id + last_alt = alt + + # take care of renaming any chains that had duplicate IDs + for chain in chains_to_rename: + parts = chain.cid.split("|") + assert ( + len(parts) > 1 + ), "something went wrong when renaming a chain at the end of reading" + name = all_system._pick_unique_chain_name(parts[0], verbose) + chain.cid = name + if len(name): + chain.segid = name + + # add an entity for each chain (copy from chain information) + for ci, chain in enumerate(all_system.chains()): + seq = [None] * chain.num_residues() + het = [None] * chain.num_residues() + for ri, res in enumerate(chain.residues()): + seq[ri] = res.name + het[ri] = all(a.het for a in res.atoms()) + entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) + entity = SystemEntity( + entity_type, f"chain {chain.cid}", polymer_type, seq, het + ) + all_system.add_new_entity(entity, [ci]) + + return all_system + + def to_CIF(self, output_file: str): + """Writes the System to a CIF file.""" + f = open(output_file, "w") + self._write_cif(f) + + def to_CIF_string(self): + """Returns a CIF string representing the System.""" + import io + + f = io.StringIO("") + self._write_cif(f) + cif_str = f.getvalue() + f.close() + return cif_str + + def _write_cif(self, f): + # fmt: off + _specials_atom_names = [ + "MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", + "BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", + "AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" + ] + # fmt: on + _ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] + + def _guess_type(atom_name, res_name): + if len(atom_name) > 0 and atom_name[0] == '"': + atom_name = atom_name.replace('"', "") + if atom_name[:2] in _specials_atom_names: + return atom_name[:2] + else: + if atom_name in _ambiguous_atom_names and res_name == atom_name: + return atom_name + elif atom_name == "UNK": + return "X" + return atom_name[:1] + + entry_id = self.name.strip() + if entry_id == "": + entry_id = "system" + f.write( + "data_GNR8\n#\n" + + "_entry.id " + + sp.star_string_escape(entry_id) + + "\n#\n" + ) + + # write entities table + sp.star_loop_header_write( + f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] + ) + for id, entity in self._entities.items(): + num_mol = self.num_molecules_of_entity(id) + f.write( + f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" + ) + f.write("#\n") + + # write entity polymer sequences + sp.star_loop_header_write( + f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] + ) + for id, entity in self._entities.items(): + if entity._seq is not None: + for i, (res, het) in enumerate(zip(entity._seq, entity._het)): + f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") + f.write("#\n") + + # write entity polymer types + sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) + for id, entity in self._entities.items(): + if entity.is_polymer(): + f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") + f.write("#\n") + + if self.num_assemblies(): + assemblies = self._assembly_info.assemblies + ops = self._assembly_info.operations + # assembly info table + sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) + for assembly_id, assembly in assemblies.items(): + f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") + f.write("#\n") + + # assembly generation instructions table + sp.star_loop_header_write( + f, + "_pdbx_struct_assembly_gen", + ["assembly_id", "oper_expression", "asym_id_list"], + ) + for assembly_id, assembly in assemblies.items(): + for instruction in assembly["instructions"]: + chain_list = ",".join([str(ci) for ci in instruction["chains"]]) + f.write( + f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" + ) + f.write("#\n") + + # symmetry operations table + sp.star_loop_header_write( + f, + "_pdbx_struct_oper_list", + [ + "id", + "type", + "name", + "matrix[1][1]", + "matrix[1][2]", + "matrix[1][3]", + "matrix[2][1]", + "matrix[2][2]", + "matrix[2][3]", + "matrix[3][1]", + "matrix[3][2]", + "matrix[3][3]", + "vector[1]", + "vector[2]", + "vector[3]", + ], + ) + for op_id, op in ops.items(): + f.write( + f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " + ) + f.write( + f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " + ) + f.write( + f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " + ) + f.write( + f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " + ) + f.write( + f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" + ) + f.write("#\n") + + sp.star_loop_header_write( + f, + "_atom_site", + [ + "group_PDB", + "id", + "label_atom_id", + "label_alt_id", + "label_comp_id", + "label_asym_id", + "label_entity_id", + "label_seq_id", + "pdbx_PDB_ins_code", + "Cartn_x", + "Cartn_y", + "Cartn_z", + "occupancy", + "B_iso_or_equiv", + "pdbx_PDB_model_num", + "auth_seq_id", + "auth_asym_id", + "type_symbol", + ], + ) + idx = -1 + for model_index in range(self.num_models()): + self.swap_model(model_index) + for chain, entity_id in zip(self.chains(), self._chain_entities): + authchainid = ( + chain.authid if sp.star_value_defined(chain.authid) else chain.cid + ) + for residue in chain.residues(): + authresid = ( + residue.authid + if sp.star_value_defined(residue.authid) + else residue.num + ) + for atom in residue.atoms(): + idx = idx + 1 + for location in atom.locations(): + # this means this coordinate was not specified for this model + if not location.defined(): + continue + + coor = location.coor_info + f.write("HETATM " if atom.het else "ATOM ") + f.write( + f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " + ) + entity_id_str = ( + f"{entity_id + 1}" if entity_id is not None else "?" + ) + f.write( + f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " + ) + f.write( + f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " + ) + f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") + f.write( + f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" + ) + self.swap_model(model_index) + f.write("#\n") + + # write out selections + if len(self._selections): + sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) + for name, indices in self._selections.items(): + f.write( + f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" + ) + f.write("#\n") + + # write out labels + if len(self._labels): + sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) + for category, label_dict in self._labels.items(): + for gti, label in label_dict.items(): + f.write( + f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" + ) + f.write("#\n") + + def to_PDB(self, output_file: str, options: str = ""): + """Writes the System to a PDB file. + + Args: + output_file (str): output PDB file name. + options (str, optional): a string specifying various options for + the writing process. The presence of certain sub-strings will + trigger specific behaviors. Currently recognized sub-strings + include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", + "NOTER", and "NOALT". This option is case-insensitive. + """ + f = open(output_file, "w") + self._write_pdb(f, options) + + def to_PDB_string(self, options=""): + """Writes the System to a PDB string. The options string has the same + interpretation as with System::toPDB. + """ + import io + + f = io.StringIO("") + self._write_pdb(f, options) + cif_str = f.getvalue() + f.close() + return cif_str + + def _write_pdb(self, f, options=""): + def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): + if rn is None: + rn = loc.atom.residue.name + if ri is None: + ri = loc.atom.residue.num + if an is None: + an = loc.atom.name + icode = loc.atom.residue.icode + cid = loc.atom.residue.chain.cid + if len(cid) > 1: + cid = cid[0] + segid = loc.atom.residue.chain.segid + if len(segid) > 4: + segid = segid[0:4] + + # atom name placement is different when it is 4 characters long + if len(an) < 4: + an_str = " %-.3s" % an + else: + an_str = "%.4s" % an + + # moduli are used to make sure numbers do not go over prescribe field widths + # (this is not enforced by sprintf like with strings) + line = ( + "%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" + % ( + "HETATM" if loc.atom.het else "ATOM ", + ai % 100000, + an_str, + loc.alt, + rn, + cid, + ri % 10000, + icode, + loc.x, + loc.y, + loc.z, + loc.occ, + loc.B, + segid, + ) + ) + + return line + + # various formating options (the wonders of dealing with the good-old PDB format) + # and user-defined overrides + options = options.upper() + # the PDB file is intended for use in CHARMM or some other MM package + charmmFormat = True if "CHARMM" in options else False + + # upon writing, convert from all-hydrogen topology (param 22 and higher) + # to CHARMM19 united-atom topology (matters for HIS protonation states) + charmm19Format = True if "CHARMM19" in options else False + + # upon writing, convert from CHARMM19 united-atom topology to all-hydrogen + # param 22 topology (matters for HIS protonation states). Also works for + # converting generic PDB files downloaded from the PDB. + charmm22Format = True if "CHARMM22" in options else False + + # upon writing, renumber residue and atom names to start from 1 and go in order + renumber = True if "RENUMBER" in options else False + + # do not write END at the end of the PDB file (e.g., useful for + # concatenating chains from several structures) + noend = True if "NOEND" in options else False + + # do not demark the end of each chain with TER (this is not _really_ + # necessary, assuming chain names are unique, and it is sometimes nice + # not to have extra lines other than atoms) + noter = True if "NOTER" in options else False + + # write alternative locations by default + writeAlt = True if "NOALT" in options else False + + # upon writing, convert to a generic PDB naming convention (no + # protonation state specified for HIS) + genericFormat = False + + if charmm19Format and charmm22Format: + raise Exception( + "CHARMM 19 and 22 formatting options cannot be specified together" + ) + + atomIndex = 1 + for ci, chain in enumerate(self.chains()): + for ri, residue in enumerate(chain.residues()): + for ai, atom in enumerate(residue.atoms()): + # dirty details of formating for MM purposes converting + atomname = atom.name + resname = residue.name + if charmmFormat: + if (residue.name == "ILE") and (atom.name == "CD1"): + atomname = "CD" + if (atom.name == "O") and (ri == chain.num_residues() - 1): + atomname = "OT1" + if (atom.name == "OXT") and (ri == chain.num_residues() - 1): + atomname = "OT2" + if residue.name == "HOH": + resname = "TIP3" + + if charmm19Format: + if residue.name == "HSD": # neutral HIS, proton on ND1 + resname = "HIS" + if residue.name == "HSE": # neutral HIS, proton on NE2 + resname = "HSD" + if residue.name == "HSC": # doubley-protonated +1 HIS + resname = "HSP" + elif charmm22Format: + """This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded + * PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen + * topology, HIS protonation state must be explicitly specified, so there is no HIS per se. + * Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one + * does not even really know the protonation state). Whether sometimes people do specify it + * nevertheless, and what naming format they use to do so, I am not sure (welcome to the + * PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to + * HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but + * something needs to be assumed. Doing this renaming will make the PDB file work in MM + * packages with the all-hydrogen model.""" + if residue.name == "HSD": # neutral HIS, proton on NE2 + resname = "HSE" + if residue.name == "HIS": # neutral HIS, proton on ND1 + resname = "HSD" + if residue.name == "HSP": # doubley-protonated +1 HIS + resname = "HSC" + elif genericFormat: + if residue.name in ["HSD", "HSP", "HSE", "HSC"]: + resname = "HIS" + if (residue.name == "ILE") and (atom.name == "CD"): + atomname = "CD1" + + # write the atom line + for li in range(atom.num_locations()): + if renumber: + f.write( + _pdb_line( + atom.get_location(li), + atomIndex, + ri=ri + 1, + rn=resname, + an=atomname, + ) + + "\n" + ) + else: + f.write( + _pdb_line( + atom.get_location(li), + atomIndex, + rn=resname, + an=atomname, + ) + + "\n" + ) + atomIndex = atomIndex + 1 + + if not noter and (ri == chain.num_residues() - 1): + f.write("TER\n") + if not noend and (ci == self.num_chains() - 1): + f.write("END\n") + + def canonicalize_protein( + self, + level=2, + drop_coors_unknowns=False, + drop_coors_missing_backbone=False, + filter_by_entity=False, + ): + """Canonicalize the calling System object (in place) by assuming that it represents + a protein molecular system. Different canonicalization rigor and options + can be specified but are all optional. + + Args: + level (int): Canonicalization level that determines which nonstandard-to-standard + residue mappings are performed. Possible values are 1, 2 or 3, with 2 being + the default and higher values meaning more rigorous (and less conservative) + canonicalization. With level 1, only truly equivalent mappings are performed + (e.g., different His protonation states are mapped to the canonical residue + name HIS that does not specify protonation). Level 2 adds to this some less + exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- + cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds + even less equivalent but still reasonable mappings--i.e., phosphorylated SER, + THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys + to Cys. + drop_coors_unknowns (bool, optional): if True, will discard structural information + for all residues that are not natural or mappable under the current level. + NOTE: any sequence record for these residues (i.e., if they are part of a + polymer entity) will be preserved. + drop_coors_missing_backbone (bool, optional): if True, will discard structural + information for residues that do not have at least the N, CA, C, and O + backbone atoms. Same note applies regarding the full sequence record as in + the above. + filter_by_entity (bool, optional): if True, will remove any chains that do not + represent polymer/polypeptide entities. This is convenient for cases where a + System object has both protein and non-protein components. However, depending + on how the System object was generated, entity metadata may not have been filled, + so applying this canonicalization approach will remove the entire structure. + For this reason, the option is False by default. + """ + + def _mod_to_standard_aa_mappings( + less_standard: bool, almost_standard: bool, standard: bool + ): + # Perfectly corresponding to standard residues + standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} + + # Almost perfectly corresponding to standard residues: + # * MSE -- selenomethyonine; SEC -- selenocysteine + almost_standard_map = {"MSE": "MET", "SEC": "CYS"} + + # A little less perfectly corresponding pairings, but can be acceptable (depends): + # * HIP -- ND1-phosphohistidine; SEP -- phosphoserine; TPO -- phosphothreonine; + # * PTR -- o-phosphotyrosine. + less_standard_map = { + "HIP": "HIS", + "CSX": "CYS", + "SEP": "SER", + "TPO": "THR", + "PTR": "TYR", + } + + ret = dict() + if standard: + ret.update(standard_map) + if almost_standard: + ret.update(almost_standard_map) + if less_standard: + ret.update(less_standard_map) + return ret + + def _to_standard_aa_mappings( + less_standard: bool, almost_standard: bool, standard: bool + ): + # get the mapping between modifications and their corresponding standard forms + mapping = _mod_to_standard_aa_mappings( + less_standard, almost_standard, standard + ) + + # add mapping between standard names and themselves + import chroma.utility.polyseq as polyseq + + for aa in polyseq.canonical_amino_acids(): + mapping[aa] = aa + + return mapping + + less_standard, almost_standard, standard = False, False, False + if level == 3: + less_standard, almost_standard, standard = True, True, True + elif level == 2: + less_standard, almost_standard, standard = False, True, True + elif level == 1: + less_standard, almost_standard, standard = False, False, True + else: + raise Exception(f"unknown canonicalization level {level}") + + to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) + + # NOTE: need to re-implement the canonicalization procedure such that it: + # 1. checks to make sure entity sequence and structure sequence agree (error if not) + # 2. goes over entities and looks for residues to rename, does the renaming on the entities + # and all chains simultaneously (so that no new entities are created) + # 3. then goes over the structured part and fixes atoms + + # For residue renamings, we will first record all edits and will perform them + # afterwards in one go, so we can judge whether any new entities have to be + # created. The dictionary `esidues_to_rename`` will be as follows: + # entity_id: { + # chain_index: [list of (residue index, rew name) tuples] + # } + chains_to_delete = [] + residues_to_rename = dict() + for ci, chain in enumerate(self.chains()): + entity = chain.get_entity() + if filter_by_entity: + if ( + (entity is None) + or (entity._type != "polymer") + or ("polypeptide" not in entity.polymer_time) + ): + chains_to_delete.append(chain) + continue + + # iterate in reverse order so we can safely delete any residues we find necessary + cleared_residues = 0 + for residue in reversed(list(chain.residues())): + aa = residue.name + delete_atoms = False + # canonicalize amino acid (delete structure if unknown, provided this was asked for) + if aa in to_standard: + aa_new = to_standard[aa] + if aa != aa_new: + # edit any atoms to reflect the mutation + if ( + (aa == "HSD") + or (aa == "HSE") + or (aa == "HSC") + or (aa == "HSP") + ) and (aa_new == "HIS"): + pass + elif ((aa == "MSE") and (aa_new == "MET")) or ( + (aa == "SEC") and (aa_new == "CYS") + ): + SE = residue.find_atom("SE") + if SE is not None: + if aa == "MSE": + SE.residue.rename("SD") + else: + SE.residue.rename("SG") + elif ( + ((aa == "HIP") and (aa_new == "HIS")) + or ((aa == "SEP") and (aa_new == "SER")) + or ((aa == "TPO") and (aa_new == "THR")) + or ((aa == "PTR") and (aa_new == "TYR")) + ): + # delete the phosphate group + for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: + a = residue.find_atom(atomname) + if a is not None: + a.delete() + elif (aa == "CSX") and (aa_new == "CYS"): + a = residue.find_atom("OD") + if a is not None: + a.delete() + + # record residue renaming operation to be done later + entity_id = chain.get_entity_id() + if entity_id is None: + residue.rename(aa_new) + else: + if entity_id not in residues_to_rename: + residues_to_rename[entity_id] = dict() + if ci not in residues_to_rename[entity_id]: + residues_to_rename[entity_id][ci] = list() + residues_to_rename[entity_id][ci].append( + (residue.get_index_in_chain(), aa_new) + ) + else: + if aa == "ARG": + A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} + for an in A: + atom = residue.find_atom(an) + if atom is not None and atom.num_locations(): + A[an] = atom.get_location(0) + if all([a is not None for n, a in A.items()]): + dihe1 = System.dihedral( + A["CD"], A["NE"], A["CZ"], A["NH1"] + ) + dihe2 = System.dihedral( + A["CD"], A["NE"], A["CZ"], A["NH2"] + ) + if abs(dihe1) > abs(dihe2): + A["NH1"].name = "NH2" + A["NH2"].name = "NH1" + elif drop_coors_unknowns: + delete_atoms = True + + if not drop_coors_missing_backbone: + if not delete_atoms and not residue.has_full_backbone(): + delete_atoms = True + + if delete_atoms: + residue.delete_atoms() + cleared_residues += 1 + + # If we have deleted all residues in this chain, then this is probably not + # a protein chain, so get rid of it. Unless we are asked to pay attention to + # the entity type (i.e., whether it is peptidic), in which case the decision + # of whether to keep the chain would have been made previously. + if ( + not filter_by_entity + and (cleared_residues != 0) + and (cleared_residues == chain.num_residues()) + ): + chains_to_delete.append(chain) + + # rename residues differently depending on whether all chains of a given entity + # have the same set of renamings + for entity_id, ops in residues_to_rename.items(): + chain_indices = set(ops.keys()) + entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) + unique_renames = set([tuple(v) for v in ops.values()]) + fork = True + if (chain_indices == entity_chains) and (len(unique_renames) == 1): + # we can rename without updating entities, because all entity chains are updated the same way + fork = False + for ci, renames in ops.items(): + chain = self.get_chain(ci) + for ri, new_name in renames: + chain.get_residue(ri).rename(new_name, fork_entity=fork) + + # now delete any chains + for chain in reversed(chains_to_delete): + chain.delete() + + self._reindex() + + def sequence(self, format="three-letter-list"): + """Returns the full sequence of this System, concatenated over all + chains in their order within the System. + + Args: + format (str): sequence format. Possible options are either + "three-letter-list" (default) or "one-letter-string". + + Returns: + List (default) or string. + """ + if format == "three-letter-list": + seq = [] + else: + seq = "" + + for chain in self.chains(): + seq = seq + chain.sequence(format) + return seq + + @staticmethod + def distance(a1: AtomLocationView, a2: AtomLocationView): + """Computes the distance between atom locations `a1` and `a2`.""" + v21 = a1.coors - a2.coors + return np.linalg.norm(v21) + + @staticmethod + def angle( + a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False + ): + """Computes the angle formed by three 3D points represented by AtomLocationView objects. + + Args: + a1, a2, a3 (AtomLocationView): three 3D points. + radian (bool, optional): if True (default False), will return the angle in radians. + Otherwise, in degrees. + + Returns: + Angle `a1`-`a2`-`a3`. + """ + v21 = a1.coors - a2.coors + v23 = a3.coors - a2.coors + v21 = v21 / np.linalg.norm(v21) + v23 = v23 / np.linalg.norm(v23) + c = np.dot(v21, v23) + return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) + + @staticmethod + def dihedral( + a1: AtomLocationView, + a2: AtomLocationView, + a3: AtomLocationView, + a4: AtomLocationView, + radians=False, + ): + """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. + + Args: + a1, a2, a3, a4 (AtomLocationView): four 3D points. + radian (bool, optional): if True (default False), will return the angle in radians. + Otherwise, in degrees. + + Returns: + Dihedral angle `a1`-`a2`-`a3`-`a4`. + """ + AB = a1.coors - a2.coors + CB = a3.coors - a2.coors + DC = a4.coors - a3.coors + + if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: + raise Exception("some points coincide in dihedral calculation") + + ABxCB = np.cross(AB, CB) + ABxCB = ABxCB / np.linalg.norm(ABxCB) + DCxCB = np.cross(DC, CB) + DCxCB = DCxCB / np.linalg.norm(DCxCB) + + # the following is necessary for values very close to 1 but just above + dotp = np.dot(ABxCB, DCxCB) + if dotp > 1.0: + dotp = 1.0 + elif dotp < -1.0: + dotp = -1.0 + + angle = np.arccos(dotp) + if np.dot(ABxCB, DC) > 0: + angle *= -1 + if not radians: + angle *= 180.0 / np.pi + + return angle + + @staticmethod + def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): + """Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. + Specifically, possible known names in each category are: + 'N', 'NT' + 'CA', 'C', 'CY', 'CAY' + 'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' + 'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' + """ + array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] + if atom_name in ["N", "NT"]: + return array[0] + if atom_name == "CA": + return array[1] + if (atom_name == "C") or (atom_name == "CY"): + return array[2] + if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): + return array[3] + if not no_hyd: + if atom_name in ["H", "HA", "HN"]: + return array[4] + if atom_name.startswith("HT") or atom_name.startswith("HY"): + return array[4] + # Rosetta's N-terinal amine has hydrogens named 1H, 2H, and 3H + if ( + atom_name.startswith("1H") + or atom_name.startswith("2H") + or atom_name.startswith("3H") + ): + return array[4] + return None + + +@dataclass +class SystemEntity: + """A molecular entity represented in a molecular system.""" + + _type: str + _desc: str + _polymer_type: str + _seq: list + _het: list + + def is_polymer(self): + """Returns whether the entity represents a polymer.""" + return self._type == "polymer" + + @classmethod + def guess_entity_and_polymer_type(cls, seq: List): + is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 + polymer_type = None + if is_poly: + entity_type = "polymer" + for ptype in polyseq.polymerType: + if ( + np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) + > 0.8 + ): + polymer_type = polyseq.polymer_type_name(ptype) + break + else: + entity_type = "unknown" + + return entity_type, polymer_type + + @property + def type(self): + return self._type + + @property + def description(self): + return self._desc + + @property + def polymer_type(self): + return self._polymer_type + + @property + def sequence(self): + return self._seq + + @property + def hetero(self): + return self._het + + +@dataclass +class BaseView: + """An abstract base "view" class for accessing different parts of System.""" + + _ix: int + _parent: object + + def get_index(self): + """Return the index of this atom location in its System.""" + return self._ix + + def is_valid(self): + return self._ix >= 0 and self._parent is not None + + def _delete(self): + at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) + self.parent._siblings.delete_child(self.parent._ix, at) + + @property + def parent(self): + return self._parent + + +@dataclass +class ChainView(BaseView): + """A Chain view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, system: System): + self._ix = ix + self._parent = system + self._siblings = system._chains + + def __str__(self): + return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" + + def residues(self): + for rn in range(self.num_residues()): + ri = self._siblings.child_index(self._ix, rn) + yield ResidueView(ri, self) + + def num_residues(self): + """Returns the number of residues in the Chain.""" + return self._siblings.num_children(self._ix) + + def num_structured_residues(self): + return sum([res.has_structure() for res in self.residues()]) + + def num_atoms(self): + return sum([res.num_atoms() for res in self.residues()]) + + def num_atom_locations(self): + return sum([res.num_atom_locations() for res in self.residues()]) + + def sequence(self, format="three-letter-list"): + """Returns the sequence of this chain. See `System::sequence()` for + possible formats. + """ + if format == "three-letter-list": + seq = [None] * self.num_residues() + for ri, residue in enumerate(self.residues()): + seq[ri] = residue.name + return seq + elif format == "one-letter-string": + import chroma.utility.polyseq as polyseq + + seq = [None] * self.num_residues() + for ri, residue in enumerate(self.residues()): + seq[ri] = polyseq.to_single(residue.name) + return "".join(seq) + else: + raise Exception(f"unknown sequence format {format}") + + def get_residue(self, ri: int): + """Get the residue at the specified index within the Chain. + + Args: + ri (int): Residue index within the Chain. + + Returns: + ResidueView object corresponding to the residue in question. + """ + if ri < 0 or ri >= self.num_residues(): + raise Exception( + f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" + ) + ri = self._siblings.child_index(self._ix, ri) + return ResidueView(ri, self) + + def get_residue_index(self, residue: ResidueView): + """Get the index of the given residue in this Chain.""" + return residue._ix - self._siblings.child_index(self._ix, 0) + + def get_atom(self, aidx: int): + """Get the atom at index `aidx` within this chain.""" + if aidx < 0: + raise Exception(f"negative atom index: {aidx}") + off = 0 + for residue in self.residues(): + na = residue.num_atoms() + if aidx < off + na: + return residue.get_atom(aidx - off) + off = off + na + raise Exception( + f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" + ) + + def get_atoms(self): + """Return a list of all atoms in this chain.""" + atoms_views = [] + for residue in self.residues(): + atoms_views.extend(residue.get_atoms()) + return atoms_views + + def __getitem__(self, res_idx: int): + return self.get_residue(res_idx) + + def get_entity_id(self): + """Return the entity ID corresponding to this chain.""" + return self.system._chain_entities[self._ix] + + def get_entity(self): + """Return the entity this chain belongs to.""" + entity_id = self.get_entity_id() + if entity_id is None: + return None + return self.system._entities[entity_id] + + def check_sequence(self): + """Compare the list of residue names of this chain to the corresponding entity sequence record.""" + entity = self.get_entity() + if entity is not None and entity.is_polymer(): + if self.num_residues() != len(entity._seq): + return False + for res, ent_aan in zip(self.residues(), entity._seq): + if res.name != ent_aan: + return False + return True + + def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): + """Add a new residue to this chain. + + Args: + name (str): Residue name. + num (int): Residue number (i.e., residue ID). + authid (str): Author residue ID. + icode (str): Insertion code. + at (int, optional): Index at which to insert the residue. Default + is to append to the end of the chain (i.e., equivalent of ``at` + being equal to the present length of the chain). + """ + if at is None: + at = self.num_residues() + ri = self._siblings.insert_child( + self._ix, + at, + {"name": name, "resnum": num, "authresid": authid, "icode": icode}, + ) + return ResidueView(ri, self) + + def delete(self, keep_entity=False): + """Deletes this Chain from its System. + + Args: + keep_entity (bool, optional): If False (default) and if the chain + being deleted happens to be the last representative of the + entity it belongs to, the entity will be deleted. If True, the + entity will always be kept. + """ + # delete the mention of the chain from assembly information + self.system._assembly_info.delete_chain(self.cid) + + # optionally, delete the corresponding entity if no other chains point to it + if not keep_entity: + eid = self.get_entity_id() + if self.system.num_chains_of_entity(eid) == 0: + self.system.delete_entity(eid) + + self.system._chain_entities.pop(self._ix) + self._siblings.delete(self._ix) + self._ix = -1 # invalidate the view + + @property + def system(self): + return self._parent + + @property + def cid(self): + return self._siblings["cid"][self._ix] + + @property + def segid(self): + return self._siblings["segid"][self._ix] + + @property + def authid(self): + return self._siblings["authid"][self._ix] + + @cid.setter + def cid(self, val): + self._siblings["cid"][self._ix] = val + + @segid.setter + def segid(self, val): + self._siblings["segid"][self._ix] = val + + @authid.setter + def authid(self, val): + self._siblings["authid"][self._ix] = val + + +@dataclass +class ResidueView(BaseView): + """A Residue view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, chain: ChainView): + self._ix = ix + self._parent = chain + self._siblings = chain.system._residues + + def __str__(self): + return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" + + def atoms(self): + off = self._siblings.child_index(self._ix, 0) + for an in range(self.num_atoms()): + yield AtomView(off + an, self) + + def num_atoms(self): + return self._siblings.num_children(self._ix) + + def num_atom_locations(self): + return sum([a.num_locations() for a in self.atoms()]) + + def has_structure(self): + """Returns whether the atom has any structural information (i.e., one or more locations).""" + for a in self.atoms(): + if a.num_locations(): + return True + return False + + def get_atom(self, ai: int): + """Get the atom at the specified index within the Residue. + + Args: + atom_idx (int): Atom index within the Residue. + + Returns: + AtomView object corresponding to the atom in question. + """ + + if ai < 0 or ai >= self.num_atoms(): + raise Exception( + f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" + ) + ai = self._siblings.child_index(self._ix, ai) + return AtomView(ai, self) + + def get_atom_index(self, atom: AtomView): + """Get the index of the given atom in this Residue.""" + return atom._ix - self._siblings.child_index(self._ix, 0) + + def find_atom(self, name): + """Find and return the first atom (as AtomView object) with the given name + within the Residue or None.""" + for atom in self.atoms(): + if atom.name == name: + return atom + return None + + def __getitem__(self, atom_idx: int): + return self.get_atom(atom_idx) + + def get_index_in_chain(self): + """Return the index of the Residue in its parent Chain.""" + return self.chain.get_residue_index(self) + + def rename(self, new_name: str, fork_entity=True): + """Assigns the residue a new name with all proper updates. + + Args: + new_name (str): New residue name. + fork_entity (bool, optional): If True (default) and if parent + chain corresponds to an entity that has other chains + associated with it and there is a real renaming (i.e., + the old name is not the same as the new name), will + make a new (duplicate) entity for to this chain and + will edit the new one, leaving the old one unchanged. + If False, will not perform this regardless. NOTE: + setting this to False can create an inconsistent state + between chain and entity sequence information. + """ + entity_id = self.chain.get_entity_id() + if entity_id is not None: + entity = self.system._entities[entity_id] + ri = self.get_index_in_chain() + if fork_entity and (entity._seq[ri] != new_name): + ci = self.chain.get_index() + entity_id = self.system._ensure_unique_entity(ci) + entity = self.system._entities[entity_id] + entity._seq[ri] = new_name + self._siblings["name"][self._ix] = new_name + + def add_atom( + self, + name: str, + het: bool, + x: float = None, + y: float = None, + z: float = None, + occ: float = 1.0, + B: float = 0.0, + alt: str = " ", + at=None, + ): + """Adds a new atom to the residue (appending it at the end) and + returns an AtomView to it. If atom location information is + specified, will also add a location to the atom. + + Args: + name (str): Atom name. + het (bool): Whether it is a hetero-atom. + x, y, z (float): Atom location coordinates. + occ (float): Occupancy. + B (float): B-factor. + alt (str): Alternative position character. + at (int, optional): Index at which to insert the atom. Default + is to append to the end of the residue (i.e., equivalent of + ``at` being equal to the number of atoms in the residue). + + Returns: + AtomView object corresponding to the newly added atom. + """ + if at is None: + at = self.num_atoms() + ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) + atom = AtomView(ai, self) + + # now add a location to this atom + if x is not None: + atom.add_location(x, y, z, occ, B, alt) + + return atom + + def delete(self, fork_entity=True): + """Deletes this residue from its Chain/System. + + Args: + fork_entity (bool, optional): If True (default) and if parent + chain corresponds to an entity that has other chains + associated with it, will make a new (duplicate) entity + for to this chain and will edit the new one, leaving the + old one unchanged. If False, will not perform this. + NOTE: setting this to False can create an inconsistent state + between chain and entity sequence information. + """ + # update the entity (duplicating, if necessary) + entity_id = self.chain.get_entity_id() + if entity_id is not None: + entity = self.system._entities[entity_id] + ri = self.get_index_in_chain() + if fork_entity: + ci = self.chain.get_index() + entity_id = self.system._ensure_unique_entity(ci) + entity = self.system._entities[entity_id] + entity._seq.pop(ri) + + # delete the residue + self._delete() + self._ix = -1 # invalidate the view + + def delete_atoms(self, atoms=None): + """Delete either the specified list of atoms or all atoms from the residue. + + Args: + atoms (list, optional): List of AtomView objects corresponding to the + atoms to delete. If not specified, will delete all atoms in the residue. + """ + if atoms is None: + atoms = list(self.atoms()) + for atom in reversed(atoms): + if atom.residue != self: + raise Exception(f"Atom {atom} does not belong to Residue {self}") + atom.delete() + + @property + def chain(self): + return self._parent + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return self._siblings["name"][self._ix] + + @property + def num(self): + return self._siblings["resnum"][self._ix] + + @property + def authid(self): + return self._siblings["authresid"][self._ix] + + @property + def icode(self): + return self._siblings["icode"][self._ix] + + def get_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), returns the + list of atoms corresponding to the residue's backbone, in the order: + backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), + and amide hydrogen (H, optional). + + Args: + no_hyd (bool, optional): If True (default), will exclude the amide hydrogen + and only return four atoms. If False, will include the amide hydrogen. + + Returns: + A list with each entry being an AtomView object corresponding to the backbone + atom in the order above or None if the atom does not exist in the residue. + """ + bb = [None] * (4 if no_hyd else 5) + left = len(bb) + for atom in self.atoms(): + i = System.protein_backbone_atom_type(atom.name, no_hyd) + if i is None or bb[i] is not None: + continue + bb[i] = atom + left = left - 1 + if left == 0: + break + return bb + + def has_full_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), returns + whether the residue harbors a structurally defined backbone (i.e., has + all backbone atoms each of which has location information). + + Args: + no_hyd (bool, optional): If True (default), will ignore whether the amide + hydrogen exists or not (if False will consider it). + + Returns: + Boolean indicating whether there is a full backbone in the residue. + """ + bb = self.get_backbone(no_hyd) + return all([(a is not None) and a.num_locations() for a in bb]) + + def delete_non_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), deletes + all atoms except backbone atoms. + + Args: + no_hyd (bool, optional): If True (default), will not consider the amide + hydrogen as a backbone atom (if False will consider it). + """ + to_delete = [] + for atom in self.atoms(): + if System.protein_backbone_atom_type(atom.name, no_hyd) is None: + to_delete.append(atom) + self.delete_atoms(to_delete) + + +@dataclass +class AtomView(BaseView): + """An Atom view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, residue: ResidueView): + self._ix = ix + self._parent = residue + self._siblings = residue.system._atoms + + def __str__(self): + string = self.name + (" (HET) " if self.het else " ") + if self.num_locations() > 0: + string = string + str(self.get_location(0)) + string = string + f" ({self.num_locations()})" + return string + " -> " + str(self.residue) + + def locations(self): + off = self._siblings.child_index(self._ix, 0) + for ln in range(self.num_locations()): + yield AtomLocationView(off + ln, self) + + def num_locations(self): + return self._siblings.num_children(self._ix) + + def __getitem__(self, loc_idx: int): + return self.get_location(loc_idx) + + def get_location(self, li: int = 0): + """Returns the (li+1)-th location of the atom.""" + if li < 0 or li >= self.num_locations(): + raise Exception( + f"location index {li} out of range for Atom with {self.num_locations()} locations" + ) + li = self._siblings.child_index(self._ix, li) + return AtomLocationView(li, self) + + def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): + """Adds a location to this atom, append it to the end. + + Args: + x, y, z (float): coordinates of the location. + occ (float): occupancy for the location. + B (float): B-factor for the location. + alt (str): alternative location character. + at (int, optional): Index at which to insert the location. Default + is to append at the end (i.e., equivalent of ``at` being equal + to the current number of locations). + """ + if at is None: + at = self.num_locations() + li = self._siblings.insert_child( + self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} + ) + return AtomLocationView(li, self) + + def delete(self): + """Deletes this atom from its Residue/Chain/System.""" + self._delete() + self._ix = -1 # invalidate the view + + @property + def residue(self): + return self._parent + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return self._siblings["name"][self._ix] + + @property + def het(self): + return self._siblings["het"][self._ix] + + """Location information getters and setters operate on the default (first) + location for this atom and throw an index error if there are no locations.""" + + @property + def x(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 0] + + @property + def y(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 1] + + @property + def z(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 2] + + @property + def coors(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 0:3] + + @property + def occ(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 3] + + @property + def B(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 4] + + @property + def alt(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["alt"][ix] + + @x.setter + def x(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 0] = val + + @y.setter + def y(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 1] = val + + @z.setter + def z(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 2] = val + + @occ.setter + def occ(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 3] = val + + @B.setter + def B(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 4] = val + + @alt.setter + def alt(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["alt"][ix] = val + + +class DummyAtomView(AtomView): + """An dummy Atom view that can be attached to a residue but that does not + have any locations and with no other information.""" + + def __init__(self, residue: ResidueView): + self._ix = -1 + self._parent = residue + + def __str__(self): + return "DUMMY -> " + str(self.residue) + + def locations(self): + return + yield + + def num_locations(self): + return 0 + + def __getitem__(self, loc_idx: int): + return None + + def get_location(self, li: int = 0): + raise Exception(f"no locations in DUMMY atom") + + def add_location(self, x, y, z, occ, B, alt, at=None): + raise Exception(f"can't add no locations to DUMMY atom") + + @property + def residue(self): + return self._parent + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return None + + @property + def het(self): + return None + + @property + def x(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def y(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def z(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def occ(self): + raise Exception(f"no occupancy in DUMMY atom") + + @property + def B(self): + raise Exception(f"no B-factor in DUMMY atom") + + @property + def alt(self): + raise Exception(f"no alt flag in DUMMY atom") + + @x.setter + def x(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @y.setter + def y(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @z.setter + def z(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @occ.setter + def occ(self, val): + raise Exception(f"can't set occupancy for DUMMY atom") + + @B.setter + def B(self, val): + raise Exception(f"can't set B-factor for DUMMY atom") + + @alt.setter + def alt(self, val): + raise Exception(f"can't set alt flag for DUMMY atom") + + +@dataclass +class AtomLocationView(BaseView): + """An AtomLocation view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, atom: AtomView): + self._ix = ix + self._parent = atom + self._siblings = atom.system._locations + + def __str__(self): + return f"{self.x} {self.y} {self.z}" + + def swap(self, other: AtomLocationView): + """Swaps information between itself and the provided atom location. + + Args: + other (AtomLocationView): the other atom location to swap with. + """ + self.x, other.x = other.x, self.x + self.y, other.y = other.y, self.y + self.z, other.z = other.z, self.z + self.occ, other.occ = other.occ, self.occ + self.B, other.B = other.B, self.B + self.alt, other.alt = other.alt, self.alt + + def defined(self): + """Return whether this is a valid location.""" + return (self.x is not None) and (self.y is not None) and (self.z is not None) + + @property + def atom(self): + return self._parent + + @property + def residue(self): + return self.atom.residue + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def x(self): + return self.system._locations["coor"][self._ix, 0] + + @property + def y(self): + return self.system._locations["coor"][self._ix, 1] + + @property + def z(self): + return self.system._locations["coor"][self._ix, 2] + + @property + def occ(self): + return self.system._locations["coor"][self._ix, 3] + + @property + def B(self): + return self.system._locations["coor"][self._ix, 4] + + @property + def alt(self): + return self.system._locations["alt"][self._ix] + + @property + def coors(self): + return np.array(self.system._locations["coor"][self._ix, 0:3]) + + @property + def coor_info(self): + return np.array(self.system._locations["coor"][self._ix]) + + @x.setter + def x(self, val): + self.system._locations["coor"][self._ix, 0] = val + + @y.setter + def y(self, val): + self.system._locations["coor"][self._ix, 1] = val + + @z.setter + def z(self, val): + self.system._locations["coor"][self._ix, 2] = val + + @coors.setter + def coors(self, val): + self.system._locations["coor"][self._ix, 0:3] = val + + @coor_info.setter + def coor_info(self, val): + self.system._locations["coor"][self._ix] = val + + @occ.setter + def occ(self, val): + self.system._locations["coor"][self._ix, 3] = val + + @B.setter + def B(self, val): + self.system._locations["coor"][self._ix, 4] = val + + @alt.setter + def alt(self, val): + self.system._locations["alt"][self._ix] = val + + +class ExpressionTreeEvaluator: + """A class for evaluating custom logical parenthetical expressions. The + implementation is very generic, supports nullary, unary, and binary + operators, and does not know anything about what the expressions actually + mean. Instead the class interprets the expression as a tree of sub- + expressions, governed by parentheses and operators, and traverses the + calling upon a user-specified evaluation function to evaluate leaf + nodes as the tree is gradually collapsed into a single node. This + can be used for evaluating set expressions, algebraic expressions, and + others. + + Args: + operators_nullary (list): A list of strings designating nullary operators + (i.e., operators that do not have any operands). E.g., if the language + describes selection algebra, these could be "hyd", "all", or "none"]. + operators_unary (list): A list of strings designating unary operators + (i.e., operators that have one operand, which must comes to the right + of the operator). E.g., if the language describes selection algebra, + these could be "name", "resid", or "chain". + operators_binary (list): A list of strings designating binary operators + (i.e., operators that have two operands, one on each side of the + operator). E.g., if the language describes selection algebra, thse + could be "and", "or", or "around". + eval_function (str): A function that is able to evaluate a leaf node of + the expression tree. It shall accept three parameters: + + operator (str): name of the operator + left: the left operand. Will be None if the left operand is missing or + not relevant. Otherwise, can be either a list of strings, which + should represent an evaluatable sub-expression corresponding to the + left operand, or the result of a prior evaluation of this function. + right: Same as `left` but for the right operand. + + The function should attempt to evaluate the resulting expression and + return None in the case of failing or a dictionary with the result of + the evaluation stored under key "result". + left_associativity (bool): If True (the default), operators are taken to be + left-associative. Meaning something like "A and B or C" is "(A and B) or C". + If False, the operators are taken to be right-associative, such that + the same expression becomes "A and (B or C)". NOTE: MST is right-associative + but often human intiution tends to be left-associative. + debug (bool): If True (default is false), will print a great deal of debugging + messages to help diagnose any evaluation problems. + """ + + def __init__( + self, + operators_nullary: list, + operators_unary: list, + operators_binary: list, + eval_function: function, + left_associativity: bool = True, + debug: bool = False, + ): + self.operators_nullary = operators_nullary + self.operators_unary = operators_unary + self.operators_binary = operators_binary + self.operators = operators_nullary + operators_unary + operators_binary + self.eval_function = eval_function + self.debug = debug + self.left_associativity = left_associativity + + def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): + def _collect_operands(E, j): + # collect all operands before hitting an operator + operands = [] + for k in range(len(E[j:])): + if E[j + k] in self.operators: + k = k - 1 + break + operands.append(E[j + k]) + return operands, j + k + 1 + + def _find_matching_close_paren(E, beg: int): + c = 0 + for i in range(beg, len(E)): + if E[i] == "(": + c = c + 1 + elif E[i] == ")": + c = c - 1 + if c == 0: + return i + return None + + def _my_eval(op, left, right, debug=False): + if debug: + print( + f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" + ) + result = self.eval_function(op, left, right) + if debug: + print(f"\t-> got result {operand_str(result)}") + return result + + def operand_str(operand): + if isinstance(operand, dict): + if "result" in operand and len(operand["result"]) > 15: + vec = list(operand["result"]) + beg = ", ".join([str(i) for i in vec[:5]]) + end = ", ".join([str(i) for i in vec[-5:]]) + return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" + return str(operand) + return str(operand) + + left, right, op = None, None, None + if debug: + print(f"-> received {E[i:]}") + + while i < len(E): + if all([x is None for x in (left, right, op)]): + # first part can either be a left parenthesis, a left operand, a nullary operator, or a unary operator + if E[i] == "(": + end = _find_matching_close_paren(E, i) + if end is None: + return None, f"parenthesis imbalance starting with {E[i:]}" + # evaluate expression inside the parentheses, and it becomes the left operand + left, rem = self._traverse_expression_tree( + E[i + 1 : end], 0, eval_all=True, debug=debug + ) + if left is None: + return None, rem + i = end + 1 + if not eval_all: + return left, i + elif E[i] in self.operators_nullary: + # evaluate nullary op + left = _my_eval(E[i], None, None, debug) + if left is None: + return None, f"failed to evaluate nullary operator '{E[i]}'" + i = i + 1 + elif E[i] in self.operators_unary: + op = E[i] + i = i + 1 + elif E[i] in self.operators: + # an operator other than a unary operator cannot appear first + return None, f"unexpected binary operator in the context {E[i:]}" + else: + # if not an operator, then we are looking at operand(s) + left, i = _collect_operands(E, i) + elif (left is not None) and (op is None) and (right is None): + # we have a left operand and now looking for a binary operator + if E[i] not in self.operators_binary: + return ( + None, + f"expected end or a binary operator when got '{E[i]}' in expression: {E}", + ) + op = E[i] + i = i + 1 + elif ( + (left is None) and (op in self.operators_unary) and (right is None) + ) or ( + (left is not None) and (op in self.operators_binary) and (right is None) + ): + # we saw a unary operator before and now looking for a right operand, another unary operator, or a nullary operator + # OR + # we have a left operand and a binary operator before, now looking for a right operand, a unary operator, or a nullary operator + if ( + E[i] in (self.operators_nullary + self.operators_unary) + or E[i] == "(" + ): + right, i = self._traverse_expression_tree( + E, i, eval_all=not self.left_associativity, debug=debug + ) + if right is None: + return None, i + else: + right, i = _collect_operands(E, i) + + # We are now ready to evaluate, because: + # we have a unary operator and a right operand + # OR + # we have a left operand, a binary operator, and a right operand + result = _my_eval(op, left, right, debug) + if result is None: + return ( + None, + f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", + ) + if not eval_all: + return result, i + left = result + op, right = None, None + + else: + return ( + None, + f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", + ) + + if (op is not None) or (right is not None): + return None, f"expression ended unexpectedly" + if left is None: + return None, f"failed to evaluate expression: {E}" + + return left, i + + def evaluate(self, expression: str): + """Evaluates the expression and returns the result.""" + + def _split_tokens(expr): + # first split by parentheses (preserving the parentheses themselves) + parts = list(re.split("([()])", expr)) + # then split by space (getting rid of space) + return [ + t.strip() + for p in parts + for t in re.split("\s+", p.strip()) + if t.strip() != "" + ] + + # parse expression into tokens + E = _split_tokens(expression) + val, rem = self._traverse_expression_tree(E, debug=self.debug) + if val is None: + raise Exception( + f"failed to evaluate expression: '{expression}', reason: {rem}" + ) + + return val["result"]