File size: 2,485 Bytes
04579ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import numpy as np
import json
import onnxruntime
from transformers import BertTokenizer, RobertaTokenizer
import torch
def init():
global session, prot_tokenizer, mol_tokenizer, input_name
session = onnxruntime.InferenceSession("models/affinity_predictor0734-seed2101.onnx")
input_name = session.get_inputs()[0].name
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
def run(raw_data):
try:
data = json.loads(raw_data)
prot_seq = data['protein']
mol_smiles = data['smiles']
# Tokenize and encode protein
prot_tokens = prot_tokenizer(preprocess_sequence(prot_seq),
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
with torch.no_grad():
prot_representations = torch.tensor(prot_tokens['input_ids']).unsqueeze(0)
prot_representations = prot_representations.squeeze(0)
# Tokenize and encode molecule
mol_tokens = mol_tokenizer(mol_smiles,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
with torch.no_grad():
mol_representations = torch.tensor(mol_tokens['input_ids']).unsqueeze(0)
mol_representations = mol_representations.squeeze(0)
# Combine representations
features = torch.cat((prot_representations, mol_representations), dim=0)
# Run inference
affinity_normalized = session.run(None, {input_name: [features.numpy()], 'TrainingMode': np.array(False)})[0][0][0]
# Convert to affinity
affinity = convert_to_affinity(affinity_normalized)
return (affinity)
except Exception as e:
return json.dumps({"error": str(e)})
def preprocess_sequence(seq):
import re
return " ".join(re.sub(r"[UZOB]", "X", seq))
def convert_to_affinity(normalized):
mean = 6.51286529169358
scale = 1.5614094578916633
return {
"neg_log10_affinity_M": (normalized * scale) + mean,
"affinity_uM": (10**6) * (10**(-((normalized * scale) + mean)))
}
print(run({"protein": "MILK", "smiles": "CCO"})) |