PLAPT_V1 / plapt.py
Navvye's picture
Make it better
04579ee
import torch
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
import re
import onnxruntime
import numpy as np
torch.set_num_threads(1)
def flatten_list(nested_list):
flat_list = []
for element in nested_list:
if isinstance(element, list):
flat_list.extend(flatten_list(element))
else:
flat_list.append(element)
return flat_list
class PredictionModule:
def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"):
self.session = onnxruntime.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
# Normalization scaling parameters
self.mean = 6.51286529169358
self.scale = 1.5614094578916633
def convert_to_affinity(self, normalized):
return {
"neg_log10_affinity_M": (normalized * self.scale) + self.mean,
"affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
}
def predict(self, batch_data):
"""Run predictions on a batch of data."""
# Convert each tensor to a numpy array and store in a list
batch_data = np.array([t.numpy() for t in batch_data])
# Process each feature in the batch individually and store results
affinities = []
for feature in batch_data:
# Run the model on the single feature
affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0]
# Append the result
affinities.append(self.convert_to_affinity(affinity_normalized))
return affinities
class Plapt:
def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'):
# Set device for computation
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load protein tokenizer and encoder
self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)
# Load molecule tokenizer and encoder
self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)
self.caching = caching
self.cache = {}
# Load the prediction module ONNX model
self.prediction_module = PredictionModule(prediction_module_path)
def set_prediction_module(self, prediction_module_path):
self.prediction_module = PredictionModule(prediction_module_path)
@staticmethod
def preprocess_sequence(seq):
# Preprocess protein sequence
return " ".join(re.sub(r"[UZOB]", "X", seq))
def tokenize(self, mol_smiles):
# Tokenize and encode molecules
mol_tokens = self.mol_tokenizer(mol_smiles,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
return mol_tokens
def tokenize_prot(self, prot_seq):
# Tokenize and encode protein sequences
prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq),
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
return prot_tokens
# Define the batch functions
@staticmethod
def make_batches(iterable, n=1):
length = len(iterable)
for ndx in range(0, length, n):
yield iterable[ndx:min(ndx + n, length)]
def predict_affinity(self, prot_seq, mol_smiles, batch_size=2):
input_strs = mol_smiles
prot_tokens = self.tokenize_prot(prot_seq)
with torch.no_grad():
prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
prot_representations = prot_representations.squeeze(0)
# repeat for zip(prot_representations, mol_representations)
prot_representations = [prot_representations for i in range(batch_size)]
affinities = []
for batch in self.make_batches(input_strs, batch_size):
batch_key = str(batch) # Convert batch to a string to use as a dictionary key
if batch_key in self.cache and self.caching:
# Use cached features if available
features = self.cache[batch_key]
else:
# Tokenize and encode the batch, then cache the results
mol_tokens = self.tokenize(batch)
with torch.no_grad():
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])]
features = [torch.cat((prot, mol), dim=0) for prot, mol in
zip(prot_representations, mol_representations)]
if self.caching:
self.cache[batch_key] = features
affinities.extend(self.prediction_module.predict(features))
return affinities
def score_candidates(self, target_protein, mol_smiles, batch_size=2):
target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)],
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
with torch.no_grad():
target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu()
print(target_representation)
affinities = []
for mol in mol_smiles:
mol_tokens = self.mol_tokenizer(mol,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
with torch.no_grad():
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
print(mol_representations)
features = torch.cat((target_representation[0], mol_representations[0]), dim=0)
print(features)
affinities.extend(self.prediction_module.predict([features]))
return affinities
def get_cached_features(self):
return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))]
def clear_cache(self):
self.cache = {}