File size: 9,528 Bytes
9439b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from functools import cache
from pathlib import Path

from esm import FastaBatchedDataset, pretrained
from rdkit.Chem import AddHs
from torch_geometric.data import Dataset, HeteroData
import numpy as np
import torch
import prody as pr
import esm
import pandas as pd

from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure
from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence


def get_sequences_from_pdbfile(file_path):
    sequence = None
    # prodyb package requires str input
    pdb = pr.parsePDB(str(file_path))
    seq = pdb.ca.getSequence()
    one_hot = get_onehot_sequence(seq)

    chain_ids = np.zeros(len(one_hot))
    res_chain_ids = pdb.ca.getChids()
    res_seg_ids = pdb.ca.getSegnames()
    res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
    ids = np.unique(res_chain_ids)

    for i, id in enumerate(ids):
        chain_ids[res_chain_ids == id] = i

        s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1)
        s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp])

        if sequence is None:
            sequence = s
        else:
            sequence += (":" + s)

    return sequence

@cache
def process_protein(protein_string):
    input_path = Path(protein_string)
    # Check if the input is a path to a file
    if Path(protein_string).is_absolute() or len(Path(protein_string).parts) > 1:
        # Check if the input is a PDB file path
        if input_path.is_file() and input_path.suffix == '.pdb':
            # Extract sequence from PDB file
            return get_sequences_from_pdbfile(input_path), str(input_path)
        else:
            raise FileNotFoundError(f"File {protein_string} not found or not a PDB file")
    else:
        # Assume the input is already a FASTA sequence
        return protein_string, None


def compute_esm_embeddings(model, alphabet, labels, sequences):
    # settings used
    toks_per_batch = 4096
    repr_layers = [33]
    truncation_seq_length = 1022

    dataset = FastaBatchedDataset(labels, sequences)
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
    embeddings = {}

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}

            for i, label in enumerate(labels):
                truncate_len = min(truncation_seq_length, len(strs[i]))
                embeddings[label] = representations[33][i, 1: truncate_len + 1].clone()
    return embeddings


def generate_esm_structure(model, filename, sequence):
    model.set_chunk_size(256)
    chunk_size = 256
    output = None

    while output is None:
        try:
            with torch.no_grad():
                output = model.infer_pdb(sequence)

            with open(filename, "w") as f:
                f.write(output)
                print("saved", filename)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print('| WARNING: ran out of memory on chunk_size', chunk_size)
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                chunk_size = chunk_size // 2
                if chunk_size > 2:
                    model.set_chunk_size(chunk_size)
                else:
                    print("Not enough memory for ESMFold")
                    break
            else:
                raise e
    return output is not None


class InferenceDataset(Dataset):
    def __init__(self,
                 df, out_dir,
                 lm_embeddings, receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None,
                 remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False):

        super(InferenceDataset, self).__init__()
        self.receptor_radius = receptor_radius
        self.c_alpha_max_neighbors = c_alpha_max_neighbors
        self.remove_hs = remove_hs
        self.all_atoms = all_atoms
        self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
        self.knn_only_graph = knn_only_graph

        self.df = df

        # generate LM embeddings
        if lm_embeddings and (precomputed_lm_embeddings is None or precomputed_lm_embeddings[0] is None):
            print("Generating ESM language model embeddings")
            model_location = "esm2_t33_650M_UR50D"
            model, alphabet = pretrained.load_model_and_alphabet(model_location)
            model.eval()
            if torch.cuda.is_available():
                model = model.cuda()
            df[['protein_sequence', 'protein_path']] = df['X2'].apply(process_protein).apply(pd.Series)
            labels, sequences = [], []
            for i in range(len(df)):
                s = df['protein_sequence'].iloc[i].split(':')
                sequences.extend(s)
                labels.extend([df['name'].iloc[i] + '_chain_' + str(j) for j in range(len(s))])
            # TODO improve efficiency for repeated X2 values
            lm_embeddings = compute_esm_embeddings(model, alphabet, labels, sequences)

            self.lm_embeddings = []
            for i in range(len(df)):
                s = df['protein_sequence'].iloc[i].split(':')
                self.lm_embeddings.append(
                    [lm_embeddings[f"{df['name'].iloc[i]}_chain_{j}"] for j in range(len(s))]
                )

        elif not lm_embeddings:
            self.lm_embeddings = [None] * len(self.complex_names)

        else:
            self.lm_embeddings = precomputed_lm_embeddings

        # generate structures with ESMFold
        if None in df['protein_path'].values:
            print("generating missing structures with ESMFold")
            model = esm.pretrained.esmfold_v1()
            model = model.eval().cuda()

            for i in range(len(df)):
                # TODO improve efficiency for repeated X2 values
                protein_sequence = df['protein_sequence'].iloc[i]
                protein_file = df['protein_path'].iloc[i]
                complex_name = df['name'].iloc[i]
                if protein_file is None:
                    protein_file = f"{out_dir}/{complex_name}/{complex_name}_esmfold.pdb"
                    if not Path(protein_file).is_file():
                        print("generating", df['protein_path'].iloc[i])
                        generate_esm_structure(model, protein_file, protein_sequence)
                        df['protein_sequence'].iloc[i] = protein_sequence

    def len(self):
        return len(self.df)

    def get(self, idx):
        name = self.df['name'].iloc[idx]
        protein_file = self.df['protein_path'].iloc[idx]
        ligand_description = self.df['X1'].iloc[idx]
        mol = self.df['mol'].iloc[idx]
        lm_embedding = self.lm_embeddings[idx]

        # build the pytorch geometric heterogeneous graph
        complex_graph = HeteroData()
        complex_graph['name'] = name

        if mol is not None:
            mol = AddHs(mol)
            generate_conformer(mol)
        else:
            print(f'Failed to read molecule {ligand_description}. Skipping...')
            complex_graph['success'] = False
            return complex_graph

        try:
            # parse the receptor from the pdb file
            get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False,
                                        num_conformers=1, remove_hs=self.remove_hs)

            moad_extract_receptor_structure(
                path=protein_file,
                complex_graph=complex_graph,
                neighbor_cutoff=self.receptor_radius,
                max_neighbors=self.c_alpha_max_neighbors,
                lm_embeddings=lm_embedding,
                knn_only_graph=self.knn_only_graph,
                all_atoms=self.all_atoms,
                atom_cutoff=self.atom_radius,
                atom_max_neighbors=self.atom_max_neighbors)

        except Exception as e:
            print(f'Skipping {name} because of the error:')
            print(e)
            complex_graph['success'] = False
            return complex_graph

        protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
        complex_graph['receptor'].pos -= protein_center
        if self.all_atoms:
            complex_graph['atom'].pos -= protein_center

        ligand_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True)
        complex_graph['ligand'].pos -= ligand_center

        complex_graph.original_center = protein_center
        complex_graph.mol = mol
        complex_graph['success'] = True
        return complex_graph