trrt8 commited on
Commit
5753139
·
1 Parent(s): 4a4e689

Upload 2 files

Browse files
Files changed (2) hide show
  1. models/predictionModule.onnx +3 -0
  2. plapt.py +111 -0
models/predictionModule.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7601e11190efd0f06e7aa1ae161e09efab5b674cc938807b901abddd9ef13594
3
+ size 4867404
plapt.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
4
+ import re
5
+ import onnxruntime
6
+ import numpy as np
7
+
8
+ class PredictionModule:
9
+ def __init__(self, model_path="models/predictionModule.onnx"):
10
+ """Initialize the PredictionModule with the given ONNX model."""
11
+ self.session = onnxruntime.InferenceSession(model_path)
12
+ self.input_name = self.session.get_inputs()[0].name
13
+
14
+ # Normalization scaling parameters
15
+ self.mean = 6.51286529169358
16
+ self.scale = 1.5614094578916633
17
+
18
+ def convert_to_affinity(self, normalized):
19
+ return {
20
+ "neg_log10_affinity_M": (normalized * self.scale) + self.mean,
21
+ "affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
22
+ }
23
+
24
+ def predict(self, batch_data):
25
+ """Run predictions on a batch of data."""
26
+ # Ensure data is in numpy array format and the correct dtype
27
+ batch_data = np.array(batch_data).astype(np.float32)
28
+
29
+ # Process each feature in the batch individually and store results
30
+ affinities = []
31
+ for feature in batch_data:
32
+ # Reshape the feature to match the model's expected input shape
33
+ feature = feature.reshape(1, -1)
34
+ # Run the model on the single feature
35
+ affinity_normalized = self.session.run(None, {self.input_name: feature, 'TrainingMode': np.array(False)})[0][0][0]
36
+ # Append the result
37
+ affinities.append(self.convert_to_affinity(affinity_normalized))
38
+
39
+ return affinities
40
+
41
+ class Plapt:
42
+ def __init__(self, prediction_module_path = "models/predictionModule.onnx", device='cuda'):
43
+ # Set device for computation
44
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Load protein tokenizer and encoder
47
+ self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
48
+ self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)
49
+
50
+ # Load molecule tokenizer and encoder
51
+ self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
52
+ self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)
53
+
54
+ self.cache = {}
55
+
56
+ # Load the prediction module ONNX model
57
+ self.prediction_module = PredictionModule(prediction_module_path)
58
+
59
+ def set_prediction_module(self, prediction_module_path):
60
+ self.prediction_module = PredictionModule(prediction_module_path)
61
+
62
+ @staticmethod
63
+ def preprocess_sequence(seq):
64
+ # Preprocess protein sequence
65
+ return " ".join(re.sub(r"[UZOB]", "X", seq))
66
+
67
+ def tokenize(self, prot_seqs, mol_smiles):
68
+ # Tokenize and encode protein sequences
69
+ prot_tokens = self.prot_tokenizer([self.preprocess_sequence(seq) for seq in prot_seqs],
70
+ padding=True,
71
+ max_length=3200,
72
+ truncation=True,
73
+ return_tensors='pt')
74
+
75
+ # Tokenize and encode molecules
76
+ mol_tokens = self.mol_tokenizer(mol_smiles,
77
+ padding=True,
78
+ max_length=278,
79
+ truncation=True,
80
+ return_tensors='pt')
81
+ return prot_tokens, mol_tokens
82
+
83
+ # Define the batch functions
84
+ @staticmethod
85
+ def make_batches(iterable, n=1):
86
+ length = len(iterable)
87
+ for ndx in range(0, length, n):
88
+ yield iterable[ndx:min(ndx + n, length)]
89
+
90
+ def predict_affinity(self, prot_seqs, mol_smiles, batch_size=2):
91
+ input_strs = list(zip(prot_seqs,mol_smiles))
92
+ affinities = []
93
+ for batch in self.make_batches(input_strs, batch_size):
94
+ batch_key = str(batch) # Convert batch to a string to use as a dictionary key
95
+
96
+ if batch_key in self.cache:
97
+ # Use cached features if available
98
+ features = self.cache[batch_key]
99
+ else:
100
+ # Tokenize and encode the batch, then cache the results
101
+ prot_tokens, mol_tokens = self.tokenize(*zip(*batch))
102
+ with torch.no_grad():
103
+ prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
104
+ mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
105
+
106
+ features = [torch.cat((prot, mol), dim=0) for prot, mol in zip(prot_representations, mol_representations)]
107
+ self.cache[batch_key] = features
108
+
109
+ affinities.extend(self.prediction_module.predict(features))
110
+
111
+ return affinities