Tschoui commited on
Commit
2bc1244
·
1 Parent(s): cf004a6
Files changed (4) hide show
  1. .gitattributes +1 -1
  2. .gitignore +2 -0
  3. inputs.yml +0 -0
  4. src/prediction_pipeline.py +31 -13
.gitattributes CHANGED
@@ -34,4 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
37
- *.png filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.pyc
inputs.yml ADDED
File without changes
src/prediction_pipeline.py CHANGED
@@ -19,19 +19,32 @@ from src.mhnfs.model import MHNfs
19
 
20
  class ActivityPredictor:
21
 
22
- def __init__(self):
23
 
24
- @st.cache_resource # Caching for streamlit
25
- def load_model():
26
- pl.seed_everything(1234)
27
- current_loc = __file__.rsplit("/",2)[0]
28
- model = MHNfs.load_from_checkpoint(current_loc +
29
- "/assets/mhnfs_data/"
30
- "mhnfs_checkpoint.ckpt")
31
- model._update_context_set_embedding()
32
- model.eval()
33
-
34
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Load model
37
  self.model = load_model()
@@ -55,7 +68,12 @@ class ActivityPredictor:
55
  support_inactives_input, support_inactives_size = create_support_set_input(
56
  support_inactives_smiles
57
  )
58
-
 
 
 
 
 
59
  # Make predictions
60
  predictions = self.model(
61
  query_input,
 
19
 
20
  class ActivityPredictor:
21
 
22
+ def __init__(self, streamlit=True):
23
 
24
+ if streamlit:
25
+ @st.cache_resource # Caching for streamlit
26
+ def load_model():
27
+ pl.seed_everything(1234)
28
+ current_loc = __file__.rsplit("/",2)[0]
29
+ model = MHNfs.load_from_checkpoint(current_loc +
30
+ "/assets/mhnfs_data/"
31
+ "mhnfs_checkpoint.ckpt")
32
+ model._update_context_set_embedding()
33
+ model.eval()
34
+
35
+ return model
36
+ else:
37
+ def load_model():
38
+ pl.seed_everything(1234)
39
+ current_loc = __file__.rsplit("/",2)[0]
40
+ model = MHNfs.load_from_checkpoint(current_loc +
41
+ "/assets/mhnfs_data/"
42
+ "mhnfs_checkpoint.ckpt")
43
+ model._update_context_set_embedding()
44
+ model.eval()
45
+
46
+ return model
47
+
48
 
49
  # Load model
50
  self.model = load_model()
 
68
  support_inactives_input, support_inactives_size = create_support_set_input(
69
  support_inactives_smiles
70
  )
71
+
72
+ # save inputs
73
+ import pickle
74
+ with open("/system/user/publicwork/luukkonen/mhnfs-benchmark/js_code/preprocess_data/ap_inputs.pkl", "wb") as f:
75
+ pickle.dump((query_input, support_actives_input, support_inactives_input, support_actives_size, support_inactives_size), f)
76
+
77
  # Make predictions
78
  predictions = self.model(
79
  query_input,