|
import torch |
|
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel |
|
import re |
|
import onnxruntime |
|
import numpy as np |
|
torch.set_num_threads(1) |
|
def flatten_list(nested_list): |
|
flat_list = [] |
|
for element in nested_list: |
|
if isinstance(element, list): |
|
flat_list.extend(flatten_list(element)) |
|
else: |
|
flat_list.append(element) |
|
|
|
return flat_list |
|
|
|
class PredictionModule: |
|
def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"): |
|
self.session = onnxruntime.InferenceSession(model_path) |
|
self.input_name = self.session.get_inputs()[0].name |
|
|
|
|
|
self.mean = 6.51286529169358 |
|
self.scale = 1.5614094578916633 |
|
|
|
def convert_to_affinity(self, normalized): |
|
return { |
|
"neg_log10_affinity_M": (normalized * self.scale) + self.mean, |
|
"affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean))) |
|
} |
|
|
|
def predict(self, batch_data): |
|
"""Run predictions on a batch of data.""" |
|
|
|
batch_data = np.array([t.numpy() for t in batch_data]) |
|
|
|
|
|
affinities = [] |
|
for feature in batch_data: |
|
|
|
affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0] |
|
|
|
affinities.append(self.convert_to_affinity(affinity_normalized)) |
|
|
|
return affinities |
|
|
|
class Plapt: |
|
def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'): |
|
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) |
|
self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device) |
|
|
|
|
|
self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
|
self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device) |
|
|
|
self.caching = caching |
|
self.cache = {} |
|
|
|
|
|
self.prediction_module = PredictionModule(prediction_module_path) |
|
|
|
def set_prediction_module(self, prediction_module_path): |
|
self.prediction_module = PredictionModule(prediction_module_path) |
|
|
|
@staticmethod |
|
def preprocess_sequence(seq): |
|
|
|
return " ".join(re.sub(r"[UZOB]", "X", seq)) |
|
|
|
def tokenize(self, mol_smiles): |
|
|
|
mol_tokens = self.mol_tokenizer(mol_smiles, |
|
padding=True, |
|
max_length=278, |
|
truncation=True, |
|
return_tensors='pt') |
|
return mol_tokens |
|
|
|
def tokenize_prot(self, prot_seq): |
|
|
|
prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq), |
|
padding=True, |
|
max_length=3200, |
|
truncation=True, |
|
return_tensors='pt') |
|
|
|
return prot_tokens |
|
|
|
|
|
@staticmethod |
|
def make_batches(iterable, n=1): |
|
length = len(iterable) |
|
for ndx in range(0, length, n): |
|
yield iterable[ndx:min(ndx + n, length)] |
|
|
|
def predict_affinity(self, prot_seq, mol_smiles, batch_size=2): |
|
input_strs = mol_smiles |
|
|
|
prot_tokens = self.tokenize_prot(prot_seq) |
|
with torch.no_grad(): |
|
prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu() |
|
prot_representations = prot_representations.squeeze(0) |
|
|
|
prot_representations = [prot_representations for i in range(batch_size)] |
|
|
|
affinities = [] |
|
for batch in self.make_batches(input_strs, batch_size): |
|
batch_key = str(batch) |
|
|
|
if batch_key in self.cache and self.caching: |
|
|
|
features = self.cache[batch_key] |
|
else: |
|
|
|
mol_tokens = self.tokenize(batch) |
|
with torch.no_grad(): |
|
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu() |
|
mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])] |
|
|
|
features = [torch.cat((prot, mol), dim=0) for prot, mol in |
|
zip(prot_representations, mol_representations)] |
|
|
|
if self.caching: |
|
self.cache[batch_key] = features |
|
|
|
affinities.extend(self.prediction_module.predict(features)) |
|
|
|
return affinities |
|
|
|
def score_candidates(self, target_protein, mol_smiles, batch_size=2): |
|
target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)], |
|
padding=True, |
|
max_length=3200, |
|
truncation=True, |
|
return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu() |
|
|
|
print(target_representation) |
|
|
|
affinities = [] |
|
for mol in mol_smiles: |
|
mol_tokens = self.mol_tokenizer(mol, |
|
padding=True, |
|
max_length=278, |
|
truncation=True, |
|
return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu() |
|
|
|
print(mol_representations) |
|
|
|
features = torch.cat((target_representation[0], mol_representations[0]), dim=0) |
|
|
|
print(features) |
|
|
|
affinities.extend(self.prediction_module.predict([features])) |
|
|
|
return affinities |
|
|
|
def get_cached_features(self): |
|
return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))] |
|
|
|
def clear_cache(self): |
|
self.cache = {} |
|
|