File size: 4,654 Bytes
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bc1244
cf004a6
2bc1244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bc1244
 
 
 
 
 
cf004a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
This module provides a simple predict function for the MHNfs model.
It loads the model from the provided checkpoint, creates necessary helper inputs
and makes predictions for a list of molecules
"""

#---------------------------------------------------------------------------------------
# Dependencies
import pandas as pd
import pytorch_lightning as pl
import streamlit as st

from src.data_preprocessing.create_model_inputs import (create_query_input,
                                                    create_support_set_input)
from src.mhnfs.model import MHNfs

#---------------------------------------------------------------------------------------
# Define predictor class

class ActivityPredictor:
    
    def __init__(self, streamlit=True):

        if streamlit:
            @st.cache_resource # Caching for streamlit
            def load_model():
                pl.seed_everything(1234)
                current_loc = __file__.rsplit("/",2)[0]
                model = MHNfs.load_from_checkpoint(current_loc +
                                                        "/assets/mhnfs_data/"
                                                        "mhnfs_checkpoint.ckpt")
                model._update_context_set_embedding()
                model.eval()
        
                return model
        else:
            def load_model():
                pl.seed_everything(1234)
                current_loc = __file__.rsplit("/",2)[0]
                model = MHNfs.load_from_checkpoint(current_loc +
                                                        "/assets/mhnfs_data/"
                                                        "mhnfs_checkpoint.ckpt")
                model._update_context_set_embedding()
                model.eval()
        
                return model
            
        
        # Load model
        self.model = load_model()
        
        # Initiate query mol storage
        self.query_molecules = None
        
    def predict(self, query_smiles, support_activces_smiles, support_inactives_smiles):
        
        # Create model inputs
        # Query input
        self.query_molecules = query_smiles
        query_input = create_query_input(query_smiles)
        
        # Active support set input
        support_actives_input, support_actives_size = create_support_set_input(
            support_activces_smiles
        )
        
        # Inactive support set input
        support_inactives_input, support_inactives_size = create_support_set_input(
            support_inactives_smiles
        )

        # save inputs
        import pickle
        with open("/system/user/publicwork/luukkonen/mhnfs-benchmark/js_code/preprocess_data/ap_inputs.pkl", "wb") as f:
            pickle.dump((query_input, support_actives_input, support_inactives_input, support_actives_size, support_inactives_size), f)

        # Make predictions
        predictions = self.model(
            query_input,
            support_actives_input,
            support_inactives_input,
            support_actives_size,
            support_inactives_size,
        )
        
        preds_numpy = predictions.detach().numpy().flatten()
        
        
        return preds_numpy
        
    def _return_query_mols_as_list(self):
        if isinstance(self.query_molecules, list):
            return self.query_molecules
        elif isinstance(self.query_molecules, str):
            smiles_list = self.query_molecules.split(",")
            smiles_list_cleaned = [smiles.strip() for smiles in smiles_list]
            return smiles_list_cleaned
        elif isinstance(self.query_molecules, pd.DataFrame):
            return self.query_molecules.smiles.tolist()
        elif isinstance(self.query_molecules, type(None)):
            raise ValueError("No query molecules have been stored yet."
                             "Run predict-function first.")
        else:
            raise TypeError("Type of query molecules not recognized."
                            "Please check input type.")
            
#---------------------------------------------------------------------------------------
if __name__ == "__main__":
    # Create predictor
    predictor = ActivityPredictor()
    
    # Create example inputs
    query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"]
    support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"]
    support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"]
    
    # Make predictions
    predictions = predictor.predict(query_smiles,
                                    support_actives_smiles,
                                    support_inactives_smiles)
    
    print(predictions)