File size: 2,485 Bytes
04579ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import json
import onnxruntime
from transformers import BertTokenizer, RobertaTokenizer
import torch

def init():
    global session, prot_tokenizer, mol_tokenizer, input_name
    session = onnxruntime.InferenceSession("models/affinity_predictor0734-seed2101.onnx")
    input_name = session.get_inputs()[0].name
    prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
    mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

def run(raw_data):
    try:
        data = json.loads(raw_data)
        prot_seq = data['protein']
        mol_smiles = data['smiles']

        # Tokenize and encode protein
        prot_tokens = prot_tokenizer(preprocess_sequence(prot_seq),
                                      padding=True,
                                      max_length=3200,
                                      truncation=True,
                                      return_tensors='pt')
        with torch.no_grad():
            prot_representations = torch.tensor(prot_tokens['input_ids']).unsqueeze(0)
            prot_representations = prot_representations.squeeze(0)

        # Tokenize and encode molecule
        mol_tokens = mol_tokenizer(mol_smiles,
                                    padding=True,
                                    max_length=278,
                                    truncation=True,
                                    return_tensors='pt')
        with torch.no_grad():
            mol_representations = torch.tensor(mol_tokens['input_ids']).unsqueeze(0)
            mol_representations = mol_representations.squeeze(0)

        # Combine representations
        features = torch.cat((prot_representations, mol_representations), dim=0)

        # Run inference
        affinity_normalized = session.run(None, {input_name: [features.numpy()], 'TrainingMode': np.array(False)})[0][0][0]

        # Convert to affinity
        affinity = convert_to_affinity(affinity_normalized)

        return (affinity)
    except Exception as e:
        return json.dumps({"error": str(e)})

def preprocess_sequence(seq):
    import re
    return " ".join(re.sub(r"[UZOB]", "X", seq))

def convert_to_affinity(normalized):
    mean = 6.51286529169358
    scale = 1.5614094578916633
    return {
        "neg_log10_affinity_M": (normalized * scale) + mean,
        "affinity_uM": (10**6) * (10**(-((normalized * scale) + mean)))
    }

print(run({"protein": "MILK", "smiles": "CCO"}))