Tschoui commited on
Commit
9023ae4
Β·
2 Parent(s): 13e9c8e 3299ee7

Merge branch 'main' of https://huggingface.co/spaces/ml-jku/mhnfs

Browse files
README.md CHANGED
@@ -71,9 +71,9 @@ For your screening, load the model, i.e. the **Activity Predictor** into your py
71
  from src.prediction_pipeline load ActivityPredictor
72
 
73
  # Define inputs
74
- query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] # Replace with your data
75
- support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
76
- support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
77
 
78
  # Make predictions
79
  predictions = predictor.predict(query_smiles, support_actives_smiles support_inactives_smiles)
@@ -92,10 +92,6 @@ cd .../whatever_your_dir_name_is/ # Replace with your path
92
  # Run streamlit app
93
  python -m streamlit run
94
  ```
95
-
96
-
97
- ## πŸ€— Hugging face app
98
- Explore our hugging-face app here:
99
 
100
  ## πŸ“š Cite us
101
 
 
71
  from src.prediction_pipeline load ActivityPredictor
72
 
73
  # Define inputs
74
+ query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] # Replace with your data
75
+ support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
76
+ support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
77
 
78
  # Make predictions
79
  predictions = predictor.predict(query_smiles, support_actives_smiles support_inactives_smiles)
 
92
  # Run streamlit app
93
  python -m streamlit run
94
  ```
 
 
 
 
95
 
96
  ## πŸ“š Cite us
97
 
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- This script runs the streamlit app for MHNfs
3
 
4
  MHNfs: Few-shot method for drug discovery activity predictions
