Merge branch 'main' of https://huggingface.co/spaces/ml-jku/mhnfs
Browse files- README.md +3 -7
- app.py +1 -1
- pre-requirements.txt +1 -0
- pubchem_experiment/data_preprocess.py +197 -0
- pubchem_experiment/make_predictions.py +172 -0
- pubchem_experiment/metrics.py +163 -0
- src/app/constants.py +4 -1
- src/app/layout.py +14 -6
README.md
CHANGED
@@ -71,9 +71,9 @@ For your screening, load the model, i.e. the **Activity Predictor** into your py
|
|
71 |
from src.prediction_pipeline load ActivityPredictor
|
72 |
|
73 |
# Define inputs
|
74 |
-
query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"]
|
75 |
-
support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"]
|
76 |
-
support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"]
|
77 |
|
78 |
# Make predictions
|
79 |
predictions = predictor.predict(query_smiles, support_actives_smiles support_inactives_smiles)
|
@@ -92,10 +92,6 @@ cd .../whatever_your_dir_name_is/ # Replace with your path
|
|
92 |
# Run streamlit app
|
93 |
python -m streamlit run
|
94 |
```
|
95 |
-
|
96 |
-
|
97 |
-
## π€ Hugging face app
|
98 |
-
Explore our hugging-face app here:
|
99 |
|
100 |
## π Cite us
|
101 |
|
|
|
71 |
from src.prediction_pipeline load ActivityPredictor
|
72 |
|
73 |
# Define inputs
|
74 |
+
query_smiles = ["C1CCCCC1", "C1CCCCC1", "C1CCCCC1", "C1CCCCC1"] # Replace with your data
|
75 |
+
support_actives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
|
76 |
+
support_inactives_smiles = ["C1CCCCC1", "C1CCCCC1"] # Replace with your data
|
77 |
|
78 |
# Make predictions
|
79 |
predictions = predictor.predict(query_smiles, support_actives_smiles support_inactives_smiles)
|
|
|
92 |
# Run streamlit app
|
93 |
python -m streamlit run
|
94 |
```
|
|
|
|
|
|
|
|
|
95 |
|
96 |
## π Cite us
|
97 |
|
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
This script runs the streamlit app for MHNfs
|
3 |
|
4 |
MHNfs: Few-shot method for drug discovery activity predictions
|
5 |
(https://openreview.net/pdf?id=XrMWUuEevr)
|
|
|
1 |
"""
|
2 |
+
This script runs the streamlit app for MHNfs.
|
3 |
|
4 |
MHNfs: Few-shot method for drug discovery activity predictions
|
5 |
(https://openreview.net/pdf?id=XrMWUuEevr)
|
pre-requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip==24.0
|
pubchem_experiment/data_preprocess.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pandas as pd
|
3 |
+
import tqdm
|
4 |
+
import swifter
|
5 |
+
from rdkit import Chem
|
6 |
+
|
7 |
+
# Disable RDKit informational and warning messages
|
8 |
+
from rdkit import RDLogger
|
9 |
+
RDLogger.DisableLog('rdApp.*')
|
10 |
+
|
11 |
+
PUBCHEM_DIR = # pubchem_path + 'pubchem24/'
|
12 |
+
FSMOL_UID_PATH = # fsmol_path + '/fsmol/fsmol_train_accession_keys.json'
|
13 |
+
PROT_CLASS_PATH = # chembl_path + 'chembl33/uniprot_pclass_mapping.csv'
|
14 |
+
MHNFS_PATH = # mhnfs_path + '/mhnfs'
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append(MHNFS_PATH)
|
18 |
+
from src.data_preprocessing.utils import Standardizer
|
19 |
+
|
20 |
+
class PubChemFilter:
|
21 |
+
|
22 |
+
def __init__(self, pubchem_dir, fsmol_uid_path, prot_class_path, mhnfs_path, debug = False):
|
23 |
+
self.pubchem_dir = pubchem_dir
|
24 |
+
self.fsmol_uid_path = fsmol_uid_path
|
25 |
+
self.prot_class_path = prot_class_path
|
26 |
+
self.mhnfs_path = mhnfs_path
|
27 |
+
self.debug = debug
|
28 |
+
|
29 |
+
def load_and_filter_assays(self):
|
30 |
+
"""
|
31 |
+
Load PubChem Assay data from file and filter them:
|
32 |
+
1. Drop all assays without protein accession keys
|
33 |
+
2. Drop all assays linked to multiple accession keys
|
34 |
+
3. Drop all assays with accession keys in FSmol training data
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
df_assays (pd.Dataframe)
|
38 |
+
"""
|
39 |
+
|
40 |
+
print('Load assays...')
|
41 |
+
df_assays = pd.read_table(f'{self.pubchem_dir}/bioassays.tsv.gz', usecols=['AID', 'UniProts IDs'] ).rename(columns={'UniProts IDs' : 'UID'})
|
42 |
+
|
43 |
+
# Load FSmol training data accession keys
|
44 |
+
with open(self.fsmol_uid_path, 'r') as f:
|
45 |
+
fs_train_targets = json.load(f).values()
|
46 |
+
fs_train_targets = list(set([key for sublist in fs_train_targets for key in sublist]))
|
47 |
+
|
48 |
+
print('Filter assays...')
|
49 |
+
df_assays = df_assays.dropna(subset=['UID'])
|
50 |
+
df_assays = df_assays[~df_assays['UID'].str.contains('\|')]
|
51 |
+
df_assays = df_assays[~df_assays['UID'].str.contains('|'.join(fs_train_targets))]
|
52 |
+
self.df_assays = df_assays
|
53 |
+
|
54 |
+
def load_and_filter_bioactivities(self, chunk_size=10_000_000):
|
55 |
+
"""
|
56 |
+
Load bioactivity data in chucks and filter out datapoints with
|
57 |
+
1. assay not in aids
|
58 |
+
2. outcome not 'Active'/'Inactive'
|
59 |
+
"""
|
60 |
+
|
61 |
+
print('Load bioactivities...')
|
62 |
+
aids = self.df_assays.AID.tolist()
|
63 |
+
filtered_chunks = []
|
64 |
+
chunk_size = 10_000_000
|
65 |
+
for chunk in pd.read_csv(f'{self.pubchem_dir}/bioactivities.tsv.gz', sep='\t', chunksize=chunk_size, usecols=['AID', 'CID', 'Activity Outcome']):
|
66 |
+
filtered_chunk = chunk[chunk['AID'].isin(aids)]
|
67 |
+
filtered_chunk = filtered_chunk[filtered_chunk['Activity Outcome'].isin(['Inactive','Active'])]
|
68 |
+
filtered_chunks.append(filtered_chunk)
|
69 |
+
if self.debug:
|
70 |
+
break # For debugging
|
71 |
+
df_bio = pd.concat(filtered_chunks)
|
72 |
+
df_bio = df_bio[df_bio.CID.notna()]
|
73 |
+
df_bio['Activity'] = df_bio['Activity Outcome'].swifter.apply(lambda x : 1 if x == 'Active' else 0)
|
74 |
+
self.df_bio = df_bio.drop('Activity Outcome', axis=1).astype(int)
|
75 |
+
|
76 |
+
def merge_assay_and_activity_data(self):
|
77 |
+
print('Merge...')
|
78 |
+
self.df = self.df_bio.merge(self.df_assays, on='AID', how='left')
|
79 |
+
convert_dict = {col: 'int32' if col != 'UID' else 'str' for col in self.df.columns }
|
80 |
+
self.df = self.df.astype(convert_dict)
|
81 |
+
del self.df_assays, self.df_bio
|
82 |
+
|
83 |
+
def drop_hts_assays(self):
|
84 |
+
print('Drop HTS assays...')
|
85 |
+
aid_counts = self.df.groupby('AID').size()
|
86 |
+
filtered_aids = aid_counts[aid_counts <= 100_000].index
|
87 |
+
self.df = self.df[self.df['AID'].isin(filtered_aids)]
|
88 |
+
|
89 |
+
def drop_targets_with_limited_data(self, na_min=50, ni_min=50):
|
90 |
+
print('Drop targets with not enough datapoints...')
|
91 |
+
unique_uids = self.df['UID'].sort_values().unique() # Sorted unique targets
|
92 |
+
activity_counts = self.df.groupby('UID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
|
93 |
+
mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
|
94 |
+
self.df = self.df[self.df['UID'].isin(unique_uids[mask])]
|
95 |
+
|
96 |
+
def drop_conflicting_bioactivity_measures(self, target_col='UID', compound_col='CID'):
|
97 |
+
"""
|
98 |
+
Check if each target-compound pair is associated to an unique activity value,
|
99 |
+
i.e. every measure either active or inactive. If not, drop it.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def process_group(group):
|
103 |
+
if group['Activity'].nunique() == 1:
|
104 |
+
return group.head(1)
|
105 |
+
else:
|
106 |
+
return None
|
107 |
+
|
108 |
+
print('Drop conflicting datapoints...')
|
109 |
+
# Get unique UID-CID pairs and duplicated ones
|
110 |
+
df_uniques = self.df.drop_duplicates(subset=[target_col, compound_col], keep=False)
|
111 |
+
df_duplicates = self.df[~self.df.index.isin(df_uniques.index)]
|
112 |
+
|
113 |
+
# Check duplicated pairs
|
114 |
+
groups = df_duplicates.groupby([target_col, compound_col])
|
115 |
+
rows = []
|
116 |
+
for _, group in tqdm.tqdm(groups):
|
117 |
+
rows.append(process_group(group))
|
118 |
+
df_rows = pd.concat([row for row in rows if row is not None])
|
119 |
+
self.df = pd.concat([df_uniques, df_rows])
|
120 |
+
|
121 |
+
def add_smiles(self, chunk_size=10_000_000):
|
122 |
+
print('Retrieve SMILES...')
|
123 |
+
cids = self.df.CID.astype(int).unique()
|
124 |
+
filtered_chunks = []
|
125 |
+
for chunk in pd.read_table(f'{self.pubchem_dir}/smiles.tsv.gz', chunksize=chunk_size, names=['CID', 'SMILES']):
|
126 |
+
filtered_chunk = chunk[chunk['CID'].isin(cids)]
|
127 |
+
filtered_chunks.append(filtered_chunk)
|
128 |
+
if self.debug:
|
129 |
+
break
|
130 |
+
df_smiles = pd.concat(filtered_chunks)
|
131 |
+
|
132 |
+
def cleanup(smiles):
|
133 |
+
sm = Standardizer(metal_disconnect=True, canon_taut=True)
|
134 |
+
mol = Chem.MolFromSmiles(smiles)
|
135 |
+
try:
|
136 |
+
standardized_mol, _ = sm.standardize_mol(mol)
|
137 |
+
return Chem.MolToSmiles(standardized_mol)
|
138 |
+
except:
|
139 |
+
print(smiles)
|
140 |
+
return None
|
141 |
+
|
142 |
+
df_smiles['SMILES'] = df_smiles['SMILES'].swifter.apply(lambda smi: cleanup(smi))
|
143 |
+
df_smiles.dropna(inplace=True)
|
144 |
+
|
145 |
+
self.df = self.df.merge(df_smiles, on='CID', how='left').dropna(subset=['SMILES'])
|
146 |
+
|
147 |
+
def print_stats(self):
|
148 |
+
nassays = self.df['AID'].nunique()
|
149 |
+
ntargets = self.df["UID"].nunique()
|
150 |
+
ncompounds = self.df["CID"].nunique()
|
151 |
+
nactvities = self.df.shape[0]
|
152 |
+
print(f'{ntargets: >5,} targets | {nassays: >6,} assays | {ncompounds: >9,} compounds | {nactvities: >10,} activity data points')
|
153 |
+
|
154 |
+
def save(self, fname='data/pubchem24_preprocessed.csv.gz'):
|
155 |
+
print(f'Save to {fname}...')
|
156 |
+
self.df.to_csv(fname, index=False)
|
157 |
+
|
158 |
+
def load(self, fname):
|
159 |
+
print(f'Load from {fname}...')
|
160 |
+
self.df = pd.read_csv(fname)
|
161 |
+
|
162 |
+
def add_protein_classifications(self):
|
163 |
+
"""
|
164 |
+
Retrieve protein classification
|
165 |
+
"""
|
166 |
+
print('Retrieve protein classifications...')
|
167 |
+
protein_class = pd.read_csv(self.prot_class_path)
|
168 |
+
print(protein_class)
|
169 |
+
# protein_class['UID'] = protein_class['target_id'].swifter.apply(lambda x: x.split('_')[0])
|
170 |
+
self.df = self.df.merge(protein_class[['UID', 'Organism', 'L1', 'L2']], on='UID', how='left')
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
# Create an instance of PubChemFilter class
|
174 |
+
pubchem_filter = PubChemFilter(PUBCHEM_DIR, FSMOL_UID_PATH, PROT_CLASS_PATH, MHNFS_PATH, False)
|
175 |
+
|
176 |
+
# Call methods of the class as needed
|
177 |
+
pubchem_filter.load_and_filter_assays()
|
178 |
+
pubchem_filter.load_and_filter_bioactivities()
|
179 |
+
pubchem_filter.merge_assay_and_activity_data()
|
180 |
+
pubchem_filter.print_stats()
|
181 |
+
pubchem_filter.drop_hts_assays()
|
182 |
+
pubchem_filter.print_stats()
|
183 |
+
pubchem_filter.drop_targets_with_limited_data()
|
184 |
+
pubchem_filter.print_stats()
|
185 |
+
pubchem_filter.drop_conflicting_bioactivity_measures()
|
186 |
+
pubchem_filter.print_stats()
|
187 |
+
pubchem_filter.drop_targets_with_limited_data()
|
188 |
+
pubchem_filter.print_stats()
|
189 |
+
pubchem_filter.add_smiles()
|
190 |
+
pubchem_filter.print_stats()
|
191 |
+
pubchem_filter.drop_conflicting_bioactivity_measures(compound_col='SMILES')
|
192 |
+
pubchem_filter.print_stats()
|
193 |
+
pubchem_filter.drop_targets_with_limited_data()
|
194 |
+
pubchem_filter.print_stats()
|
195 |
+
pubchem_filter.add_protein_classifications()
|
196 |
+
pubchem_filter.save(fname='data/pubchem24/preprocessed.csv.gz')
|
197 |
+
|
pubchem_experiment/make_predictions.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chunk
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from rdkit import Chem
|
7 |
+
# from rdkit.Chem import AllChem
|
8 |
+
from rdkit.Chem import rdFingerprintGenerator
|
9 |
+
from sklearn.ensemble import RandomForestClassifier
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import clamp
|
13 |
+
import torch
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
|
18 |
+
def generate_morgan_fingerprints(smiles_list, radius=4, n_bits=4048):
|
19 |
+
"""
|
20 |
+
Generate Morgan fingerprints for a list of SMILES.
|
21 |
+
"""
|
22 |
+
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=radius,fpSize=n_bits)
|
23 |
+
mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
|
24 |
+
fps = []
|
25 |
+
for smiles, mol in zip(smiles_list, mols):
|
26 |
+
if mol is None:
|
27 |
+
print(smiles)
|
28 |
+
fps.append(None)
|
29 |
+
else:
|
30 |
+
fps.append(mfpgen.GetFingerprintAsNumPy(mol))
|
31 |
+
# np.array([mfpgen.GetFingerprintAsNumPy(mol) for mol in mols])
|
32 |
+
return fps
|
33 |
+
|
34 |
+
def rf(df, train_smiles, test_smiles):
|
35 |
+
"""
|
36 |
+
Train and test RF baseline model.
|
37 |
+
|
38 |
+
Parameters:
|
39 |
+
df : pd.DataFrame with 'SMILES' and 'Activity_label' columns
|
40 |
+
train_smiles : list of training set smiles
|
41 |
+
test_smiles : list of test set smiles
|
42 |
+
Returns:
|
43 |
+
preds : list of predicted labels for the test set
|
44 |
+
"""
|
45 |
+
train_df = df[df['SMILES'].isin(train_smiles)]
|
46 |
+
test_df = df[df['SMILES'].isin(test_smiles)]
|
47 |
+
|
48 |
+
# Generate Morgan fingerprints for training and test sets
|
49 |
+
X_train = generate_morgan_fingerprints(train_df['SMILES'])
|
50 |
+
X_test = generate_morgan_fingerprints(test_df['SMILES'])
|
51 |
+
|
52 |
+
# Extract labels
|
53 |
+
y_train = train_df['Activity'].values
|
54 |
+
|
55 |
+
# Train a Random Forest Classifier
|
56 |
+
clf = RandomForestClassifier(n_estimators=200, random_state=82)
|
57 |
+
clf.fit(X_train, y_train)
|
58 |
+
|
59 |
+
# Make predictions on the test set
|
60 |
+
try:
|
61 |
+
preds = clf.predict_proba(X_test)[:,1]
|
62 |
+
except Exception as e:
|
63 |
+
print(e)
|
64 |
+
print(test_df)
|
65 |
+
print(X_test)
|
66 |
+
|
67 |
+
return preds
|
68 |
+
|
69 |
+
def fh(smiles_list):
|
70 |
+
df = pd.read_csv('data/fh_predictions.csv')
|
71 |
+
preds = df[df['SMILES'].isin(smiles_list)]['Prediction'].tolist()
|
72 |
+
return preds
|
73 |
+
|
74 |
+
def drop_assays_with_limited_data(df, na_min=50, ni_min=100):
|
75 |
+
print('Drop assays with not enough datapoints...')
|
76 |
+
unique_uids = df['AID'].sort_values().unique() # Sorted unique targets
|
77 |
+
activity_counts = df.groupby('AID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
|
78 |
+
mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
|
79 |
+
df = df[df['AID'].isin(unique_uids[mask])]
|
80 |
+
return df
|
81 |
+
|
82 |
+
def run(
|
83 |
+
n_actives : int,
|
84 |
+
n_inactives : int,
|
85 |
+
model : str = 'MHNfs',
|
86 |
+
task : str = 'UID',
|
87 |
+
input_file : str = '', # todo add path
|
88 |
+
output_dir : str = '', # todo add path
|
89 |
+
n_repeats : int = 3,
|
90 |
+
seed : int = 42
|
91 |
+
):
|
92 |
+
|
93 |
+
# Load data
|
94 |
+
data = pd.read_csv(input_file)
|
95 |
+
|
96 |
+
if task == 'AID':
|
97 |
+
data = drop_assays_with_limited_data(data, 30, 30)
|
98 |
+
|
99 |
+
# Output dir
|
100 |
+
output_dir = os.path.join(output_dir, model, task, f'{n_actives}+{n_inactives}x{n_repeats}')
|
101 |
+
print(output_dir)
|
102 |
+
os.makedirs(output_dir, exist_ok=True)
|
103 |
+
|
104 |
+
# Tasks
|
105 |
+
tasks = data[task].value_counts(ascending=True).index.tolist()
|
106 |
+
# print(tasks)
|
107 |
+
|
108 |
+
if model == 'MHNfs':
|
109 |
+
predictor = ActivityPredictor()
|
110 |
+
|
111 |
+
# Iterate over tasks
|
112 |
+
for t in tqdm(tasks):
|
113 |
+
|
114 |
+
# Output file
|
115 |
+
output_file = os.path.join(output_dir, f'{t}.csv')
|
116 |
+
if os.path.exists(output_file):
|
117 |
+
continue
|
118 |
+
|
119 |
+
# Data for task
|
120 |
+
df = data[data[task] == t]
|
121 |
+
|
122 |
+
# Iterate over replicates
|
123 |
+
results = []
|
124 |
+
for i in range(n_repeats):
|
125 |
+
# Select support sets and test molecules
|
126 |
+
actives = df.loc[df['Activity'] == 1, 'SMILES'].sample(n=n_actives, random_state=seed+i).tolist()
|
127 |
+
inactives = df.loc[df['Activity'] == 0, 'SMILES'].sample(n=n_inactives, random_state=seed+i).tolist()
|
128 |
+
test_smiles = df[~df.SMILES.isin(actives+inactives)].SMILES.tolist()
|
129 |
+
|
130 |
+
if model == 'RF':
|
131 |
+
preds = rf(df, actives+inactives, test_smiles)
|
132 |
+
else:
|
133 |
+
if len(test_smiles) > 10_000:
|
134 |
+
# MHNfs breaks for over 20_000 datapoints -> Use chunks to make predictions
|
135 |
+
chunk_size = 10_000
|
136 |
+
chunks = [test_smiles[i:i + chunk_size] for i in range(0, len(test_smiles), chunk_size)]
|
137 |
+
preds = []
|
138 |
+
for chunk in chunks:
|
139 |
+
preds.extend( predictor.predict(chunk, actives, inactives))
|
140 |
+
else:
|
141 |
+
preds = predictor.predict(test_smiles, actives, inactives)
|
142 |
+
|
143 |
+
d = {
|
144 |
+
'SMILES' : test_smiles,
|
145 |
+
'Label' : df[df.SMILES.isin(test_smiles)].Activity,
|
146 |
+
'Prediction' : preds,
|
147 |
+
'Fold' : [i] * len(test_smiles)
|
148 |
+
}
|
149 |
+
results.append(pd.DataFrame(d))
|
150 |
+
|
151 |
+
results = pd.concat(results)
|
152 |
+
results.to_csv(output_file, index=False)
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
|
156 |
+
mhnfs_path = # mhnfs_path + '/mhnfs'
|
157 |
+
benchmark_path = # project_path
|
158 |
+
|
159 |
+
import sys
|
160 |
+
sys.path.append(mhnfs_path)
|
161 |
+
from src.prediction_pipeline import ActivityPredictor
|
162 |
+
|
163 |
+
support_sets = [(1,7), (2,6), (4,4)]
|
164 |
+
models = ['RF', 'MHNfs']
|
165 |
+
tasks = ['AID', 'UID']
|
166 |
+
|
167 |
+
input_file = # preprocessed_data path + '/pubchem24_preprocessed_2.csv.gz'
|
168 |
+
|
169 |
+
for support_set in support_sets:
|
170 |
+
for model in models:
|
171 |
+
for task in tasks:
|
172 |
+
run(*support_set, task=task, model=model, input_file=input_file)
|
pubchem_experiment/metrics.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from rdkit.ML.Scoring.Scoring import CalcBEDROC
|
10 |
+
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, average_precision_score, \
|
11 |
+
matthews_corrcoef, precision_score, recall_score, f1_score, confusion_matrix
|
12 |
+
|
13 |
+
def specificity_score(true_labels, predicted_labels):
|
14 |
+
tn, fp, _, _ = confusion_matrix(true_labels, predicted_labels).ravel()
|
15 |
+
specificity = tn / (tn + fp)
|
16 |
+
return specificity
|
17 |
+
|
18 |
+
MAIN_DIR = '' # todo add project dir
|
19 |
+
|
20 |
+
def balanced_mcc_score(sensitivity, specificity, prevalence):
|
21 |
+
"""Returns the Matthews' correlation coefficient at the given
|
22 |
+
sensitivity, specificity and prevalence.
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
sensitivity : float
|
27 |
+
The sensitivity of the model
|
28 |
+
specificity : float
|
29 |
+
The specificity of the model
|
30 |
+
prevalence : float
|
31 |
+
The prevalence of the test set
|
32 |
+
|
33 |
+
Returns
|
34 |
+
------
|
35 |
+
float
|
36 |
+
Matthews' correlation coefficient as a float
|
37 |
+
"""
|
38 |
+
numerator = sensitivity + specificity - 1
|
39 |
+
denominatorFirstTerm = sensitivity + (1 - specificity)*(1 - prevalence) / prevalence
|
40 |
+
denominatorSecondTerm = specificity + (1 -sensitivity)*prevalence/(1 - prevalence)
|
41 |
+
denominator = math.sqrt(denominatorFirstTerm * denominatorSecondTerm)
|
42 |
+
|
43 |
+
if sensitivity == 1 and specificity == 0:
|
44 |
+
denominator = 1
|
45 |
+
if sensitivity == 0 and specificity == 1:
|
46 |
+
denominator = 1.
|
47 |
+
|
48 |
+
return(numerator / denominator)
|
49 |
+
|
50 |
+
def ef_top_per(predictions, prevalance, top_frac=0.01):
|
51 |
+
|
52 |
+
n = int(len(predictions) * top_frac)
|
53 |
+
predictions = sorted(predictions, reverse=True)[:n]
|
54 |
+
f = np.sum(np.round(predictions)) / n
|
55 |
+
return f / prevalance
|
56 |
+
|
57 |
+
def compute_metrics(df):
|
58 |
+
"""
|
59 |
+
Compute a set of classification metric for single set of predictions.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
df : dataframe with true labels in 'Label' column and probabilistic predictions in 'Prediction' column
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
df_metrics: dataframe with metrics in 'Metric' column and values in 'Value' column
|
66 |
+
"""
|
67 |
+
true_labels = df['Label']
|
68 |
+
prevalance = sum(true_labels) / len(true_labels)
|
69 |
+
predictions = df['Prediction']
|
70 |
+
|
71 |
+
# print(true_labels.value_counts())
|
72 |
+
# print(predictions.max())
|
73 |
+
|
74 |
+
acc = accuracy_score(true_labels, predictions.round())
|
75 |
+
bacc = balanced_accuracy_score(true_labels, predictions.round())
|
76 |
+
precision = precision_score(true_labels, predictions.round(), zero_division=0.0)
|
77 |
+
recall = recall_score(true_labels, predictions.round())
|
78 |
+
specificity = specificity_score(true_labels, predictions.round())
|
79 |
+
mcc = matthews_corrcoef(true_labels, predictions.round())
|
80 |
+
bmcc = balanced_mcc_score(recall, specificity, prevalance)
|
81 |
+
f1 = f1_score(true_labels, predictions.round())
|
82 |
+
|
83 |
+
auc = roc_auc_score(true_labels, predictions)
|
84 |
+
ap = average_precision_score(true_labels, predictions)
|
85 |
+
dap = ap - prevalance
|
86 |
+
scores = df.sort_values(by='Prediction', ascending=False)[['Label', 'Prediction']].values
|
87 |
+
bedroc = CalcBEDROC(scores, 0, 20)
|
88 |
+
ef = ef_top_per(predictions, prevalance, 0.01)
|
89 |
+
|
90 |
+
metrics_dict = {'ACC': acc, 'BACC': bacc, 'MCC': mcc, 'BMCC': bmcc, 'Precision': precision, 'Recall': recall, 'F1-score': f1,
|
91 |
+
'AUC': auc, 'dAP': dap, 'BEDROC': bedroc, 'EF-1%' : ef}
|
92 |
+
df_metrics = pd.DataFrame(metrics_dict.items(), columns=['Metric', 'Value'])
|
93 |
+
|
94 |
+
|
95 |
+
return df_metrics
|
96 |
+
|
97 |
+
|
98 |
+
def get_metrics(
|
99 |
+
tasks : list[str] = ['AID', 'UID'],
|
100 |
+
models : list[str] = ['MHNfs', 'RF'],
|
101 |
+
settings : list[str] = ['1+1x3', '1+3x3', '1+7x3', '2+2x3', '2+6x3', '2+14x3', '4+4x3', '4+12x3', '4+28x3', '8+8x3', '8+24x3', '8+56x3'],
|
102 |
+
overwrite: bool = False):
|
103 |
+
|
104 |
+
"""
|
105 |
+
Computes classification metrics for each combination.
|
106 |
+
"""
|
107 |
+
|
108 |
+
file = f'{MAIN_DIR}/results_used.csv.gz'
|
109 |
+
|
110 |
+
if overwrite:
|
111 |
+
df = pd.DataFrame()
|
112 |
+
else:
|
113 |
+
df = pd.read_csv(file)
|
114 |
+
|
115 |
+
path_preprocessed = "" # todo
|
116 |
+
df_pubchem = pd.read_csv(path_preprocessed)
|
117 |
+
|
118 |
+
for task in tasks:
|
119 |
+
for model in models:
|
120 |
+
for setting in settings:
|
121 |
+
dir = f'{MAIN_DIR}/predictions/{model}/{task}/{setting}'
|
122 |
+
try:
|
123 |
+
targets = [x[:-4] for x in os.listdir(dir)]
|
124 |
+
pubchem_targets = df_pubchem[task].astype(str).unique().tolist()
|
125 |
+
|
126 |
+
for target in tqdm(targets, desc=f'{task} - {model} - {setting}'):
|
127 |
+
|
128 |
+
if target not in pubchem_targets:
|
129 |
+
continue
|
130 |
+
|
131 |
+
# Skip already computed targets
|
132 |
+
if not overwrite and any((df['Model'] == model) & (df['Setting'] == setting) & (df['Task'] == task) & (df['TID'] == target)):
|
133 |
+
continue
|
134 |
+
|
135 |
+
# Load predictions
|
136 |
+
df_task = pd.read_csv(f'{dir}/{target}.csv')
|
137 |
+
|
138 |
+
# Retrieve oragnism and L1 protein classification
|
139 |
+
try:
|
140 |
+
org = df_pubchem.loc[df_pubchem[task] == target, 'Organism'].values[0]
|
141 |
+
l1 = df_pubchem.loc[df_pubchem[task] == target, 'L1'].values[0]
|
142 |
+
except:
|
143 |
+
org = df_pubchem.loc[df_pubchem[task] == int(target), 'Organism'].values[0]
|
144 |
+
l1 = df_pubchem.loc[df_pubchem[task] == int(target), 'L1'].values[0]
|
145 |
+
if l1 == None:
|
146 |
+
print(target, l1)
|
147 |
+
|
148 |
+
# Compute metrics for each fold
|
149 |
+
for fold in df_task.Fold.unique():
|
150 |
+
metrics = (compute_metrics(df_task[df_task.Fold == fold]).assign(
|
151 |
+
Model=model, Task=task, TID=target, Organism=org, L1=l1, Setting=setting, Fold=fold,
|
152 |
+
)
|
153 |
+
).rename(columns={'Target' : task})
|
154 |
+
df = pd.concat([df, metrics], ignore_index=True)
|
155 |
+
except Exception as e:
|
156 |
+
print(e)
|
157 |
+
raise e
|
158 |
+
|
159 |
+
df.to_csv(file, index=False)
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
#get_metrics()
|
163 |
+
get_metrics(settings=['1+7x3', '2+6x3', '4+4x3', '2+14x3', '4+12x3','8+8x3'], overwrite=True)
|
src/app/constants.py
CHANGED
@@ -17,7 +17,10 @@ summary_text = ('''
|
|
17 |
Hit **Predict** and explore the predictions!
|
18 |
|
19 |
For more **information** about the **model** and **how to provide the
|
20 |
-
molecules**, please visit the **Additional Information** tab.
|
|
|
|
|
|
|
21 |
''')
|
22 |
|
23 |
mhnfs_text =('''
|
|
|
17 |
Hit **Predict** and explore the predictions!
|
18 |
|
19 |
For more **information** about the **model** and **how to provide the
|
20 |
+
molecules**, please visit the **Additional Information** tab.
|
21 |
+
|
22 |
+
If you encounter any problems, we would be glad if you could report them
|
23 |
+
to us: **[email protected]**.
|
24 |
''')
|
25 |
|
26 |
mhnfs_text =('''
|
src/app/layout.py
CHANGED
@@ -159,8 +159,16 @@ class LayoutMaker():
|
|
159 |
label="SMILES input for query molecules",
|
160 |
label_visibility="hidden",
|
161 |
key="query_textbox",
|
162 |
-
value="
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
)
|
165 |
elif input_choice == "CSV upload":
|
166 |
query_file = st.file_uploader(key="query_csv",
|
@@ -194,8 +202,8 @@ class LayoutMaker():
|
|
194 |
label="SMILES input for active support set molecules",
|
195 |
label_visibility="hidden",
|
196 |
key="active_textbox",
|
197 |
-
value="
|
198 |
-
|
199 |
)
|
200 |
elif active_input_choice == "CSV upload":
|
201 |
support_active_file = st.file_uploader(
|
@@ -224,8 +232,8 @@ class LayoutMaker():
|
|
224 |
label="SMILES input for inactive support set molecules",
|
225 |
label_visibility="hidden",
|
226 |
key="inactive_textbox",
|
227 |
-
value="CSc1nc(
|
228 |
-
|
229 |
)
|
230 |
elif inactive_input_choice == "CSV upload":
|
231 |
support_inactive_file = st.file_uploader(
|
|
|
159 |
label="SMILES input for query molecules",
|
160 |
label_visibility="hidden",
|
161 |
key="query_textbox",
|
162 |
+
value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
|
163 |
+
"N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, "
|
164 |
+
"Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, "
|
165 |
+
"CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, "
|
166 |
+
"Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, "
|
167 |
+
"COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, "
|
168 |
+
"Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, "
|
169 |
+
"CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, "
|
170 |
+
"N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, "
|
171 |
+
"COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F"
|
172 |
)
|
173 |
elif input_choice == "CSV upload":
|
174 |
query_file = st.file_uploader(key="query_csv",
|
|
|
202 |
label="SMILES input for active support set molecules",
|
203 |
label_visibility="hidden",
|
204 |
key="active_textbox",
|
205 |
+
value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, "
|
206 |
+
"Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O"
|
207 |
)
|
208 |
elif active_input_choice == "CSV upload":
|
209 |
support_active_file = st.file_uploader(
|
|
|
232 |
label="SMILES input for inactive support set molecules",
|
233 |
label_visibility="hidden",
|
234 |
key="inactive_textbox",
|
235 |
+
value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, "
|
236 |
+
"CSc1nc(C)nc(OCC(=O)O)c1C#N"
|
237 |
)
|
238 |
elif inactive_input_choice == "CSV upload":
|
239 |
support_inactive_file = st.file_uploader(
|