from functools import cache from pathlib import Path from esm import FastaBatchedDataset, pretrained from rdkit.Chem import AddHs from import Dataset, HeteroData import numpy as np import torch import prody as pr import esm import pandas as pd from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence def get_sequences_from_pdbfile(file_path): sequence = None # prodyb package requires str input pdb = pr.parsePDB(str(file_path)) seq = one_hot = get_onehot_sequence(seq) chain_ids = np.zeros(len(one_hot)) res_chain_ids = res_seg_ids = res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)]) ids = np.unique(res_chain_ids) for i, id in enumerate(ids): chain_ids[res_chain_ids == id] = i s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1) s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp]) if sequence is None: sequence = s else: sequence += (":" + s) return sequence @cache def process_protein(protein_string): input_path = Path(protein_string) # Check if the input is a path to a file if Path(protein_string).is_absolute() or len(Path(protein_string).parts) > 1: # Check if the input is a PDB file path if input_path.is_file() and input_path.suffix == '.pdb': # Extract sequence from PDB file return get_sequences_from_pdbfile(input_path), str(input_path) else: raise FileNotFoundError(f"File {protein_string} not found or not a PDB file") else: # Assume the input is already a FASTA sequence return protein_string, None def compute_esm_embeddings(model, alphabet, labels, sequences): # settings used toks_per_batch = 4096 repr_layers = [33] truncation_seq_length = 1022 dataset = FastaBatchedDataset(labels, sequences) batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) data_loader = dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches ) assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers) repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers] embeddings = {} with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)") if torch.cuda.is_available(): toks ="cuda", non_blocking=True) out = model(toks, repr_layers=repr_layers, return_contacts=False) representations = {layer:"cpu") for layer, t in out["representations"].items()} for i, label in enumerate(labels): truncate_len = min(truncation_seq_length, len(strs[i])) embeddings[label] = representations[33][i, 1: truncate_len + 1].clone() return embeddings def generate_esm_structure(model, filename, sequence): model.set_chunk_size(256) chunk_size = 256 output = None while output is None: try: with torch.no_grad(): output = model.infer_pdb(sequence) with open(filename, "w") as f: f.write(output) print("saved", filename) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory on chunk_size', chunk_size) for p in model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() chunk_size = chunk_size // 2 if chunk_size > 2: model.set_chunk_size(chunk_size) else: print("Not enough memory for ESMFold") break else: raise e return output is not None class InferenceDataset(Dataset): def __init__(self, df, out_dir, lm_embeddings, receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None, remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False): super(InferenceDataset, self).__init__() self.receptor_radius = receptor_radius self.c_alpha_max_neighbors = c_alpha_max_neighbors self.remove_hs = remove_hs self.all_atoms = all_atoms self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors self.knn_only_graph = knn_only_graph self.df = df # generate LM embeddings if lm_embeddings and (precomputed_lm_embeddings is None or precomputed_lm_embeddings[0] is None): print("Generating ESM language model embeddings") model_location = "esm2_t33_650M_UR50D" model, alphabet = pretrained.load_model_and_alphabet(model_location) model.eval() if torch.cuda.is_available(): model = model.cuda() df[['protein_sequence', 'protein_path']] = df['X2'].apply(process_protein).apply(pd.Series) labels, sequences = [], [] for i in range(len(df)): s = df['protein_sequence'].iloc[i].split(':') sequences.extend(s) labels.extend([df['name'].iloc[i] + '_chain_' + str(j) for j in range(len(s))]) # TODO improve efficiency for repeated X2 values lm_embeddings = compute_esm_embeddings(model, alphabet, labels, sequences) self.lm_embeddings = [] for i in range(len(df)): s = df['protein_sequence'].iloc[i].split(':') self.lm_embeddings.append( [lm_embeddings[f"{df['name'].iloc[i]}_chain_{j}"] for j in range(len(s))] ) elif not lm_embeddings: self.lm_embeddings = [None] * len(self.complex_names) else: self.lm_embeddings = precomputed_lm_embeddings # generate structures with ESMFold if None in df['protein_path'].values: print("generating missing structures with ESMFold") model = esm.pretrained.esmfold_v1() model = model.eval().cuda() for i in range(len(df)): # TODO improve efficiency for repeated X2 values protein_sequence = df['protein_sequence'].iloc[i] protein_file = df['protein_path'].iloc[i] complex_name = df['name'].iloc[i] if protein_file is None: protein_file = f"{out_dir}/{complex_name}/{complex_name}_esmfold.pdb" if not Path(protein_file).is_file(): print("generating", df['protein_path'].iloc[i]) generate_esm_structure(model, protein_file, protein_sequence) df['protein_sequence'].iloc[i] = protein_sequence def len(self): return len(self.df) def get(self, idx): name = self.df['name'].iloc[idx] protein_file = self.df['protein_path'].iloc[idx] ligand_description = self.df['X1'].iloc[idx] mol = self.df['mol'].iloc[idx] lm_embedding = self.lm_embeddings[idx] # build the pytorch geometric heterogeneous graph complex_graph = HeteroData() complex_graph['name'] = name if mol is not None: mol = AddHs(mol) generate_conformer(mol) else: print(f'Failed to read molecule {ligand_description}. Skipping...') complex_graph['success'] = False return complex_graph try: # parse the receptor from the pdb file get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False, num_conformers=1, remove_hs=self.remove_hs) moad_extract_receptor_structure( path=protein_file, complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius, max_neighbors=self.c_alpha_max_neighbors, lm_embeddings=lm_embedding, knn_only_graph=self.knn_only_graph, all_atoms=self.all_atoms, atom_cutoff=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors) except Exception as e: print(f'Skipping {name} because of the error:') print(e) complex_graph['success'] = False return complex_graph protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True) complex_graph['receptor'].pos -= protein_center if self.all_atoms: complex_graph['atom'].pos -= protein_center ligand_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True) complex_graph['ligand'].pos -= ligand_center complex_graph.original_center = protein_center complex_graph.mol = mol complex_graph['success'] = True return complex_graph