File size: 4,951 Bytes
5753139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
import re
import onnxruntime
import numpy as np
class PredictionModule:
def __init__(self, model_path="models/predictionModule.onnx"):
"""Initialize the PredictionModule with the given ONNX model."""
self.session = onnxruntime.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
# Normalization scaling parameters
self.mean = 6.51286529169358
self.scale = 1.5614094578916633
def convert_to_affinity(self, normalized):
return {
"neg_log10_affinity_M": (normalized * self.scale) + self.mean,
"affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
}
def predict(self, batch_data):
"""Run predictions on a batch of data."""
# Ensure data is in numpy array format and the correct dtype
batch_data = np.array(batch_data).astype(np.float32)
# Process each feature in the batch individually and store results
affinities = []
for feature in batch_data:
# Reshape the feature to match the model's expected input shape
feature = feature.reshape(1, -1)
# Run the model on the single feature
affinity_normalized = self.session.run(None, {self.input_name: feature, 'TrainingMode': np.array(False)})[0][0][0]
# Append the result
affinities.append(self.convert_to_affinity(affinity_normalized))
return affinities
class Plapt:
def __init__(self, prediction_module_path = "models/predictionModule.onnx", device='cuda'):
# Set device for computation
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load protein tokenizer and encoder
self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)
# Load molecule tokenizer and encoder
self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)
self.cache = {}
# Load the prediction module ONNX model
self.prediction_module = PredictionModule(prediction_module_path)
def set_prediction_module(self, prediction_module_path):
self.prediction_module = PredictionModule(prediction_module_path)
@staticmethod
def preprocess_sequence(seq):
# Preprocess protein sequence
return " ".join(re.sub(r"[UZOB]", "X", seq))
def tokenize(self, prot_seqs, mol_smiles):
# Tokenize and encode protein sequences
prot_tokens = self.prot_tokenizer([self.preprocess_sequence(seq) for seq in prot_seqs],
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
# Tokenize and encode molecules
mol_tokens = self.mol_tokenizer(mol_smiles,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
return prot_tokens, mol_tokens
# Define the batch functions
@staticmethod
def make_batches(iterable, n=1):
length = len(iterable)
for ndx in range(0, length, n):
yield iterable[ndx:min(ndx + n, length)]
def predict_affinity(self, prot_seqs, mol_smiles, batch_size=2):
input_strs = list(zip(prot_seqs,mol_smiles))
affinities = []
for batch in self.make_batches(input_strs, batch_size):
batch_key = str(batch) # Convert batch to a string to use as a dictionary key
if batch_key in self.cache:
# Use cached features if available
features = self.cache[batch_key]
else:
# Tokenize and encode the batch, then cache the results
prot_tokens, mol_tokens = self.tokenize(*zip(*batch))
with torch.no_grad():
prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
features = [torch.cat((prot, mol), dim=0) for prot, mol in zip(prot_representations, mol_representations)]
self.cache[batch_key] = features
affinities.extend(self.prediction_module.predict(features))
return affinities
|