Spaces:
Sleeping
Sleeping
from functools import cache | |
from pathlib import Path | |
from esm import FastaBatchedDataset, pretrained | |
from rdkit.Chem import AddHs | |
from torch_geometric.data import Dataset, HeteroData | |
import numpy as np | |
import torch | |
import prody as pr | |
import esm | |
import pandas as pd | |
from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure | |
from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence | |
def get_sequences_from_pdbfile(file_path): | |
sequence = None | |
# prodyb package requires str input | |
pdb = pr.parsePDB(str(file_path)) | |
seq = pdb.ca.getSequence() | |
one_hot = get_onehot_sequence(seq) | |
chain_ids = np.zeros(len(one_hot)) | |
res_chain_ids = pdb.ca.getChids() | |
res_seg_ids = pdb.ca.getSegnames() | |
res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)]) | |
ids = np.unique(res_chain_ids) | |
for i, id in enumerate(ids): | |
chain_ids[res_chain_ids == id] = i | |
s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1) | |
s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp]) | |
if sequence is None: | |
sequence = s | |
else: | |
sequence += (":" + s) | |
return sequence | |
def process_protein(protein_string): | |
input_path = Path(protein_string) | |
# Check if the input is a path to a file | |
if Path(protein_string).is_absolute() or len(Path(protein_string).parts) > 1: | |
# Check if the input is a PDB file path | |
if input_path.is_file() and input_path.suffix == '.pdb': | |
# Extract sequence from PDB file | |
return get_sequences_from_pdbfile(input_path), str(input_path) | |
else: | |
raise FileNotFoundError(f"File {protein_string} not found or not a PDB file") | |
else: | |
# Assume the input is already a FASTA sequence | |
return protein_string, None | |
def compute_esm_embeddings(model, alphabet, labels, sequences): | |
# settings used | |
toks_per_batch = 4096 | |
repr_layers = [33] | |
truncation_seq_length = 1022 | |
dataset = FastaBatchedDataset(labels, sequences) | |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches | |
) | |
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers) | |
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers] | |
embeddings = {} | |
with torch.no_grad(): | |
for batch_idx, (labels, strs, toks) in enumerate(data_loader): | |
print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)") | |
if torch.cuda.is_available(): | |
toks = toks.to(device="cuda", non_blocking=True) | |
out = model(toks, repr_layers=repr_layers, return_contacts=False) | |
representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()} | |
for i, label in enumerate(labels): | |
truncate_len = min(truncation_seq_length, len(strs[i])) | |
embeddings[label] = representations[33][i, 1: truncate_len + 1].clone() | |
return embeddings | |
def generate_esm_structure(model, filename, sequence): | |
model.set_chunk_size(256) | |
chunk_size = 256 | |
output = None | |
while output is None: | |
try: | |
with torch.no_grad(): | |
output = model.infer_pdb(sequence) | |
with open(filename, "w") as f: | |
f.write(output) | |
print("saved", filename) | |
except RuntimeError as e: | |
if 'out of memory' in str(e): | |
print('| WARNING: ran out of memory on chunk_size', chunk_size) | |
for p in model.parameters(): | |
if p.grad is not None: | |
del p.grad # free some memory | |
torch.cuda.empty_cache() | |
chunk_size = chunk_size // 2 | |
if chunk_size > 2: | |
model.set_chunk_size(chunk_size) | |
else: | |
print("Not enough memory for ESMFold") | |
break | |
else: | |
raise e | |
return output is not None | |
class InferenceDataset(Dataset): | |
def __init__(self, | |
df, out_dir, | |
lm_embeddings, receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None, | |
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False): | |
super(InferenceDataset, self).__init__() | |
self.receptor_radius = receptor_radius | |
self.c_alpha_max_neighbors = c_alpha_max_neighbors | |
self.remove_hs = remove_hs | |
self.all_atoms = all_atoms | |
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors | |
self.knn_only_graph = knn_only_graph | |
self.df = df | |
# generate LM embeddings | |
if lm_embeddings and (precomputed_lm_embeddings is None or precomputed_lm_embeddings[0] is None): | |
print("Generating ESM language model embeddings") | |
model_location = "esm2_t33_650M_UR50D" | |
model, alphabet = pretrained.load_model_and_alphabet(model_location) | |
model.eval() | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
df[['protein_sequence', 'protein_path']] = df['X2'].apply(process_protein).apply(pd.Series) | |
labels, sequences = [], [] | |
for i in range(len(df)): | |
s = df['protein_sequence'].iloc[i].split(':') | |
sequences.extend(s) | |
labels.extend([df['name'].iloc[i] + '_chain_' + str(j) for j in range(len(s))]) | |
# TODO improve efficiency for repeated X2 values | |
lm_embeddings = compute_esm_embeddings(model, alphabet, labels, sequences) | |
self.lm_embeddings = [] | |
for i in range(len(df)): | |
s = df['protein_sequence'].iloc[i].split(':') | |
self.lm_embeddings.append( | |
[lm_embeddings[f"{df['name'].iloc[i]}_chain_{j}"] for j in range(len(s))] | |
) | |
elif not lm_embeddings: | |
self.lm_embeddings = [None] * len(self.complex_names) | |
else: | |
self.lm_embeddings = precomputed_lm_embeddings | |
# generate structures with ESMFold | |
if None in df['protein_path'].values: | |
print("generating missing structures with ESMFold") | |
model = esm.pretrained.esmfold_v1() | |
model = model.eval().cuda() | |
for i in range(len(df)): | |
# TODO improve efficiency for repeated X2 values | |
protein_sequence = df['protein_sequence'].iloc[i] | |
protein_file = df['protein_path'].iloc[i] | |
complex_name = df['name'].iloc[i] | |
if protein_file is None: | |
protein_file = f"{out_dir}/{complex_name}/{complex_name}_esmfold.pdb" | |
if not Path(protein_file).is_file(): | |
print("generating", df['protein_path'].iloc[i]) | |
generate_esm_structure(model, protein_file, protein_sequence) | |
df['protein_sequence'].iloc[i] = protein_sequence | |
def len(self): | |
return len(self.df) | |
def get(self, idx): | |
name = self.df['name'].iloc[idx] | |
protein_file = self.df['protein_path'].iloc[idx] | |
ligand_description = self.df['X1'].iloc[idx] | |
mol = self.df['mol'].iloc[idx] | |
lm_embedding = self.lm_embeddings[idx] | |
# build the pytorch geometric heterogeneous graph | |
complex_graph = HeteroData() | |
complex_graph['name'] = name | |
if mol is not None: | |
mol = AddHs(mol) | |
generate_conformer(mol) | |
else: | |
print(f'Failed to read molecule {ligand_description}. Skipping...') | |
complex_graph['success'] = False | |
return complex_graph | |
try: | |
# parse the receptor from the pdb file | |
get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False, | |
num_conformers=1, remove_hs=self.remove_hs) | |
moad_extract_receptor_structure( | |
path=protein_file, | |
complex_graph=complex_graph, | |
neighbor_cutoff=self.receptor_radius, | |
max_neighbors=self.c_alpha_max_neighbors, | |
lm_embeddings=lm_embedding, | |
knn_only_graph=self.knn_only_graph, | |
all_atoms=self.all_atoms, | |
atom_cutoff=self.atom_radius, | |
atom_max_neighbors=self.atom_max_neighbors) | |
except Exception as e: | |
print(f'Skipping {name} because of the error:') | |
print(e) | |
complex_graph['success'] = False | |
return complex_graph | |
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True) | |
complex_graph['receptor'].pos -= protein_center | |
if self.all_atoms: | |
complex_graph['atom'].pos -= protein_center | |
ligand_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True) | |
complex_graph['ligand'].pos -= ligand_center | |
complex_graph.original_center = protein_center | |
complex_graph.mol = mol | |
complex_graph['success'] = True | |
return complex_graph | |