Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Protein data type. | |
Adapted from original code by alexechu. | |
""" | |
import dataclasses | |
import io | |
from typing import Any, Mapping, Optional | |
from core import residue_constants | |
from Bio.PDB import PDBParser | |
import numpy as np | |
FeatureDict = Mapping[str, np.ndarray] | |
ModelOutput = Mapping[str, Any] # Is a nested dict. | |
# Complete sequence of chain IDs supported by the PDB format. | |
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" | |
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. | |
class Protein: | |
"""Protein structure representation.""" | |
# Cartesian coordinates of atoms in angstroms. The atom types correspond to | |
# residue_constants.atom_types, i.e. the first three are N, CA, CB. | |
atom_positions: np.ndarray # [num_res, num_atom_type, 3] | |
# Amino-acid type for each residue represented as an integer between 0 and | |
# 20, where 20 is 'X'. | |
aatype: np.ndarray # [num_res] | |
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom | |
# is present and 0.0 if not. This should be used for loss masking. | |
atom_mask: np.ndarray # [num_res, num_atom_type] | |
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. | |
residue_index: np.ndarray # [num_res] | |
# 0-indexed number corresponding to the chain in the protein that this residue | |
# belongs to. | |
chain_index: np.ndarray # [num_res] | |
# B-factors, or temperature factors, of each residue (in sq. angstroms units), | |
# representing the displacement of the residue from its ground truth mean | |
# value. | |
b_factors: np.ndarray # [num_res, num_atom_type] | |
def __post_init__(self): | |
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: | |
raise ValueError( | |
f"Cannot build an instance with more than {PDB_MAX_CHAINS} chains " | |
"because these cannot be written to PDB format." | |
) | |
def from_pdb_string( | |
pdb_str: str, chain_id: Optional[str] = None, protein_only: bool = False | |
) -> Protein: | |
"""Takes a PDB string and constructs a Protein object. | |
WARNING: All non-standard residue types will be converted into UNK. All | |
non-standard atoms will be ignored. | |
Args: | |
pdb_str: The contents of the pdb file | |
chain_id: If chain_id is specified (e.g. A), then only that chain | |
is parsed. Otherwise all chains are parsed. | |
Returns: | |
A new `Protein` parsed from the pdb contents. | |
""" | |
pdb_fh = io.StringIO(pdb_str) | |
parser = PDBParser(QUIET=True) | |
structure = parser.get_structure("none", pdb_fh) | |
models = list(structure.get_models()) | |
if len(models) != 1: | |
raise ValueError( | |
f"Only single model PDBs are supported. Found {len(models)} models." | |
) | |
model = models[0] | |
atom_positions = [] | |
aatype = [] | |
atom_mask = [] | |
residue_index = [] | |
chain_ids = [] | |
b_factors = [] | |
for chain in model: | |
if chain_id is not None and chain.id != chain_id: | |
continue | |
for res in chain: | |
if protein_only and res.id[0] != " ": | |
continue | |
if res.id[2] != " ": | |
pass | |
# raise ValueError( | |
# f"PDB contains an insertion code at chain {chain.id} and residue " | |
# f"index {res.id[1]}. These are not supported." | |
# ) | |
res_shortname = residue_constants.restype_3to1.get(res.resname, "X") | |
restype_idx = residue_constants.restype_order.get( | |
res_shortname, residue_constants.restype_num | |
) | |
pos = np.zeros((residue_constants.atom_type_num, 3)) | |
mask = np.zeros((residue_constants.atom_type_num,)) | |
res_b_factors = np.zeros((residue_constants.atom_type_num,)) | |
for atom in res: | |
if atom.name not in residue_constants.atom_types: | |
continue | |
pos[residue_constants.atom_order[atom.name]] = atom.coord | |
mask[residue_constants.atom_order[atom.name]] = 1.0 | |
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor | |
if np.sum(mask) < 0.5: | |
# If no known atom positions are reported for the residue then skip it. | |
continue | |
aatype.append(restype_idx) | |
atom_positions.append(pos) | |
atom_mask.append(mask) | |
residue_index.append(res.id[1]) | |
chain_ids.append(chain.id) | |
b_factors.append(res_b_factors) | |
# Chain IDs are usually characters so map these to ints. | |
unique_chain_ids = np.unique(chain_ids) | |
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} | |
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) | |
return Protein( | |
atom_positions=np.array(atom_positions), | |
atom_mask=np.array(atom_mask), | |
aatype=np.array(aatype), | |
residue_index=np.array(residue_index), | |
chain_index=chain_index, | |
b_factors=np.array(b_factors), | |
) | |
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: | |
chain_end = "TER" | |
return ( | |
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} " | |
f"{chain_name:>1}{residue_index:>4}" | |
) | |
def are_atoms_bonded(res3name, atom1_name, atom2_name): | |
lookup_table = residue_constants.standard_residue_bonds | |
for bond in lookup_table[res3name]: | |
if bond.atom1_name == atom1_name and bond.atom2_name == atom2_name: | |
return True | |
elif bond.atom1_name == atom2_name and bond.atom2_name == atom1_name: | |
return True | |
return False | |
def to_pdb(prot: Protein, conect=False) -> str: | |
"""Converts a `Protein` instance to a PDB string. | |
Args: | |
prot: The protein to convert to PDB. | |
Returns: | |
PDB string. | |
""" | |
restypes = residue_constants.restypes + ["X"] | |
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK") | |
atom_types = residue_constants.atom_types | |
pdb_lines = [] | |
atom_mask = prot.atom_mask | |
aatype = prot.aatype | |
atom_positions = prot.atom_positions | |
residue_index = prot.residue_index.astype(np.int32) | |
chain_index = prot.chain_index.astype(np.int32) | |
b_factors = prot.b_factors | |
if np.any(aatype > residue_constants.restype_num): | |
raise ValueError("Invalid aatypes.") | |
# Construct a mapping from chain integer indices to chain ID strings. | |
chain_ids = {} | |
for i in np.unique(chain_index): # np.unique gives sorted output. | |
if i >= PDB_MAX_CHAINS: | |
raise ValueError( | |
f"The PDB format supports at most {PDB_MAX_CHAINS} chains." | |
) | |
chain_ids[i] = PDB_CHAIN_IDS[i] | |
pdb_lines.append("MODEL 1") | |
atom_index = 1 | |
last_chain_index = chain_index[0] | |
conect_lines = [] | |
# Add all atom sites. | |
for i in range(aatype.shape[0]): | |
# Close the previous chain if in a multichain PDB. | |
if last_chain_index != chain_index[i]: | |
pdb_lines.append( | |
_chain_end( | |
atom_index, | |
res_1to3(aatype[i - 1]), | |
chain_ids[chain_index[i - 1]], | |
residue_index[i - 1], | |
) | |
) | |
last_chain_index = chain_index[i] | |
atom_index += 1 # Atom index increases at the TER symbol. | |
res_name_3 = res_1to3(aatype[i]) | |
atoms_appended_for_res = [] | |
for atom_name, pos, mask, b_factor in zip( | |
atom_types, atom_positions[i], atom_mask[i], b_factors[i] | |
): | |
if mask < 0.5: | |
continue | |
record_type = "ATOM" | |
name = atom_name if len(atom_name) == 4 else f" {atom_name}" | |
alt_loc = "" | |
insertion_code = "" | |
occupancy = 1.00 | |
element = atom_name[0] # Protein supports only C, N, O, S, this works. | |
charge = "" | |
# PDB is a columnar format, every space matters here! | |
atom_line = ( | |
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" | |
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}" | |
f"{residue_index[i]:>4}{insertion_code:>1} " | |
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" | |
f"{occupancy:>6.2f}{b_factor:>6.2f} " | |
f"{element:>2}{charge:>2}" | |
) | |
pdb_lines.append(atom_line) | |
for prev_atom_idx, prev_atom in atoms_appended_for_res: | |
if are_atoms_bonded(res_name_3, atom_name, prev_atom): | |
conect_line = f"CONECT{prev_atom_idx:5d}{atom_index:5d}\n" | |
conect_lines.append(conect_line) | |
atoms_appended_for_res.append((atom_index, atom_name)) | |
if atom_name == "N": | |
n_atom_idx = atom_index | |
if atom_name == "C": | |
c_atom_idx = atom_index | |
atom_index += 1 | |
if i > 0: | |
conect_line = f"CONECT{prev_c_atom_idx:5d}{n_atom_idx:5d}\n" | |
conect_lines.append(conect_line) | |
prev_c_atom_idx = c_atom_idx | |
# Close the final chain. | |
pdb_lines.append( | |
_chain_end( | |
atom_index, | |
res_1to3(aatype[-1]), | |
chain_ids[chain_index[-1]], | |
residue_index[-1], | |
) | |
) | |
pdb_lines.append("ENDMDL") | |
pdb_lines.append("END") | |
# Pad all lines to 80 characters. | |
pdb_lines = [line.ljust(80) for line in pdb_lines] | |
pdb_str = "\n".join(pdb_lines) + "\n" # Add terminating newline. | |
if conect: | |
conect_str = "".join(conect_lines) + "\n" | |
return pdb_str, conect_str | |
return pdb_str | |
def ideal_atom_mask(prot: Protein) -> np.ndarray: | |
"""Computes an ideal atom mask. | |
`Protein.atom_mask` typically is defined according to the atoms that are | |
reported in the PDB. This function computes a mask according to heavy atoms | |
that should be present in the given sequence of amino acids. | |
Args: | |
prot: `Protein` whose fields are `numpy.ndarray` objects. | |
Returns: | |
An ideal atom mask. | |
""" | |
return residue_constants.STANDARD_ATOM_MASK[prot.aatype] | |
def from_prediction( | |
features: FeatureDict, | |
result: ModelOutput, | |
b_factors: Optional[np.ndarray] = None, | |
remove_leading_feature_dimension: bool = True, | |
) -> Protein: | |
"""Assembles a protein from a prediction. | |
Args: | |
features: Dictionary holding model inputs. | |
result: Dictionary holding model outputs. | |
b_factors: (Optional) B-factors to use for the protein. | |
remove_leading_feature_dimension: Whether to remove the leading dimension | |
of the `features` values. | |
Returns: | |
A protein instance. | |
""" | |
fold_output = result["structure_module"] | |
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: | |
return arr[0] if remove_leading_feature_dimension else arr | |
if "asym_id" in features: | |
chain_index = _maybe_remove_leading_dim(features["asym_id"]) | |
else: | |
chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"])) | |
if b_factors is None: | |
b_factors = np.zeros_like(fold_output["final_atom_mask"]) | |
return Protein( | |
aatype=_maybe_remove_leading_dim(features["aatype"]), | |
atom_positions=fold_output["final_atom_positions"], | |
atom_mask=fold_output["final_atom_mask"], | |
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1, | |
chain_index=chain_index, | |
b_factors=b_factors, | |
) | |