GenFBDD / datasets /pdb.py
libokj's picture
Minify
c17cba8
raw
history blame
25.7 kB
# Significant contribution from Ben Fry
import copy
import os.path
import pickle
import random
from multiprocessing import Pool
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, MolFromSmiles
from scipy.spatial.distance import pdist, squareform
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.utils import subgraph
from tqdm import tqdm
from datasets.constants import aa_to_cg_indices, amino_acid_smiles, cg_rdkit_indices
from datasets.parse_chi import aa_long2short, atom_order
from datasets.process_mols import new_extract_receptor_structure, get_lig_graph, generate_conformer
from utils.torsion import get_transformation_mask
def read_strings_from_txt(path):
# every line will be one element of the returned list
with open(path) as file:
lines = file.readlines()
return [line.rstrip() for line in lines]
def compute_num_ca_neighbors(coords, cg_coords, idx, is_valid_bb_node, max_dist=5, buffer_residue_num=7):
"""
Counts number of residues with heavy atoms within max_dist (Angstroms) of this sidechain that are not
residues within +/- buffer_residue_num in primary sequence.
From Ben's code
Note: Gabriele removed the chain_index
"""
# Extract coordinates of all residues in the protein.
bb_coords = coords
# Compute the indices that we should not consider interactions.
excluded_neighbors = [idx - x for x in reversed(range(0, buffer_residue_num+1)) if (idx - x) >= 0]
excluded_neighbors.extend([idx + x for x in range(1, buffer_residue_num+1)])
# Create indices of an N x M distance matrix where N is num BB nodes and M is num CG nodes.
e_idx = torch.stack([
torch.arange(bb_coords.shape[0]).unsqueeze(-1).expand((-1, cg_coords.shape[0])).flatten(),
torch.arange(cg_coords.shape[0]).unsqueeze(0).expand((bb_coords.shape[0], -1)).flatten()
])
# Expand bb_coords and cg_coords into the same dimensionality.
bb_coords_exp = bb_coords[e_idx[0]]
cg_coords_exp = cg_coords[e_idx[1]].unsqueeze(1)
# Every row is distance of chemical group to each atom in backbone coordinate frame.
bb_exp_idces, _ = (torch.cdist(bb_coords_exp, cg_coords_exp).squeeze(-1) < max_dist).nonzero(as_tuple=True)
bb_idces_within_thresh = torch.unique(e_idx[0][bb_exp_idces])
# Only count residues that are not adjacent or origin in primary sequence and are valid backbone residues (fully resolved coordinate frame).
bb_idces_within_thresh = bb_idces_within_thresh[~torch.isin(bb_idces_within_thresh, torch.tensor(excluded_neighbors)) & is_valid_bb_node[bb_idces_within_thresh]]
return len(bb_idces_within_thresh)
def identify_valid_vandermers(args):
"""
Constructs a tensor containing all the number of contacts for each residue that can be sampled from for chemical groups.
By using every sidechain as a chemical group, we will load the actual chemical groups at training time.
These can be used to sample as probabilities once divided by the sum.
"""
complex_graph, max_dist, buffer_residue_num = args
# Constructs a mask tracking whether index is a valid coordinate frame / residue label to train over.
#is_in_residue_vocabulary = torch.tensor([x in aa_short2long for x in data['seq']]).bool()
coords, seq = complex_graph.coords, complex_graph.seq
is_valid_bb_node = (coords[:, :4].isnan().sum(dim=(1,2)) == 0).bool() #* is_in_residue_vocabulary
valid_cg_idces = []
for idx, aa in enumerate(seq):
if aa not in aa_to_cg_indices:
valid_cg_idces.append(0)
else:
indices = aa_to_cg_indices[aa]
cg_coordinates = coords[idx][indices]
# remove chemical group residues that aren't fully resolved.
if torch.any(cg_coordinates.isnan()).item():
valid_cg_idces.append(0)
continue
nbr_count = compute_num_ca_neighbors(coords, cg_coordinates, idx, is_valid_bb_node,
max_dist=max_dist, buffer_residue_num=buffer_residue_num)
valid_cg_idces.append(nbr_count)
return complex_graph.name, torch.tensor(valid_cg_idces)
def fast_identify_valid_vandermers(coords, seq, max_dist=5, buffer_residue_num=7):
offset = 10000 + max_dist
R = coords.shape[0]
coords = coords.numpy().reshape(-1, 3)
pdist_mat = squareform(pdist(coords))
pdist_mat = pdist_mat.reshape((R, 14, R, 14))
pdist_mat = np.nan_to_num(pdist_mat, nan=offset)
pdist_mat = np.min(pdist_mat, axis=(1, 3))
# compute pairwise distances
pdist_mat = pdist_mat + np.diag(np.ones(len(seq)) * offset)
for i in range(1, buffer_residue_num+1):
pdist_mat += np.diag(np.ones(len(seq)-i) * offset, k=i) + np.diag(np.ones(len(seq)-i) * offset, k=-i)
# get number of residues that are within max_dist of each other
nbr_count = np.sum(pdist_mat < max_dist, axis=1)
return torch.tensor(nbr_count)
def compute_cg_features(aa, aa_smile):
"""
Given an amino acid and a smiles string returns the stacked tensor of chemical group atom encodings.
The order of the output tensor rows corresponds to the index the atoms appear in aa_to_cg_indices from constants.
"""
# Handle any residues that we don't have chemical groups for (ex: GLY if not using bb_cnh and bb_cco)
aa_short = aa_long2short[aa]
if aa_short not in aa_to_cg_indices:
return None
# Create rdkit molecule from smiles string.
mol = Chem.MolFromSmiles(aa_smile)
complex_graph = HeteroData()
get_lig_graph(mol, complex_graph)
atoms_to_keep = torch.tensor([i for i, _ in cg_rdkit_indices[aa].items()]).long()
complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr = \
subgraph(atoms_to_keep, complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr, relabel_nodes=True)
complex_graph['ligand'].x = complex_graph['ligand'].x[atoms_to_keep]
edge_mask, mask_rotate = get_transformation_mask(complex_graph)
complex_graph['ligand'].edge_mask = torch.tensor(edge_mask)
complex_graph['ligand'].mask_rotate = mask_rotate
return complex_graph
class PDBSidechain(Dataset):
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0,
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, remove_hs=True, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, sequences_to_embeddings=None,
knn_only_graph=True, multiplicity=1, vandermers_max_dist=5, vandermers_buffer_residue_num=7,
vandermers_min_contacts=5, remove_second_segment=False, merge_clusters=1, vandermers_extraction=True,
add_random_ligand=False):
super(PDBSidechain, self).__init__(root, transform)
assert remove_hs == True, "not implemented yet"
self.root = root
self.split = split
self.limit_complexes = limit_complexes
self.receptor_radius = receptor_radius
self.knn_only_graph = knn_only_graph
self.multiplicity = multiplicity
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.num_workers = num_workers
self.sequences_to_embeddings = sequences_to_embeddings
self.remove_second_segment = remove_second_segment
self.merge_clusters = merge_clusters
self.vandermers_extraction = vandermers_extraction
self.add_random_ligand = add_random_ligand
self.all_atoms = all_atoms
self.atom_radius = atom_radius
self.atom_max_neighbors = atom_max_neighbors
if vandermers_extraction:
self.cg_node_feature_lookup_dict = {aa_long2short[aa]: compute_cg_features(aa, aa_smile) for aa, aa_smile in
amino_acid_smiles.items()}
self.cache_path = os.path.join(cache_path, f'PDB3_limit{self.limit_complexes}_INDEX{self.split}'
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ ('' if not self.knn_only_graph else '_knnOnly'))
self.read_split()
if not self.check_all_proteins():
os.makedirs(self.cache_path, exist_ok=True)
self.preprocess()
self.vandermers_max_dist = vandermers_max_dist
self.vandermers_buffer_residue_num = vandermers_buffer_residue_num
self.vandermers_min_contacts = vandermers_min_contacts
self.collect_proteins()
filtered_proteins = []
if vandermers_extraction:
for complex_graph in tqdm(self.protein_graphs):
if complex_graph.name in self.vandermers and torch.any(self.vandermers[complex_graph.name] >= 10):
filtered_proteins.append(complex_graph)
print(f"Computed vandermers and kept {len(filtered_proteins)} proteins out of {len(self.protein_graphs)}")
else:
filtered_proteins = self.protein_graphs
second_filter = []
for complex_graph in tqdm(filtered_proteins):
if sequences_to_embeddings is None or complex_graph.orig_seq in sequences_to_embeddings:
second_filter.append(complex_graph)
print(f"Checked embeddings available and kept {len(second_filter)} proteins out of {len(filtered_proteins)}")
self.protein_graphs = second_filter
# filter clusters that have no protein graphs
self.split_clusters = list(set([g.cluster for g in self.protein_graphs]))
self.cluster_to_complexes = {c: [] for c in self.split_clusters}
for p in self.protein_graphs:
self.cluster_to_complexes[p['cluster']].append(p)
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_complexes[c]) > 0]
print("Total elements in set", len(self.split_clusters) * self.multiplicity // self.merge_clusters)
self.name_to_complex = {p.name: p for p in self.protein_graphs}
self.define_probabilities()
if self.add_random_ligand:
# read csv with all smiles
with open('data/smiles_list.csv', 'r') as f:
self.smiles_list = f.readlines()
self.smiles_list = [s.split(',')[0] for s in self.smiles_list]
def define_probabilities(self):
if not self.vandermers_extraction:
return
if self.vandermers_min_contacts is not None:
self.probabilities = torch.arange(1000) - self.vandermers_min_contacts + 1
self.probabilities[:self.vandermers_min_contacts] = 0
else:
with open('data/pdbbind_counts.pkl', 'rb') as f:
pdbbind_counts = pickle.load(f)
pdb_counts = torch.ones(1000)
for contacts in self.vandermers.values():
pdb_counts.index_add_(0, contacts, torch.ones(contacts.shape))
print(pdbbind_counts[:30])
print(pdb_counts[:30])
self.probabilities = pdbbind_counts / pdb_counts
self.probabilities[:7] = 0
def len(self):
return len(self.split_clusters) * self.multiplicity // self.merge_clusters
def get(self, idx=None, protein=None, smiles=None):
assert idx is not None or (protein is not None and smiles is not None), "provide idx or protein or smile"
if protein is None or smiles is None:
idx = idx % len(self.split_clusters)
if self.merge_clusters > 1:
idx = idx * self.merge_clusters
idx = idx + random.randint(0, self.merge_clusters - 1)
idx = min(idx, len(self.split_clusters) - 1)
cluster = self.split_clusters[idx]
protein_graph = copy.deepcopy(random.choice(self.cluster_to_complexes[cluster]))
else:
protein_graph = copy.deepcopy(self.name_to_complex[protein])
if self.sequences_to_embeddings is not None:
#print(self.sequences_to_embeddings[protein_graph.orig_seq].shape, len(protein_graph.orig_seq), protein_graph.to_keep.shape)
if len(protein_graph.orig_seq) != len(self.sequences_to_embeddings[protein_graph.orig_seq]):
print('problem with ESM embeddings')
return self.get(random.randint(0, self.len()))
lm_embeddings = self.sequences_to_embeddings[protein_graph.orig_seq][protein_graph.to_keep]
protein_graph['receptor'].x = torch.cat([protein_graph['receptor'].x, lm_embeddings], dim=1)
if self.vandermers_extraction:
# select sidechain to remove
vandermers_contacts = self.vandermers[protein_graph.name]
vandermers_probs = self.probabilities[vandermers_contacts].numpy()
if not np.any(vandermers_contacts.numpy() >= 10):
print('no vandarmers >= 10 retrying with new one')
return self.get(random.randint(0, self.len()))
sidechain_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
# remove part of the sequence
residues_to_keep = np.ones(len(protein_graph.seq), dtype=bool)
residues_to_keep[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
if self.remove_second_segment:
pos_idx = protein_graph['receptor'].pos[sidechain_idx]
limit_closeness = 10
far_enough = torch.sum((protein_graph['receptor'].pos - pos_idx[None, :]) ** 2, dim=-1) > limit_closeness ** 2
vandermers_probs = vandermers_probs * far_enough.float().numpy()
vandermers_probs[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = 0
if np.all(vandermers_probs<=0):
print('no second vandermer available retrying with new one')
return self.get(random.randint(0, self.len()))
sc2_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
residues_to_keep[max(0, sc2_idx - self.vandermers_buffer_residue_num):
min(sc2_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
residues_to_keep = torch.from_numpy(residues_to_keep)
protein_graph['receptor'].pos = protein_graph['receptor'].pos[residues_to_keep]
protein_graph['receptor'].x = protein_graph['receptor'].x[residues_to_keep]
protein_graph['receptor'].side_chain_vecs = protein_graph['receptor'].side_chain_vecs[residues_to_keep]
protein_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
subgraph(residues_to_keep, protein_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
# create the sidechain ligand
sidechain_aa = protein_graph.seq[sidechain_idx]
ligand_graph = self.cg_node_feature_lookup_dict[sidechain_aa]
ligand_graph['ligand'].pos = protein_graph.coords[sidechain_idx][protein_graph.mask[sidechain_idx]]
for type in ligand_graph.node_types + ligand_graph.edge_types:
for key, value in ligand_graph[type].items():
protein_graph[type][key] = value
protein_graph['ligand'].orig_pos = protein_graph['ligand'].pos.numpy()
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
protein_graph['ligand'].pos = protein_graph['ligand'].pos - protein_center
protein_graph.original_center = protein_center
protein_graph['receptor_name'] = protein_graph.name
else:
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
protein_graph.original_center = protein_center
protein_graph['receptor_name'] = protein_graph.name
if self.add_random_ligand:
if smiles is not None:
mol = MolFromSmiles(smiles)
try:
generate_conformer(mol)
except Exception as e:
print("failed to generate the given ligand returning None", e)
return None
else:
success = False
while not success:
smiles = random.choice(self.smiles_list)
mol = MolFromSmiles(smiles)
try:
success = not generate_conformer(mol)
except Exception as e:
print(e, "changing ligand")
lig_graph = HeteroData()
get_lig_graph(mol, lig_graph)
edge_mask, mask_rotate = get_transformation_mask(lig_graph)
lig_graph['ligand'].edge_mask = torch.tensor(edge_mask)
lig_graph['ligand'].mask_rotate = mask_rotate
lig_graph['ligand'].smiles = smiles
lig_graph['ligand'].pos = lig_graph['ligand'].pos - torch.mean(lig_graph['ligand'].pos, dim=0, keepdim=True)
for type in lig_graph.node_types + lig_graph.edge_types:
for key, value in lig_graph[type].items():
protein_graph[type][key] = value
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
if hasattr(protein_graph, a):
delattr(protein_graph, a)
if hasattr(protein_graph['receptor'], a):
delattr(protein_graph['receptor'], a)
return protein_graph
def read_split(self):
# read CSV file
df = pd.read_csv(self.root + "/list.csv")
print("Loaded list CSV file")
# get clusters and filter by split
if self.split == "train":
val_clusters = set(read_strings_from_txt(self.root + "/valid_clusters.txt"))
test_clusters = set(read_strings_from_txt(self.root + "/test_clusters.txt"))
clusters = df["CLUSTER"].unique()
clusters = [int(c) for c in clusters if c not in val_clusters and c not in test_clusters]
elif self.split == "val":
clusters = [int(s) for s in read_strings_from_txt(self.root + "/valid_clusters.txt")]
elif self.split == "test":
clusters = [int(s) for s in read_strings_from_txt(self.root + "/test_clusters.txt")]
else:
raise ValueError("Split must be train, val or test")
print(self.split, "clusters", len(clusters))
clusters = set(clusters)
self.chains_in_cluster = []
complexes_in_cluster = set()
for chain, cluster in zip(df["CHAINID"], df["CLUSTER"]):
if cluster not in clusters:
continue
# limit to one chain per complex
if chain[:4] not in complexes_in_cluster:
self.chains_in_cluster.append((chain, cluster))
complexes_in_cluster.add(chain[:4])
print("Filtered chains in cluster", len(self.chains_in_cluster))
if self.limit_complexes > 0:
self.chains_in_cluster = self.chains_in_cluster[:self.limit_complexes]
def check_all_proteins(self):
for i in range(len(self.chains_in_cluster)//10000+1):
if not os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
return False
return True
def collect_proteins(self):
self.protein_graphs = []
self.vandermers = {}
total_recovered = 0
print(f'Loading {len(self.chains_in_cluster)} protein graphs.')
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
random.shuffle(list_indices)
for i in list_indices:
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'rb') as f:
print(i)
l = pickle.load(f)
total_recovered += len(l)
self.protein_graphs.extend(l)
if not self.vandermers_extraction:
continue
if os.path.exists(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl')):
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'rb') as f:
vandermers = pickle.load(f)
self.vandermers.update(vandermers)
continue
vandermers = {}
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(l), desc=f'computing vandermers {i}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
arguments = zip(l, [self.vandermers_max_dist] * len(l),
[self.vandermers_buffer_residue_num] * len(l))
for t in map_fn(identify_valid_vandermers, arguments):
if t is not None:
vandermers[t[0]] = t[1]
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'wb') as f:
pickle.dump(vandermers, f)
self.vandermers.update(vandermers)
print(f"Kept {len(self.protein_graphs)} proteins out of {len(self.chains_in_cluster)} total")
return
def preprocess(self):
# running preprocessing in parallel on multiple workers and saving the progress every 10000 proteins
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
continue
chains_names = self.chains_in_cluster[10000 * i:10000 * (i + 1)]
protein_graphs = []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(chains_names),
desc=f'loading protein batch {i}/{len(self.chains_in_cluster) // 10000 + 1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.load_chain, chains_names):
if t is not None:
protein_graphs.append(t)
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'wb') as f:
pickle.dump(protein_graphs, f)
print("Finished preprocessing and saving protein graphs")
def load_chain(self, c):
chain, cluster = c
if not os.path.exists(self.root + f"/pdb/{chain[1:3]}/{chain}.pt"):
print("File not found", chain)
return None
data = torch.load(self.root + f"/pdb/{chain[1:3]}/{chain}.pt")
complex_graph = HeteroData()
complex_graph['name'] = chain
orig_seq = data["seq"]
coords = data["xyz"]
mask = data["mask"].bool()
# remove residues with NaN backbone coordinates
to_keep = torch.logical_not(torch.any(torch.isnan(coords[:, :4, 0]), dim=1))
coords = coords[to_keep]
seq = ''.join(np.asarray(list(orig_seq))[to_keep.numpy()].tolist())
mask = mask[to_keep]
if len(coords) == 0:
print("All coords were NaN", chain)
return None
try:
new_extract_receptor_structure(seq, coords.numpy(), complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors, knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms, atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print("Error in extracting receptor", chain)
print(e)
return None
if torch.any(torch.isnan(complex_graph['receptor'].pos)):
print("NaN in pos receptor", chain)
return None
complex_graph.coords = coords
complex_graph.seq = seq
complex_graph.mask = mask
complex_graph.cluster = cluster
complex_graph.orig_seq = orig_seq
complex_graph.to_keep = to_keep
return complex_graph
if __name__ == "__main__":
dataset = PDBSidechain(root="data/pdb_2021aug02_sample", split="train", multiplicity=1, limit_complexes=150)
print(len(dataset))
print(dataset[0])
for p in dataset:
print(p)
pass