Spaces:
Sleeping
Sleeping
File size: 9,528 Bytes
9439b9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
from functools import cache
from pathlib import Path
from esm import FastaBatchedDataset, pretrained
from rdkit.Chem import AddHs
from torch_geometric.data import Dataset, HeteroData
import numpy as np
import torch
import prody as pr
import esm
import pandas as pd
from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure
from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence
def get_sequences_from_pdbfile(file_path):
sequence = None
# prodyb package requires str input
pdb = pr.parsePDB(str(file_path))
seq = pdb.ca.getSequence()
one_hot = get_onehot_sequence(seq)
chain_ids = np.zeros(len(one_hot))
res_chain_ids = pdb.ca.getChids()
res_seg_ids = pdb.ca.getSegnames()
res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
ids = np.unique(res_chain_ids)
for i, id in enumerate(ids):
chain_ids[res_chain_ids == id] = i
s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1)
s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp])
if sequence is None:
sequence = s
else:
sequence += (":" + s)
return sequence
@cache
def process_protein(protein_string):
input_path = Path(protein_string)
# Check if the input is a path to a file
if Path(protein_string).is_absolute() or len(Path(protein_string).parts) > 1:
# Check if the input is a PDB file path
if input_path.is_file() and input_path.suffix == '.pdb':
# Extract sequence from PDB file
return get_sequences_from_pdbfile(input_path), str(input_path)
else:
raise FileNotFoundError(f"File {protein_string} not found or not a PDB file")
else:
# Assume the input is already a FASTA sequence
return protein_string, None
def compute_esm_embeddings(model, alphabet, labels, sequences):
# settings used
toks_per_batch = 4096
repr_layers = [33]
truncation_seq_length = 1022
dataset = FastaBatchedDataset(labels, sequences)
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
embeddings = {}
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model(toks, repr_layers=repr_layers, return_contacts=False)
representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
for i, label in enumerate(labels):
truncate_len = min(truncation_seq_length, len(strs[i]))
embeddings[label] = representations[33][i, 1: truncate_len + 1].clone()
return embeddings
def generate_esm_structure(model, filename, sequence):
model.set_chunk_size(256)
chunk_size = 256
output = None
while output is None:
try:
with torch.no_grad():
output = model.infer_pdb(sequence)
with open(filename, "w") as f:
f.write(output)
print("saved", filename)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on chunk_size', chunk_size)
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
chunk_size = chunk_size // 2
if chunk_size > 2:
model.set_chunk_size(chunk_size)
else:
print("Not enough memory for ESMFold")
break
else:
raise e
return output is not None
class InferenceDataset(Dataset):
def __init__(self,
df, out_dir,
lm_embeddings, receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None,
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False):
super(InferenceDataset, self).__init__()
self.receptor_radius = receptor_radius
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.remove_hs = remove_hs
self.all_atoms = all_atoms
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
self.knn_only_graph = knn_only_graph
self.df = df
# generate LM embeddings
if lm_embeddings and (precomputed_lm_embeddings is None or precomputed_lm_embeddings[0] is None):
print("Generating ESM language model embeddings")
model_location = "esm2_t33_650M_UR50D"
model, alphabet = pretrained.load_model_and_alphabet(model_location)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
df[['protein_sequence', 'protein_path']] = df['X2'].apply(process_protein).apply(pd.Series)
labels, sequences = [], []
for i in range(len(df)):
s = df['protein_sequence'].iloc[i].split(':')
sequences.extend(s)
labels.extend([df['name'].iloc[i] + '_chain_' + str(j) for j in range(len(s))])
# TODO improve efficiency for repeated X2 values
lm_embeddings = compute_esm_embeddings(model, alphabet, labels, sequences)
self.lm_embeddings = []
for i in range(len(df)):
s = df['protein_sequence'].iloc[i].split(':')
self.lm_embeddings.append(
[lm_embeddings[f"{df['name'].iloc[i]}_chain_{j}"] for j in range(len(s))]
)
elif not lm_embeddings:
self.lm_embeddings = [None] * len(self.complex_names)
else:
self.lm_embeddings = precomputed_lm_embeddings
# generate structures with ESMFold
if None in df['protein_path'].values:
print("generating missing structures with ESMFold")
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()
for i in range(len(df)):
# TODO improve efficiency for repeated X2 values
protein_sequence = df['protein_sequence'].iloc[i]
protein_file = df['protein_path'].iloc[i]
complex_name = df['name'].iloc[i]
if protein_file is None:
protein_file = f"{out_dir}/{complex_name}/{complex_name}_esmfold.pdb"
if not Path(protein_file).is_file():
print("generating", df['protein_path'].iloc[i])
generate_esm_structure(model, protein_file, protein_sequence)
df['protein_sequence'].iloc[i] = protein_sequence
def len(self):
return len(self.df)
def get(self, idx):
name = self.df['name'].iloc[idx]
protein_file = self.df['protein_path'].iloc[idx]
ligand_description = self.df['X1'].iloc[idx]
mol = self.df['mol'].iloc[idx]
lm_embedding = self.lm_embeddings[idx]
# build the pytorch geometric heterogeneous graph
complex_graph = HeteroData()
complex_graph['name'] = name
if mol is not None:
mol = AddHs(mol)
generate_conformer(mol)
else:
print(f'Failed to read molecule {ligand_description}. Skipping...')
complex_graph['success'] = False
return complex_graph
try:
# parse the receptor from the pdb file
get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False,
num_conformers=1, remove_hs=self.remove_hs)
moad_extract_receptor_structure(
path=protein_file,
complex_graph=complex_graph,
neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors,
lm_embeddings=lm_embedding,
knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms,
atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
complex_graph['success'] = False
return complex_graph
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= protein_center
if self.all_atoms:
complex_graph['atom'].pos -= protein_center
ligand_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True)
complex_graph['ligand'].pos -= ligand_center
complex_graph.original_center = protein_center
complex_graph.mol = mol
complex_graph['success'] = True
return complex_graph
|