5
  (https://openreview.net/pdf?id=XrMWUuEevr)
 
1
  """
2
+ This script runs the streamlit app for MHNfs.
3
 
4
  MHNfs: Few-shot method for drug discovery activity predictions
5
  (https://openreview.net/pdf?id=XrMWUuEevr)
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip==24.0
pubchem_experiment/data_preprocess.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import tqdm
4
+ import swifter
5
+ from rdkit import Chem
6
+
7
+ # Disable RDKit informational and warning messages
8
+ from rdkit import RDLogger
9
+ RDLogger.DisableLog('rdApp.*')
10
+
11
+ PUBCHEM_DIR = # pubchem_path + 'pubchem24/'
12
+ FSMOL_UID_PATH = # fsmol_path + '/fsmol/fsmol_train_accession_keys.json'
13
+ PROT_CLASS_PATH = # chembl_path + 'chembl33/uniprot_pclass_mapping.csv'
14
+ MHNFS_PATH = # mhnfs_path + '/mhnfs'
15
+
16
+ import sys
17
+ sys.path.append(MHNFS_PATH)
18
+ from src.data_preprocessing.utils import Standardizer
19
+
20
+ class PubChemFilter:
21
+
22
+ def __init__(self, pubchem_dir, fsmol_uid_path, prot_class_path, mhnfs_path, debug = False):
23
+ self.pubchem_dir = pubchem_dir
24
+ self.fsmol_uid_path = fsmol_uid_path
25
+ self.prot_class_path = prot_class_path
26
+ self.mhnfs_path = mhnfs_path
27
+ self.debug = debug
28
+
29
+ def load_and_filter_assays(self):
30
+ """
31
+ Load PubChem Assay data from file and filter them:
32
+ 1. Drop all assays without protein accession keys
33
+ 2. Drop all assays linked to multiple accession keys
34
+ 3. Drop all assays with accession keys in FSmol training data
35
+
36
+ Returns:
37
+ df_assays (pd.Dataframe)
38
+ """
39
+
40
+ print('Load assays...')
41
+ df_assays = pd.read_table(f'{self.pubchem_dir}/bioassays.tsv.gz', usecols=['AID', 'UniProts IDs'] ).rename(columns={'UniProts IDs' : 'UID'})
42
+
43
+ # Load FSmol training data accession keys
44
+ with open(self.fsmol_uid_path, 'r') as f:
45
+ fs_train_targets = json.load(f).values()
46
+ fs_train_targets = list(set([key for sublist in fs_train_targets for key in sublist]))
47
+
48
+ print('Filter assays...')
49
+ df_assays = df_assays.dropna(subset=['UID'])
50
+ df_assays = df_assays[~df_assays['UID'].str.contains('\|')]
51
+ df_assays = df_assays[~df_assays['UID'].str.contains('|'.join(fs_train_targets))]
52
+ self.df_assays = df_assays
53
+
54
+ def load_and_filter_bioactivities(self, chunk_size=10_000_000):
55
+ """
56
+ Load bioactivity data in chucks and filter out datapoints with
57
+ 1. assay not in aids
58
+ 2. outcome not 'Active'/'Inactive'
59
+ """
60
+
61
+ print('Load bioactivities...')
62
+ aids = self.df_assays.AID.tolist()
63
+ filtered_chunks = []
64
+ chunk_size = 10_000_000
65
+ for chunk in pd.read_csv(f'{self.pubchem_dir}/bioactivities.tsv.gz', sep='\t', chunksize=chunk_size, usecols=['AID', 'CID', 'Activity Outcome']):
66
+ filtered_chunk = chunk[chunk['AID'].isin(aids)]
67
+ filtered_chunk = filtered_chunk[filtered_chunk['Activity Outcome'].isin(['Inactive','Active'])]
68
+ filtered_chunks.append(filtered_chunk)
69
+ if self.debug:
70
+ break # For debugging
71
+ df_bio = pd.concat(filtered_chunks)
72
+ df_bio = df_bio[df_bio.CID.notna()]
73
+ df_bio['Activity'] = df_bio['Activity Outcome'].swifter.apply(lambda x : 1 if x == 'Active' else 0)
74
+ self.df_bio = df_bio.drop('Activity Outcome', axis=1).astype(int)
75
+
76
+ def merge_assay_and_activity_data(self):
77
+ print('Merge...')
78
+ self.df = self.df_bio.merge(self.df_assays, on='AID', how='left')
79
+ convert_dict = {col: 'int32' if col != 'UID' else 'str' for col in self.df.columns }
80
+ self.df = self.df.astype(convert_dict)
81
+ del self.df_assays, self.df_bio
82
+
83
+ def drop_hts_assays(self):
84
+ print('Drop HTS assays...')
85
+ aid_counts = self.df.groupby('AID').size()
86
+ filtered_aids = aid_counts[aid_counts <= 100_000].index
87
+ self.df = self.df[self.df['AID'].isin(filtered_aids)]
88
+
89
+ def drop_targets_with_limited_data(self, na_min=50, ni_min=50):
90
+ print('Drop targets with not enough datapoints...')
91
+ unique_uids = self.df['UID'].sort_values().unique() # Sorted unique targets
92
+ activity_counts = self.df.groupby('UID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
93
+ mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
94
+ self.df = self.df[self.df['UID'].isin(unique_uids[mask])]
95
+
96
+ def drop_conflicting_bioactivity_measures(self, target_col='UID', compound_col='CID'):
97
+ """
98
+ Check if each target-compound pair is associated to an unique activity value,
99
+ i.e. every measure either active or inactive. If not, drop it.
100
+ """
101
+
102
+ def process_group(group):
103
+ if group['Activity'].nunique() == 1:
104
+ return group.head(1)
105
+ else:
106
+ return None
107
+
108
+ print('Drop conflicting datapoints...')
109
+ # Get unique UID-CID pairs and duplicated ones
110
+ df_uniques = self.df.drop_duplicates(subset=[target_col, compound_col], keep=False)
111
+ df_duplicates = self.df[~self.df.index.isin(df_uniques.index)]
112
+
113
+ # Check duplicated pairs
114
+ groups = df_duplicates.groupby([target_col, compound_col])
115
+ rows = []
116
+ for _, group in tqdm.tqdm(groups):
117
+ rows.append(process_group(group))
118
+ df_rows = pd.concat([row for row in rows if row is not None])
119
+ self.df = pd.concat([df_uniques, df_rows])
120
+
121
+ def add_smiles(self, chunk_size=10_000_000):
122
+ print('Retrieve SMILES...')
123
+ cids = self.df.CID.astype(int).unique()
124
+ filtered_chunks = []
125
+ for chunk in pd.read_table(f'{self.pubchem_dir}/smiles.tsv.gz', chunksize=chunk_size, names=['CID', 'SMILES']):
126
+ filtered_chunk = chunk[chunk['CID'].isin(cids)]
127
+ filtered_chunks.append(filtered_chunk)
128
+ if self.debug:
129
+ break
130
+ df_smiles = pd.concat(filtered_chunks)
131
+
132
+ def cleanup(smiles):
133
+ sm = Standardizer(metal_disconnect=True, canon_taut=True)
134
+ mol = Chem.MolFromSmiles(smiles)
135
+ try:
136
+ standardized_mol, _ = sm.standardize_mol(mol)
137
+ return Chem.MolToSmiles(standardized_mol)
138
+ except:
139
+ print(smiles)
140
+ return None
141
+
142
+ df_smiles['SMILES'] = df_smiles['SMILES'].swifter.apply(lambda smi: cleanup(smi))
143
+ df_smiles.dropna(inplace=True)
144
+
145
+ self.df = self.df.merge(df_smiles, on='CID', how='left').dropna(subset=['SMILES'])
146
+
147
+ def print_stats(self):
148
+ nassays = self.df['AID'].nunique()
149
+ ntargets = self.df["UID"].nunique()
150
+ ncompounds = self.df["CID"].nunique()
151
+ nactvities = self.df.shape[0]
152
+ print(f'{ntargets: >5,} targets | {nassays: >6,} assays | {ncompounds: >9,} compounds | {nactvities: >10,} activity data points')
153
+
154
+ def save(self, fname='data/pubchem24_preprocessed.csv.gz'):
155
+ print(f'Save to {fname}...')
156
+ self.df.to_csv(fname, index=False)
157
+
158
+ def load(self, fname):
159
+ print(f'Load from {fname}...')
160
+ self.df = pd.read_csv(fname)
161
+
162
+ def add_protein_classifications(self):
163
+ """
164
+ Retrieve protein classification
165
+ """
166
+ print('Retrieve protein classifications...')
167
+ protein_class = pd.read_csv(self.prot_class_path)
168
+ print(protein_class)
169
+ # protein_class['UID'] = protein_class['target_id'].swifter.apply(lambda x: x.split('_')[0])
170
+ self.df = self.df.merge(protein_class[['UID', 'Organism', 'L1', 'L2']], on='UID', how='left')
171
+
172
+ if __name__ == '__main__':
173
+ # Create an instance of PubChemFilter class
174
+ pubchem_filter = PubChemFilter(PUBCHEM_DIR, FSMOL_UID_PATH, PROT_CLASS_PATH, MHNFS_PATH, False)
175
+
176
+ # Call methods of the class as needed
177
+ pubchem_filter.load_and_filter_assays()
178
+ pubchem_filter.load_and_filter_bioactivities()
179
+ pubchem_filter.merge_assay_and_activity_data()
180
+ pubchem_filter.print_stats()
181
+ pubchem_filter.drop_hts_assays()
182
+ pubchem_filter.print_stats()
183
+ pubchem_filter.drop_targets_with_limited_data()
184
+ pubchem_filter.print_stats()
185
+ pubchem_filter.drop_conflicting_bioactivity_measures()
186
+ pubchem_filter.print_stats()
187
+ pubchem_filter.drop_targets_with_limited_data()
188
+ pubchem_filter.print_stats()
189
+ pubchem_filter.add_smiles()
190
+ pubchem_filter.print_stats()
191
+ pubchem_filter.drop_conflicting_bioactivity_measures(compound_col='SMILES')
192
+ pubchem_filter.print_stats()
193
+ pubchem_filter.drop_targets_with_limited_data()
194
+ pubchem_filter.print_stats()
195
+ pubchem_filter.add_protein_classifications()
196
+ pubchem_filter.save(fname='data/pubchem24/preprocessed.csv.gz')
197
+
pubchem_experiment/make_predictions.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chunk
2
+ import os
3
+ import warnings
4
+
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ # from rdkit.Chem import AllChem
8
+ from rdkit.Chem import rdFingerprintGenerator
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from tqdm.auto import tqdm
11
+ import numpy as np
12
+ import clamp
13
+ import torch
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+
18
+ def generate_morgan_fingerprints(smiles_list, radius=4, n_bits=4048):
19
+ """
20
+ Generate Morgan fingerprints for a list of SMILES.
21
+ """
22
+ mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=radius,fpSize=n_bits)
23
+ mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
24
+ fps = []
25
+ for smiles, mol in zip(smiles_list, mols):
26
+ if mol is None:
27
+ print(smiles)
28
+ fps.append(None)
29
+ else:
30
+ fps.append(mfpgen.GetFingerprintAsNumPy(mol))
31
+ # np.array([mfpgen.GetFingerprintAsNumPy(mol) for mol in mols])
32
+ return fps
33
+
34
+ def rf(df, train_smiles, test_smiles):
35
+ """
36
+ Train and test RF baseline model.
37
+
38
+ Parameters:
39
+ df : pd.DataFrame with 'SMILES' and 'Activity_label' columns
40
+ train_smiles : list of training set smiles
41
+ test_smiles : list of test set smiles
42
+ Returns:
43
+ preds : list of predicted labels for the test set
44
+ """
45
+ train_df = df[df['SMILES'].isin(train_smiles)]
46
+ test_df = df[df['SMILES'].isin(test_smiles)]
47
+
48
+ # Generate Morgan fingerprints for training and test sets
49
+ X_train = generate_morgan_fingerprints(train_df['SMILES'])
50
+ X_test = generate_morgan_fingerprints(test_df['SMILES'])
51
+
52
+ # Extract labels
53
+ y_train = train_df['Activity'].values
54
+
55
+ # Train a Random Forest Classifier
56
+ clf = RandomForestClassifier(n_estimators=200, random_state=82)
57
+ clf.fit(X_train, y_train)
58
+
59
+ # Make predictions on the test set
60
+ try:
61
+ preds = clf.predict_proba(X_test)[:,1]
62
+ except Exception as e:
63
+ print(e)
64
+ print(test_df)
65
+ print(X_test)
66
+
67
+ return preds
68
+
69
+ def fh(smiles_list):
70
+ df = pd.read_csv('data/fh_predictions.csv')
71
+ preds = df[df['SMILES'].isin(smiles_list)]['Prediction'].tolist()
72
+ return preds
73
+
74
+ def drop_assays_with_limited_data(df, na_min=50, ni_min=100):
75
+ print('Drop assays with not enough datapoints...')
76
+ unique_uids = df['AID'].sort_values().unique() # Sorted unique targets
77
+ activity_counts = df.groupby('AID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
78
+ mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
79
+ df = df[df['AID'].isin(unique_uids[mask])]
80
+ return df
81
+
82
+ def run(
83
+ n_actives : int,
84
+ n_inactives : int,
85
+ model : str = 'MHNfs',
86
+ task : str = 'UID',
87
+ input_file : str = '', # todo add path
88
+ output_dir : str = '', # todo add path
89
+ n_repeats : int = 3,
90
+ seed : int = 42
91
+ ):
92
+
93
+ # Load data
94
+ data = pd.read_csv(input_file)
95
+
96
+ if task == 'AID':
97
+ data = drop_assays_with_limited_data(data, 30, 30)
98
+
99
+ # Output dir
100
+ output_dir = os.path.join(output_dir, model, task, f'{n_actives}+{n_inactives}x{n_repeats}')
101
+ print(output_dir)
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ # Tasks
105
+ tasks = data[task].value_counts(ascending=True).index.tolist()
106
+ # print(tasks)
107
+
108
+ if model == 'MHNfs':
109
+ predictor = ActivityPredictor()
110
+
111
+ # Iterate over tasks
112
+ for t in tqdm(tasks):
113
+
114
+ # Output file
115
+ output_file = os.path.join(output_dir, f'{t}.csv')
116
+ if os.path.exists(output_file):
117
+ continue
118
+
119
+ # Data for task
120
+ df = data[data[task] == t]
121
+
122
+ # Iterate over replicates
123
+ results = []
124
+ for i in range(n_repeats):
125
+ # Select support sets and test molecules
126
+ actives = df.loc[df['Activity'] == 1, 'SMILES'].sample(n=n_actives, random_state=seed+i).tolist()
127
+ inactives = df.loc[df['Activity'] == 0, 'SMILES'].sample(n=n_inactives, random_state=seed+i).tolist()
128
+ test_smiles = df[~df.SMILES.isin(actives+inactives)].SMILES.tolist()
129
+
130
+ if model == 'RF':
131
+ preds = rf(df, actives+inactives, test_smiles)
132
+ else:
133
+ if len(test_smiles) > 10_000:
134
+ # MHNfs breaks for over 20_000 datapoints -> Use chunks to make predictions
135
+ chunk_size = 10_000
136
+ chunks = [test_smiles[i:i + chunk_size] for i in range(0, len(test_smiles), chunk_size)]
137
+ preds = []
138
+ for chunk in chunks:
139
+ preds.extend( predictor.predict(chunk, actives, inactives))
140
+ else:
141
+ preds = predictor.predict(test_smiles, actives, inactives)
142
+
143
+ d = {
144
+ 'SMILES' : test_smiles,
145
+ 'Label' : df[df.SMILES.isin(test_smiles)].Activity,
146
+ 'Prediction' : preds,
147
+ 'Fold' : [i] * len(test_smiles)
148
+ }
149
+ results.append(pd.DataFrame(d))
150
+
151
+ results = pd.concat(results)
152
+ results.to_csv(output_file, index=False)
153
+
154
+ if __name__ == '__main__':
155
+
156
+ mhnfs_path = # mhnfs_path + '/mhnfs'
157
+ benchmark_path = # project_path
158
+
159
+ import sys
160
+ sys.path.append(mhnfs_path)
161
+ from src.prediction_pipeline import ActivityPredictor
162
+
163
+ support_sets = [(1,7), (2,6), (4,4)]
164
+ models = ['RF', 'MHNfs']
165
+ tasks = ['AID', 'UID']
166
+
167
+ input_file = # preprocessed_data path + '/pubchem24_preprocessed_2.csv.gz'
168
+
169
+ for support_set in support_sets:
170
+ for model in models:
171
+ for task in tasks:
172
+ run(*support_set, task=task, model=model, input_file=input_file)
pubchem_experiment/metrics.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ from tqdm.auto import tqdm
9
+ from rdkit.ML.Scoring.Scoring import CalcBEDROC
10
+ from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, average_precision_score, \
11
+ matthews_corrcoef, precision_score, recall_score, f1_score, confusion_matrix
12
+
13
+ def specificity_score(true_labels, predicted_labels):
14
+ tn, fp, _, _ = confusion_matrix(true_labels, predicted_labels).ravel()
15
+ specificity = tn / (tn + fp)
16
+ return specificity
17
+
18
+ MAIN_DIR = '' # todo add project dir
19
+
20
+ def balanced_mcc_score(sensitivity, specificity, prevalence):
21
+ """Returns the Matthews' correlation coefficient at the given
22
+ sensitivity, specificity and prevalence.
23
+
24
+ Parameters
25
+ ----------
26
+ sensitivity : float
27
+ The sensitivity of the model
28
+ specificity : float
29
+ The specificity of the model
30
+ prevalence : float
31
+ The prevalence of the test set
32
+
33
+ Returns
34
+ ------
35
+ float
36
+ Matthews' correlation coefficient as a float
37
+ """
38
+ numerator = sensitivity + specificity - 1
39
+ denominatorFirstTerm = sensitivity + (1 - specificity)*(1 - prevalence) / prevalence
40
+ denominatorSecondTerm = specificity + (1 -sensitivity)*prevalence/(1 - prevalence)
41
+ denominator = math.sqrt(denominatorFirstTerm * denominatorSecondTerm)
42
+
43
+ if sensitivity == 1 and specificity == 0:
44
+ denominator = 1
45
+ if sensitivity == 0 and specificity == 1:
46
+ denominator = 1.
47
+
48
+ return(numerator / denominator)
49
+
50
+ def ef_top_per(predictions, prevalance, top_frac=0.01):
51
+
52
+ n = int(len(predictions) * top_frac)
53
+ predictions = sorted(predictions, reverse=True)[:n]
54
+ f = np.sum(np.round(predictions)) / n
55
+ return f / prevalance
56
+
57
+ def compute_metrics(df):
58
+ """
59
+ Compute a set of classification metric for single set of predictions.
60
+
61
+ Args:
62
+ df : dataframe with true labels in 'Label' column and probabilistic predictions in 'Prediction' column
63
+
64
+ Returns:
65
+ df_metrics: dataframe with metrics in 'Metric' column and values in 'Value' column
66
+ """
67
+ true_labels = df['Label']
68
+ prevalance = sum(true_labels) / len(true_labels)
69
+ predictions = df['Prediction']
70
+
71
+ # print(true_labels.value_counts())
72
+ # print(predictions.max())
73
+
74
+ acc = accuracy_score(true_labels, predictions.round())
75
+ bacc = balanced_accuracy_score(true_labels, predictions.round())
76
+ precision = precision_score(true_labels, predictions.round(), zero_division=0.0)
77
+ recall = recall_score(true_labels, predictions.round())
78
+ specificity = specificity_score(true_labels, predictions.round())
79
+ mcc = matthews_corrcoef(true_labels, predictions.round())
80
+ bmcc = balanced_mcc_score(recall, specificity, prevalance)
81
+ f1 = f1_score(true_labels, predictions.round())
82
+
83
+ auc = roc_auc_score(true_labels, predictions)
84
+ ap = average_precision_score(true_labels, predictions)
85
+ dap = ap - prevalance
86
+ scores = df.sort_values(by='Prediction', ascending=False)[['Label', 'Prediction']].values
87
+ bedroc = CalcBEDROC(scores, 0, 20)
88
+ ef = ef_top_per(predictions, prevalance, 0.01)
89
+
90
+ metrics_dict = {'ACC': acc, 'BACC': bacc, 'MCC': mcc, 'BMCC': bmcc, 'Precision': precision, 'Recall': recall, 'F1-score': f1,
91
+ 'AUC': auc, 'dAP': dap, 'BEDROC': bedroc, 'EF-1%' : ef}
92
+ df_metrics = pd.DataFrame(metrics_dict.items(), columns=['Metric', 'Value'])
93
+
94
+
95
+ return df_metrics
96
+
97
+
98
+ def get_metrics(
99
+ tasks : list[str] = ['AID', 'UID'],
100
+ models : list[str] = ['MHNfs', 'RF'],
101
+ settings : list[str] = ['1+1x3', '1+3x3', '1+7x3', '2+2x3', '2+6x3', '2+14x3', '4+4x3', '4+12x3', '4+28x3', '8+8x3', '8+24x3', '8+56x3'],
102
+ overwrite: bool = False):
103
+
104
+ """
105
+ Computes classification metrics for each combination.
106
+ """
107
+
108
+ file = f'{MAIN_DIR}/results_used.csv.gz'
109
+
110
+ if overwrite:
111
+ df = pd.DataFrame()
112
+ else:
113
+ df = pd.read_csv(file)
114
+
115
+ path_preprocessed = "" # todo
116
+ df_pubchem = pd.read_csv(path_preprocessed)
117
+
118
+ for task in tasks:
119
+ for model in models:
120
+ for setting in settings:
121
+ dir = f'{MAIN_DIR}/predictions/{model}/{task}/{setting}'
122
+ try:
123
+ targets = [x[:-4] for x in os.listdir(dir)]
124
+ pubchem_targets = df_pubchem[task].astype(str).unique().tolist()
125
+
126
+ for target in tqdm(targets, desc=f'{task} - {model} - {setting}'):
127
+
128
+ if target not in pubchem_targets:
129
+ continue
130
+
131
+ # Skip already computed targets
132
+ if not overwrite and any((df['Model'] == model) & (df['Setting'] == setting) & (df['Task'] == task) & (df['TID'] == target)):
133
+ continue
134
+
135
+ # Load predictions
136
+ df_task = pd.read_csv(f'{dir}/{target}.csv')
137
+
138
+ # Retrieve oragnism and L1 protein classification
139
+ try:
140
+ org = df_pubchem.loc[df_pubchem[task] == target, 'Organism'].values[0]
141
+ l1 = df_pubchem.loc[df_pubchem[task] == target, 'L1'].values[0]
142
+ except:
143
+ org = df_pubchem.loc[df_pubchem[task] == int(target), 'Organism'].values[0]
144
+ l1 = df_pubchem.loc[df_pubchem[task] == int(target), 'L1'].values[0]
145
+ if l1 == None:
146
+ print(target, l1)
147
+
148
+ # Compute metrics for each fold
149
+ for fold in df_task.Fold.unique():
150
+ metrics = (compute_metrics(df_task[df_task.Fold == fold]).assign(
151
+ Model=model, Task=task, TID=target, Organism=org, L1=l1, Setting=setting, Fold=fold,
152
+ )
153
+ ).rename(columns={'Target' : task})
154
+ df = pd.concat([df, metrics], ignore_index=True)
155
+ except Exception as e:
156
+ print(e)
157
+ raise e
158
+
159
+ df.to_csv(file, index=False)
160
+
161
+ if __name__ == '__main__':
162
+ #get_metrics()
163
+ get_metrics(settings=['1+7x3', '2+6x3', '4+4x3', '2+14x3', '4+12x3','8+8x3'], overwrite=True)
src/app/constants.py CHANGED
@@ -17,7 +17,10 @@ summary_text = ('''
17
  Hit **Predict** and explore the predictions!
18
 
19
  For more **information** about the **model** and **how to provide the
20
- molecules**, please visit the **Additional Information** tab.
 
 
 
21
  ''')
22
 
23
  mhnfs_text =('''
 
17
  Hit **Predict** and explore the predictions!
18
 
19
  For more **information** about the **model** and **how to provide the
20
+ molecules**, please visit the **Additional Information** tab.
21
+
22
+ If you encounter any problems, we would be glad if you could report them
23
+ to us: **[email protected]**.
24
  ''')
25
 
26
  mhnfs_text =('''
src/app/layout.py CHANGED
@@ -159,8 +159,16 @@ class LayoutMaker():
159
  label="SMILES input for query molecules",
160
  label_visibility="hidden",
161
  key="query_textbox",
162
- value="CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, "
163
- "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O",
 
 
 
 
 
 
 
 
164
  )
165
  elif input_choice == "CSV upload":
166
  query_file = st.file_uploader(key="query_csv",
@@ -194,8 +202,8 @@ class LayoutMaker():
194
  label="SMILES input for active support set molecules",
195
  label_visibility="hidden",
196
  key="active_textbox",
197
- value="Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O, "
198
- "CSc1nc(C(C)C)nc(OCC(=O)O)c1C#N"
199
  )
200
  elif active_input_choice == "CSV upload":
201
  support_active_file = st.file_uploader(
@@ -224,8 +232,8 @@ class LayoutMaker():
224
  label="SMILES input for inactive support set molecules",
225
  label_visibility="hidden",
226
  key="inactive_textbox",
227
- value="CSc1nc(C)nc(OCC(=O)O)c1C#N, "
228
- "CSc1nc(C)n(CC(=O)O)c(=O)c1C#N"
229
  )
230
  elif inactive_input_choice == "CSV upload":
231
  support_inactive_file = st.file_uploader(
 
159
  label="SMILES input for query molecules",
160
  label_visibility="hidden",
161
  key="query_textbox",
162
+ value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
163
+ "N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, "
164
+ "Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
165
+ "CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, "
166
+ "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, "
167
+ "COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, "
168
+ "Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, "
169
+ "CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, "
170
+ "N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, "
171
+ "COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F"
172
  )
173
  elif input_choice == "CSV upload":
174
  query_file = st.file_uploader(key="query_csv",
 
202
  label="SMILES input for active support set molecules",
203
  label_visibility="hidden",
204
  key="active_textbox",
205
+ value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, "
206
+ "Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O"
207
  )
208
  elif active_input_choice == "CSV upload":
209
  support_active_file = st.file_uploader(
 
232
  label="SMILES input for inactive support set molecules",
233
  label_visibility="hidden",
234
  key="inactive_textbox",
235
+ value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, "
236
+ "CSc1nc(C)nc(OCC(=O)O)c1C#N"
237
  )
238
  elif inactive_input_choice == "CSV upload":
239
  support_inactive_file = st.file_uploader(