|
import numpy as np |
|
import json |
|
import onnxruntime |
|
from transformers import BertTokenizer, RobertaTokenizer |
|
import torch |
|
|
|
def init(): |
|
global session, prot_tokenizer, mol_tokenizer, input_name |
|
session = onnxruntime.InferenceSession("models/affinity_predictor0734-seed2101.onnx") |
|
input_name = session.get_inputs()[0].name |
|
prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) |
|
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
|
|
|
def run(raw_data): |
|
try: |
|
data = json.loads(raw_data) |
|
prot_seq = data['protein'] |
|
mol_smiles = data['smiles'] |
|
|
|
|
|
prot_tokens = prot_tokenizer(preprocess_sequence(prot_seq), |
|
padding=True, |
|
max_length=3200, |
|
truncation=True, |
|
return_tensors='pt') |
|
with torch.no_grad(): |
|
prot_representations = torch.tensor(prot_tokens['input_ids']).unsqueeze(0) |
|
prot_representations = prot_representations.squeeze(0) |
|
|
|
|
|
mol_tokens = mol_tokenizer(mol_smiles, |
|
padding=True, |
|
max_length=278, |
|
truncation=True, |
|
return_tensors='pt') |
|
with torch.no_grad(): |
|
mol_representations = torch.tensor(mol_tokens['input_ids']).unsqueeze(0) |
|
mol_representations = mol_representations.squeeze(0) |
|
|
|
|
|
features = torch.cat((prot_representations, mol_representations), dim=0) |
|
|
|
|
|
affinity_normalized = session.run(None, {input_name: [features.numpy()], 'TrainingMode': np.array(False)})[0][0][0] |
|
|
|
|
|
affinity = convert_to_affinity(affinity_normalized) |
|
|
|
return (affinity) |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
def preprocess_sequence(seq): |
|
import re |
|
return " ".join(re.sub(r"[UZOB]", "X", seq)) |
|
|
|
def convert_to_affinity(normalized): |
|
mean = 6.51286529169358 |
|
scale = 1.5614094578916633 |
|
return { |
|
"neg_log10_affinity_M": (normalized * scale) + mean, |
|
"affinity_uM": (10**6) * (10**(-((normalized * scale) + mean))) |
|
} |
|
|
|
print(run({"protein": "MILK", "smiles": "CCO"})) |