Navvye commited on
Commit
04579ee
·
1 Parent(s): f45e077

Make it better

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
__pycache__/plapt.cpython-312.pyc ADDED
Binary file (9.54 kB). View file
 
index.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import onnxruntime
4
+ from transformers import BertTokenizer, RobertaTokenizer
5
+ import torch
6
+
7
+ def init():
8
+ global session, prot_tokenizer, mol_tokenizer, input_name
9
+ session = onnxruntime.InferenceSession("models/affinity_predictor0734-seed2101.onnx")
10
+ input_name = session.get_inputs()[0].name
11
+ prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
12
+ mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
13
+
14
+ def run(raw_data):
15
+ try:
16
+ data = json.loads(raw_data)
17
+ prot_seq = data['protein']
18
+ mol_smiles = data['smiles']
19
+
20
+ # Tokenize and encode protein
21
+ prot_tokens = prot_tokenizer(preprocess_sequence(prot_seq),
22
+ padding=True,
23
+ max_length=3200,
24
+ truncation=True,
25
+ return_tensors='pt')
26
+ with torch.no_grad():
27
+ prot_representations = torch.tensor(prot_tokens['input_ids']).unsqueeze(0)
28
+ prot_representations = prot_representations.squeeze(0)
29
+
30
+ # Tokenize and encode molecule
31
+ mol_tokens = mol_tokenizer(mol_smiles,
32
+ padding=True,
33
+ max_length=278,
34
+ truncation=True,
35
+ return_tensors='pt')
36
+ with torch.no_grad():
37
+ mol_representations = torch.tensor(mol_tokens['input_ids']).unsqueeze(0)
38
+ mol_representations = mol_representations.squeeze(0)
39
+
40
+ # Combine representations
41
+ features = torch.cat((prot_representations, mol_representations), dim=0)
42
+
43
+ # Run inference
44
+ affinity_normalized = session.run(None, {input_name: [features.numpy()], 'TrainingMode': np.array(False)})[0][0][0]
45
+
46
+ # Convert to affinity
47
+ affinity = convert_to_affinity(affinity_normalized)
48
+
49
+ return (affinity)
50
+ except Exception as e:
51
+ return json.dumps({"error": str(e)})
52
+
53
+ def preprocess_sequence(seq):
54
+ import re
55
+ return " ".join(re.sub(r"[UZOB]", "X", seq))
56
+
57
+ def convert_to_affinity(normalized):
58
+ mean = 6.51286529169358
59
+ scale = 1.5614094578916633
60
+ return {
61
+ "neg_log10_affinity_M": (normalized * scale) + mean,
62
+ "affinity_uM": (10**6) * (10**(-((normalized * scale) + mean)))
63
+ }
64
+
65
+ print(run({"protein": "MILK", "smiles": "CCO"}))
models/.DS_Store ADDED
File without changes
models/affinity_predictor0734-seed2101.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbb242b307274215e542bae5cd524f81d06e6f1102b4cc0cf31042e2a601509c
3
+ size 5924195
plapt.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
3
+ import re
4
+ import onnxruntime
5
+ import numpy as np
6
+ torch.set_num_threads(1)
7
+ def flatten_list(nested_list):
8
+ flat_list = []
9
+ for element in nested_list:
10
+ if isinstance(element, list):
11
+ flat_list.extend(flatten_list(element))
12
+ else:
13
+ flat_list.append(element)
14
+
15
+ return flat_list
16
+
17
+ class PredictionModule:
18
+ def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"):
19
+ self.session = onnxruntime.InferenceSession(model_path)
20
+ self.input_name = self.session.get_inputs()[0].name
21
+
22
+ # Normalization scaling parameters
23
+ self.mean = 6.51286529169358
24
+ self.scale = 1.5614094578916633
25
+
26
+ def convert_to_affinity(self, normalized):
27
+ return {
28
+ "neg_log10_affinity_M": (normalized * self.scale) + self.mean,
29
+ "affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
30
+ }
31
+
32
+ def predict(self, batch_data):
33
+ """Run predictions on a batch of data."""
34
+ # Convert each tensor to a numpy array and store in a list
35
+ batch_data = np.array([t.numpy() for t in batch_data])
36
+
37
+ # Process each feature in the batch individually and store results
38
+ affinities = []
39
+ for feature in batch_data:
40
+ # Run the model on the single feature
41
+ affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0]
42
+ # Append the result
43
+ affinities.append(self.convert_to_affinity(affinity_normalized))
44
+
45
+ return affinities
46
+
47
+ class Plapt:
48
+ def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'):
49
+ # Set device for computation
50
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
51
+
52
+ # Load protein tokenizer and encoder
53
+ self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
54
+ self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)
55
+
56
+ # Load molecule tokenizer and encoder
57
+ self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
58
+ self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)
59
+
60
+ self.caching = caching
61
+ self.cache = {}
62
+
63
+ # Load the prediction module ONNX model
64
+ self.prediction_module = PredictionModule(prediction_module_path)
65
+
66
+ def set_prediction_module(self, prediction_module_path):
67
+ self.prediction_module = PredictionModule(prediction_module_path)
68
+
69
+ @staticmethod
70
+ def preprocess_sequence(seq):
71
+ # Preprocess protein sequence
72
+ return " ".join(re.sub(r"[UZOB]", "X", seq))
73
+
74
+ def tokenize(self, mol_smiles):
75
+ # Tokenize and encode molecules
76
+ mol_tokens = self.mol_tokenizer(mol_smiles,
77
+ padding=True,
78
+ max_length=278,
79
+ truncation=True,
80
+ return_tensors='pt')
81
+ return mol_tokens
82
+
83
+ def tokenize_prot(self, prot_seq):
84
+ # Tokenize and encode protein sequences
85
+ prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq),
86
+ padding=True,
87
+ max_length=3200,
88
+ truncation=True,
89
+ return_tensors='pt')
90
+
91
+ return prot_tokens
92
+
93
+ # Define the batch functions
94
+ @staticmethod
95
+ def make_batches(iterable, n=1):
96
+ length = len(iterable)
97
+ for ndx in range(0, length, n):
98
+ yield iterable[ndx:min(ndx + n, length)]
99
+
100
+ def predict_affinity(self, prot_seq, mol_smiles, batch_size=2):
101
+ input_strs = mol_smiles
102
+
103
+ prot_tokens = self.tokenize_prot(prot_seq)
104
+ with torch.no_grad():
105
+ prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
106
+ prot_representations = prot_representations.squeeze(0)
107
+ # repeat for zip(prot_representations, mol_representations)
108
+ prot_representations = [prot_representations for i in range(batch_size)]
109
+
110
+ affinities = []
111
+ for batch in self.make_batches(input_strs, batch_size):
112
+ batch_key = str(batch) # Convert batch to a string to use as a dictionary key
113
+
114
+ if batch_key in self.cache and self.caching:
115
+ # Use cached features if available
116
+ features = self.cache[batch_key]
117
+ else:
118
+ # Tokenize and encode the batch, then cache the results
119
+ mol_tokens = self.tokenize(batch)
120
+ with torch.no_grad():
121
+ mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
122
+ mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])]
123
+
124
+ features = [torch.cat((prot, mol), dim=0) for prot, mol in
125
+ zip(prot_representations, mol_representations)]
126
+
127
+ if self.caching:
128
+ self.cache[batch_key] = features
129
+
130
+ affinities.extend(self.prediction_module.predict(features))
131
+
132
+ return affinities
133
+
134
+ def score_candidates(self, target_protein, mol_smiles, batch_size=2):
135
+ target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)],
136
+ padding=True,
137
+ max_length=3200,
138
+ truncation=True,
139
+ return_tensors='pt')
140
+
141
+ with torch.no_grad():
142
+ target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu()
143
+
144
+ print(target_representation)
145
+
146
+ affinities = []
147
+ for mol in mol_smiles:
148
+ mol_tokens = self.mol_tokenizer(mol,
149
+ padding=True,
150
+ max_length=278,
151
+ truncation=True,
152
+ return_tensors='pt')
153
+
154
+ with torch.no_grad():
155
+ mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
156
+
157
+ print(mol_representations)
158
+
159
+ features = torch.cat((target_representation[0], mol_representations[0]), dim=0)
160
+
161
+ print(features)
162
+
163
+ affinities.extend(self.prediction_module.predict([features]))
164
+
165
+ return affinities
166
+
167
+ def get_cached_features(self):
168
+ return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))]
169
+
170
+ def clear_cache(self):
171
+ self.cache = {}
plapt_cli.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+
4
+
5
+ import argparse
6
+ import json
7
+ import csv
8
+ import os
9
+ from plapt import Plapt
10
+ warnings.filterwarnings("ignore")
11
+ def write_json(results, filename):
12
+ with open(filename, 'w') as json_file:
13
+ json.dump(results, json_file)
14
+
15
+ def write_csv(results, filename):
16
+ with open(filename, 'w', newline='') as csv_file:
17
+ writer = csv.writer(csv_file)
18
+ for result in results:
19
+ writer.writerow([result])
20
+
21
+ def determine_format_and_update_filename(output_arg, format_arg):
22
+ if output_arg:
23
+ _, ext = os.path.splitext(output_arg)
24
+ if ext not in [".csv", ".json"]:
25
+ output_arg += f".{format_arg or 'json'}"
26
+ return output_arg, (format_arg or "json" if not ext else ext[1:])
27
+ return None, "json"
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser(description="Predict affinity using Plapt.")
31
+ parser.add_argument("-t", "--target", nargs="+", required=True, help="The target protein sequence")
32
+ parser.add_argument("-m", "--smiles", nargs="+", required=True, help="List of SMILES strings")
33
+ parser.add_argument("-o", "--output", help="Optional output file path")
34
+ parser.add_argument("-f", "--format", choices=["json", "csv"], help="Optional output file format; required if output is specified without an extension")
35
+
36
+ args = parser.parse_args()
37
+
38
+ plapt = Plapt()
39
+ results = plapt.predict_affinity(args.target[0], args.smiles)
40
+
41
+ args.output, output_format = determine_format_and_update_filename(args.output, args.format)
42
+
43
+ if args.output:
44
+ if output_format == "json":
45
+ write_json(results, args.output)
46
+ elif output_format == "csv":
47
+ write_csv(results, args.output)
48
+ print(f"Output written to {args.output}")
49
+ else:
50
+ print(results)
51
+
52
+ if __name__ == "__main__":
53
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ azureml-core
2
+ azureml-defaults
3
+ torch
4
+ transformers
5
+ onnxruntime
6
+ numpy