libokj commited on
Commit
e13b250
·
1 Parent(s): 6872415

Upload 174 files

Browse files
Files changed (38) hide show
  1. .gitattributes +2 -0
  2. app.py +978 -210
  3. data/drug_libraries/drugbank_human_py_annot.csv +3 -0
  4. data/target_libraries/ChEMBL33_all_spe_single_prot_info.csv +0 -0
  5. deepscreen/__init__.py +2 -2
  6. deepscreen/__pycache__/__init__.cpython-311.pyc +0 -0
  7. deepscreen/__pycache__/train.cpython-311.pyc +0 -0
  8. deepscreen/data/__pycache__/dti.cpython-311.pyc +0 -0
  9. deepscreen/data/dti.py +67 -23
  10. deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc +0 -0
  11. deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc +0 -0
  12. deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc +0 -0
  13. deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc +0 -0
  14. deepscreen/data/featurizers/categorical.py +15 -15
  15. deepscreen/data/featurizers/monn.py +1 -1
  16. deepscreen/data/featurizers/token.py +18 -14
  17. deepscreen/data/utils/__pycache__/collator.cpython-311.pyc +0 -0
  18. deepscreen/data/utils/__pycache__/label.cpython-311.pyc +0 -0
  19. deepscreen/data/utils/__pycache__/split.cpython-311.pyc +0 -0
  20. deepscreen/data/utils/collator.py +94 -43
  21. deepscreen/data/utils/label.py +1 -0
  22. deepscreen/gui/test.py +114 -0
  23. deepscreen/models/__pycache__/dti.cpython-311.pyc +0 -0
  24. deepscreen/models/dti.py +1 -1
  25. deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc +0 -0
  26. deepscreen/models/metrics/bedroc.py +3 -0
  27. deepscreen/models/metrics/ci.py +39 -0
  28. deepscreen/models/metrics/ef.py +4 -1
  29. deepscreen/models/metrics/hit_rate.py +3 -0
  30. deepscreen/models/metrics/rie.py +9 -6
  31. deepscreen/models/predictors/drug_vqa.py +4 -1
  32. deepscreen/models/predictors/transformer_cpi.py +26 -66
  33. deepscreen/models/predictors/transformer_cpi_2.py +2 -3
  34. deepscreen/utils/__pycache__/hydra.cpython-311.pyc +0 -0
  35. deepscreen/utils/hydra.py +46 -36
  36. resources/checkpoints/deep_dta-binary-general.ckpt +3 -0
  37. resources/checkpoints/deep_dta-binary-general.ckpt.bak +3 -0
  38. resources/vocabs/drug_vqa/combinedVoc-wholeFour.voc +0 -1
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/drug_libraries/drugbank_human_py_annot.csv filter=lfs diff=lfs merge=lfs -text
37
+ resources/checkpoints/deep_dta-binary-general.ckpt.bak filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,53 +1,207 @@
1
- import hydra
 
 
 
 
 
 
 
2
  import os
3
  import pathlib
4
  from pathlib import Path
5
  import sys
6
 
 
 
7
  import gradio as gr
 
8
  import pandas as pd
 
 
 
9
  from rdkit import Chem
10
- from rdkit.Chem import RDConfig, Descriptors, Lipinski, Crippen
 
 
11
 
 
 
 
 
12
  from deepscreen.predict import predict
13
 
14
  sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
15
  import sascorer
16
 
17
  ROOT = Path.cwd()
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # TODO refactor caching with LRU
20
- # MOL_MAP = {}
21
- # def cached_mol(smiles):
22
- # if smiles not in MOL_MAP:
23
- # MOL_MAP.update({smiles: Chem.MolFromSmiles(smiles)})
24
- # return MOL_MAP.get(smiles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def sa_score(row):
28
- return sascorer.calculateScore(Chem.MolFromSmiles(row['X1']))
 
29
 
30
  def mw(row):
31
- return Chem.Descriptors.MolWt(Chem.MolFromSmiles(row['X1']))
 
32
 
33
  def hbd(row):
34
- return Lipinski.NumHDonors(Chem.MolFromSmiles(row['X1']))
 
35
 
36
  def hba(row):
37
- return Lipinski.NumHAcceptors(Chem.MolFromSmiles(row['X1']))
 
38
 
39
  def logp(row):
40
- return Crippen.MolLogP(Chem.MolFromSmiles(row['X1']))
 
41
 
42
  SCORE_MAP = {
43
  'SAscore': sa_score,
44
- 'RAscore': None, # https://github.com/reymond-group/RAscore
45
- 'SCScore': None, # https://pubs.acs.org/doi/10.1021/acs.jcim.7b00622
46
- 'LogP': logp, # https://www.rdkit.org/docs/source/rdkit.Chem.Crippen.html
47
- 'MW': mw, # https://www.rdkit.org/docs/source/rdkit.Chem.Descriptors.html
48
- 'HBD': hbd, # https://www.rdkit.org/docs/source/rdkit.Chem.Lipinski.html
49
- 'HBA': hba, # https://www.rdkit.org/docs/source/rdkit.Chem.Lipinski.html
50
- 'TopoPSA': None, # http://mordred-descriptor.github.io/documentation/master/api/mordred.TopoPSA.html
51
  }
52
 
53
  FILTER_MAP = {
@@ -64,36 +218,36 @@ TASK_MAP = {
64
 
65
  PRESET_MAP = {
66
  'DeepDTA': 'deep_dta',
67
- 'GraphDTA': 'graph_dta'
 
 
 
 
 
 
 
 
68
  }
69
 
70
  TARGET_FAMILY_MAP = {
71
- 'Auto-detect': 'detect',
72
- 'Manually-labelled': 'labelled',
73
- 'Library-labelled': 'labelled',
74
- 'Kinases': 'kinases',
75
- 'Non-kinase enzymes': 'non-kinase_enzymes',
76
- 'Membrane receptors': 'membrane_receptors',
77
- 'Nuclear receptors': 'nuclear_receptors',
78
- 'Ion channels': 'ion_channels',
79
  'Other protein targets': 'other_protein_targets',
80
- 'Kinases (auto-detected)': 'kinases',
81
- 'Non-kinase enzymes (auto-detected)': 'non-kinase_enzymes',
82
- 'Membrane receptors (auto-detected)': 'membrane_receptors',
83
- 'Nuclear receptors (auto-detected)': 'nuclear_receptors',
84
- 'Ion channels (auto-detected)': 'ion_channels',
85
- 'Other protein targets (auto-detected)': 'other_protein_targets',
86
- 'Indiscriminate': 'indiscriminate'
87
  }
88
 
89
  TARGET_LIBRARY_MAP = {
90
- 'STITCH': 'stitch.csv',
91
- 'Drug Repurposing Hub': 'drug_repurposing_hub.csv',
 
92
  }
93
 
94
  DRUG_LIBRARY_MAP = {
95
- 'ChEMBL': 'chembl.csv',
96
- 'DrugBank': 'drug_bank.csv',
97
  }
98
 
99
  MODE_LIST = [
@@ -102,182 +256,796 @@ MODE_LIST = [
102
  'Drug-target pair'
103
  ]
104
 
105
- def predictions_to_df(predictions):
106
- predictions = [pd.DataFrame(prediction) for prediction in predictions]
107
- prediction_df = pd.concat(predictions, ignore_index=True)
108
- return prediction_df
109
-
110
-
111
- def submit_predict(predict_data, task, preset, target_family):
112
- task = TASK_MAP[task]
113
- preset = PRESET_MAP[preset]
114
- target_family = TARGET_FAMILY_MAP[target_family]
115
-
116
- match target_family:
117
- case 'labelled':
118
- pass # target_family_list = ...
119
- case 'detect':
120
- pass # target_family_list = ...
121
- case _:
122
- target_family_list = [target_family]
123
-
124
- prediction_df = pd.DataFrame()
125
- for target_family in target_family_list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  with hydra.initialize(version_base="1.3", config_path="configs", job_name="webserver_inference"):
127
  cfg = hydra.compose(
128
  config_name="webserver_inference",
129
- overrides=[
130
- f"task={task}",
131
- f"preset={preset}",
132
- f"ckpt_path=resources/checkpoints/{preset}-{task}-{target_family}.ckpt",
133
- f"data.data_file='{str(predict_data)}'",
134
- ]
135
- )
136
-
137
- predictions, _ = predict(cfg)
138
- prediction_df = pd.concat([prediction_df, predictions_to_df(predictions)])
139
-
140
- return [gr.DataFrame(value=prediction_df, visible=True), gr.Tabs(selected=1)]
141
-
142
-
143
- # Define a function that takes a CSV output and a list of analytical utility functions as inputs
144
- def submit_report(df, score_list, filter_list):
145
- # Loop through the list of functions and apply them to the dataframe
146
- for filter_name in filter_list:
147
- gr.Info(f'Applying {filter_name}...')
148
-
149
- for score_name in score_list:
150
- gr.Info(f'Calculating {score_name}...')
151
- # Apply the function to the dataframe and assign the result to a new column
152
- df[score_name] = df.apply(SCORE_MAP[score_name], axis=1)
153
- # Return the dataframe as a table
154
- return [gr.DataFrame(visible=False), gr.DataFrame(value=df, visible=True)]
155
-
156
-
157
- def change_layout(mode):
158
- match mode:
159
- case "Drug screening":
160
- return [
161
- gr.Row(visible=True),
162
- gr.Row(visible=False),
163
- gr.Row(visible=False),
164
- gr.Dropdown(choices=[
165
- 'Auto-detect',
166
- 'Kinases',
167
- 'Non-kinase enzymes',
168
- 'Membrane receptors',
169
- 'Nuclear receptors',
170
- 'Ion channels',
171
- 'Other protein targets',
172
- 'Indiscriminate'
173
- ])
174
- ]
175
- case "Drug repurposing":
176
- return [
177
- gr.Row(visible=False),
178
- gr.Row(visible=True),
179
- gr.Row(visible=False),
180
- gr.Dropdown(choices=[
181
- 'Library-labelled',
182
- 'Indiscriminate'
183
- ])
184
- ]
185
- case "Drug-target pair":
186
- return [
187
- gr.Row(visible=False),
188
- gr.Row(visible=False),
189
- gr.Row(visible=True),
190
- gr.Dropdown(choices=[
191
- 'Auto-detect',
192
- 'Manually-labelled',
193
- 'Indiscriminate'
194
- ])
195
- ]
196
-
197
-
198
-
199
- with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm", text_size='md'), title='DeepScreen') as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  with gr.Tabs() as tabs:
201
- with gr.TabItem(label='Inference', id=0) as inference:
202
  gr.Markdown('''
203
- # <center>DeepScreen Inference Service</center>
204
-
205
- DeepScreen for predicting drug-target interaction/binding affinity.
206
- ''')
207
-
208
- mode = gr.Radio(label='Mode', choices=MODE_LIST, value='Drug screening')
209
-
210
- with gr.Row(visible=True) as drug_screening:
211
- with gr.Column():
212
- target = gr.Textbox(label='Target FASTA sequence')
213
- drug_library = gr.Dropdown(label='Drug library', choices=DRUG_LIBRARY_MAP.keys())
214
-
215
- # Modify the pd df directly with df['X2'] = target
216
-
217
- with gr.Row(visible=False) as drug_repurposing:
218
- with gr.Column():
219
- drug = gr.Textbox(label='Drug SMILES sequence')
220
- target_library = gr.Dropdown(label='Target library', choices=TARGET_LIBRARY_MAP.keys())
221
-
222
- # Modify the pd df directly with df['X1'] = drug
223
-
224
-
225
- with gr.Row(visible=False) as drug_target_pair:
226
- predict_data = gr.File(label='Prediction dataset file', file_count="single", type='filepath', height=50)
227
-
228
- with gr.Row(visible=True):
229
- task = gr.Dropdown(list(TASK_MAP.keys()), label='Task')
230
- preset = gr.Dropdown(list(PRESET_MAP.keys()), label='Preset')
231
- target_family = gr.Dropdown(choices=[
232
- 'Auto-detect',
233
- 'Kinases',
234
- 'Non-kinase enzymes',
235
- 'Membrane receptors',
236
- 'Nuclear receptors',
237
- 'Ion channels',
238
- 'Other protein targets',
239
- 'Indiscriminate'
240
- ], label='Target family')
241
-
242
- with gr.Row(visible=True):
243
- predict_btn = gr.Button("Predict", variant="primary")
244
-
245
- with gr.TabItem(label='Report', id=1) as report:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  gr.Markdown('''
247
- # <center>DeepScreen Virtual Screening Report</center>
248
-
249
- Analytic report for virtual screening predictions.
250
- ''')
251
- with gr.Row():
252
- scores = gr.CheckboxGroup(SCORE_MAP.keys(), label='Scores')
253
- filters = gr.CheckboxGroup(FILTER_MAP.keys(), label='Filters')
254
-
255
- with gr.Row():
256
- df_original = gr.Dataframe(type="pandas", interactive=False, height=500, visible=False)
257
- df_report = gr.Dataframe(type="pandas", interactive=False, height=500, visible=False)
258
- with gr.Row():
259
- clear_btn = gr.ClearButton()
260
- analyze_btn = gr.Button("Report", variant="primary")
261
 
262
- mode.change(change_layout, mode, [drug_screening, drug_repurposing, drug_target_pair, target_family], show_progress=False)
263
- predict_btn.click(fn=submit_predict, inputs=[predict_data, task, preset, target_family], outputs=[df_original, tabs])
264
- analyze_btn.click(fn=submit_report, inputs=[df_original, scores, filters], outputs=[df_original, df_report])
265
-
266
-
267
- # js = """function () {
268
- # gradioURL = window.location.href
269
- # if (!gradioURL.endsWith('?__theme=light')) {
270
- # window.location.replace(gradioURL + '?__theme=light');
271
- # }
272
- # }"""
273
- js="""
274
- () => {
275
- document.body.classList.remove('dark');
276
- document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)'
277
- }
278
- """
279
- demo.load(None, None, None, js=js)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
281
 
282
- demo.close()
283
- demo.launch(debug=True)
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import textwrap
4
+ import threading
5
+ from math import pi
6
+ from uuid import uuid4
7
+
8
+ import io
9
  import os
10
  import pathlib
11
  from pathlib import Path
12
  import sys
13
 
14
+ from Bio import AlignIO, SeqIO
15
+ # from email_validator import validate_email
16
  import gradio as gr
17
+ import hydra
18
  import pandas as pd
19
+ import plotly.express as px
20
+ import requests
21
+ from requests.adapters import HTTPAdapter, Retry
22
  from rdkit import Chem
23
+ from rdkit.Chem import RDConfig, Descriptors, Draw, Lipinski, Crippen, PandasTools
24
+ from rdkit.Chem.Scaffolds import MurckoScaffold
25
+ import seaborn as sns
26
 
27
+ import swifter
28
+ from tqdm.auto import tqdm
29
+
30
+ from deepscreen.data.dti import rdkit_canonicalize, validate_seq_str, FASTA_PAT, SMILES_PAT
31
  from deepscreen.predict import predict
32
 
33
  sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
34
  import sascorer
35
 
36
  ROOT = Path.cwd()
37
+ DATA_PATH = Path("./") # Path("/data")
38
+
39
+ DF_FOR_REPORT = pd.DataFrame()
40
+
41
+ pd.set_option('display.float_format', '{:.3f}'.format)
42
+ PandasTools.molRepresentation = 'svg'
43
+ PandasTools.drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
44
+ PandasTools.drawOptions.clearBackground = False
45
+ PandasTools.drawOptions.bondLineWidth = 1.5
46
+ PandasTools.drawOptions.explicitMethyl = True
47
+ PandasTools.drawOptions.singleColourWedgeBonds = True
48
+ PandasTools.drawOptions.useCDKAtomPalette()
49
+ PandasTools.molSize = (128, 128)
50
 
51
+ SESSION = requests.Session()
52
+ ADAPTER = HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]))
53
+ SESSION.mount('http://', ADAPTER)
54
+ SESSION.mount('https://', ADAPTER)
55
+
56
+ # SCHEDULER = BackgroundScheduler()
57
+
58
+ UNIPROT_ENDPOINT = 'https://rest.uniprot.org/uniprotkb/{query}'
59
+ CSS = """
60
+ .help-tip {
61
+ position: absolute;
62
+ display: block;
63
+ top: 0px;
64
+ right: 0px;
65
+ text-align: center;
66
+ background-color: #29b6f6;
67
+ border-radius: 50%;
68
+ width: 24px;
69
+ height: 24px;
70
+ font-size: 12px;
71
+ line-height: 26px;
72
+ cursor: default;
73
+ transition: all 0.5s cubic-bezier(0.55, 0, 0.1, 1);
74
+ }
75
+
76
+ .help-tip:hover {
77
+ cursor: pointer;
78
+ background-color: #ccc;
79
+ }
80
+
81
+ .help-tip:before {
82
+ content: '?';
83
+ font-weight: 700;
84
+ color: #fff;
85
+ z-index: 100;
86
+ }
87
+
88
+ .help-tip p {
89
+ visibility: hidden;
90
+ opacity: 0;
91
+ text-align: left;
92
+ background-color: #039be5;
93
+ padding: 20px;
94
+ width: 300px;
95
+ position: absolute;
96
+ border-radius: 4px;
97
+ right: -4px;
98
+ color: #fff;
99
+ font-size: 13px;
100
+ line-height: normal;
101
+ transform: scale(0.7);
102
+ transform-origin: 100% 0%;
103
+ transition: all 0.5s cubic-bezier(0.55, 0, 0.1, 1);
104
+ z-index: 100;
105
+ }
106
+
107
+ .help-tip:hover p {
108
+ cursor: default;
109
+ visibility: visible;
110
+ opacity: 1;
111
+ transform: scale(1.0);
112
+ }
113
+
114
+ .help-tip p:before {
115
+ position: absolute;
116
+ content: '';
117
+ width: 0;
118
+ height: 0;
119
+ border: 6px solid transparent;
120
+ border-bottom-color: #039be5;
121
+ right: 10px;
122
+ top: -12px;
123
+ }
124
+
125
+ .help-tip p:after {
126
+ width: 100%;
127
+ height: 40px;
128
+ content: '';
129
+ position: absolute;
130
+ top: -5px;
131
+ left: 0;
132
+ }
133
+
134
+ .help-tip a {
135
+ color: #fff;
136
+ font-weight: 700;
137
+ }
138
+
139
+ .help-tip a:hover, .help-tip a:focus {
140
+ color: #fff;
141
+ text-decoration: underline;
142
+ }
143
+
144
+ .upload_button {
145
+ background-color: #008000;
146
+ }
147
+
148
+ .absolute {
149
+ position: absolute;
150
+ }
151
+
152
+ #example {
153
+ padding: 0;
154
+ background: none;
155
+ border: none;
156
+ text-decoration: underline;
157
+ box-shadow: none;
158
+ text-align: left !important;
159
+ display: inline-block !important;
160
+ }
161
+
162
+ footer {
163
+ visibility: hidden
164
+ }
165
+
166
+ """
167
+
168
+
169
+ class HelpTip:
170
+ def __new__(cls, text):
171
+ return gr.HTML(elem_classes="help-tip",
172
+ value=f'<p>{text}</p>'
173
+ )
174
 
175
 
176
  def sa_score(row):
177
+ return sascorer.calculateScore((row['Compound']))
178
+
179
 
180
  def mw(row):
181
+ return Chem.Descriptors.MolWt((row['Compound']))
182
+
183
 
184
  def hbd(row):
185
+ return Lipinski.NumHDonors((row['Compound']))
186
+
187
 
188
  def hba(row):
189
+ return Lipinski.NumHAcceptors((row['Compound']))
190
+
191
 
192
  def logp(row):
193
+ return Crippen.MolLogP((row['Compound']))
194
+
195
 
196
  SCORE_MAP = {
197
  'SAscore': sa_score,
198
+ 'RAscore': None, # https://github.com/reymond-group/RAscore
199
+ 'SCScore': None, # https://pubs.acs.org/doi/10.1021/acs.jcim.7b00622
200
+ 'LogP': logp, # https://www.rdkit.org/docs/source/rdkit.Chem.Crippen.html
201
+ 'MW': mw, # https://www.rdkit.org/docs/source/rdkit.Chem.Descriptors.html
202
+ 'HBD': hbd, # https://www.rdkit.org/docs/source/rdkit.Chem.Lipinski.html
203
+ 'HBA': hba, # https://www.rdkit.org/docs/source/rdkit.Chem.Lipinski.html
204
+ 'TopoPSA': None, # http://mordred-descriptor.github.io/documentation/master/api/mordred.TopoPSA.html
205
  }
206
 
207
  FILTER_MAP = {
 
218
 
219
  PRESET_MAP = {
220
  'DeepDTA': 'deep_dta',
221
+ 'DeepConvDTI': 'deep_conv_dti',
222
+ 'GraphDTA': 'graph_dta',
223
+ 'MGraphDTA': 'm_graph_dta',
224
+ 'HyperAttentionDTI': 'hyper_attention_dti',
225
+ 'MolTrans': 'mol_trans',
226
+ 'TransformerCPI': 'transfomer_cpi',
227
+ 'TransformerCPI2': 'transformer_cpi_2',
228
+ 'DrugBAN': 'drug_ban',
229
+ 'DrugVQA(Seq)': 'drug_vqa'
230
  }
231
 
232
  TARGET_FAMILY_MAP = {
233
+ 'General': 'general',
234
+ 'Kinase': 'kinases',
235
+ 'Non-kinase enzyme': 'non-kinase_enzymes',
236
+ 'Membrane receptor': 'membrane_receptors',
237
+ 'Nuclear receptor': 'nuclear_receptors',
238
+ 'Ion channel': 'ion_channels',
 
 
239
  'Other protein targets': 'other_protein_targets',
 
 
 
 
 
 
 
240
  }
241
 
242
  TARGET_LIBRARY_MAP = {
243
+ # 'STITCH': 'stitch.csv',
244
+ 'ChEMBL33 (all species)': 'ChEMBL33_all_spe_single_prot_info.csv',
245
+ 'DrugBank (Human)': 'drugbank_human_py_annot.csv',
246
  }
247
 
248
  DRUG_LIBRARY_MAP = {
249
+ # 'ChEMBL': 'chembl.csv',
250
+ 'DrugBank (Human)': 'drugbank_human_py_annot.csv',
251
  }
252
 
253
  MODE_LIST = [
 
256
  'Drug-target pair'
257
  ]
258
 
259
+ COLUMN_ALIASES = {
260
+ 'X1': 'Drug SMILES',
261
+ 'X2': 'Target FASTA',
262
+ 'ID1': 'Drug ID',
263
+ 'ID2': 'Target ID',
264
+ }
265
+
266
+ URL = "https://ciddr-lab.ac.cn/deepseqreen"
267
+
268
+
269
+ def validate_columns(df, mandatory_cols):
270
+ missing_cols = [col for col in mandatory_cols if col not in df.columns]
271
+ if missing_cols:
272
+ error_message = (f"The following mandatory columns are missing "
273
+ f"in the uploaded dataset: {str(['X1', 'X2']).strip('[]')}.")
274
+ raise gr.Error(error_message)
275
+
276
+
277
+ def send_email(receiver, msg):
278
+ pass
279
+
280
+
281
+ def submit_predict(predict_filepath, task, preset, target_family, flag, progress=gr.Progress(track_tqdm=True)):
282
+ if flag:
283
+ job_id = flag
284
+ global COLUMN_ALIASES
285
+ task = TASK_MAP[task]
286
+ preset = PRESET_MAP[preset]
287
+ target_family = TARGET_FAMILY_MAP[target_family]
288
+ # email_hash = hashlib.sha256(email.encode()).hexdigest()
289
+ COLUMN_ALIASES = COLUMN_ALIASES | {
290
+ 'Y': 'Actual interaction' if task == 'binary' else 'Actual affinity',
291
+ 'Y^': 'Predicted interaction' if task == 'binary' else 'Predicted affinity'
292
+ }
293
+
294
+ # target_family_list = [target_family]
295
+ # for family in target_family_list:
296
+
297
+ # try:
298
+ prediction_df = pd.DataFrame()
299
  with hydra.initialize(version_base="1.3", config_path="configs", job_name="webserver_inference"):
300
  cfg = hydra.compose(
301
  config_name="webserver_inference",
302
+ overrides=[f"task={task}",
303
+ f"preset={preset}",
304
+ f"ckpt_path=resources/checkpoints/{preset}-{task}-{target_family}.ckpt",
305
+ f"data.data_file='{str(predict_filepath)}'"])
306
+
307
+ predictions, _ = predict(cfg)
308
+ predictions = [pd.DataFrame(prediction) for prediction in predictions]
309
+ prediction_df = pd.concat([prediction_df, pd.concat(predictions, ignore_index=True)])
310
+
311
+ predictions_file = f'{job_id}_predictions.csv'
312
+ prediction_df.to_csv(predictions_file)
313
+
314
+ return [gr.Markdown(visible=True),
315
+ gr.File(predictions_file),
316
+ gr.State(False)]
317
+ #
318
+ # except Exception as e:
319
+ # raise gr.Error(str(e))
320
+
321
+ # email_lock = Path(f"outputs/{email_hash}.lock")
322
+ # with open(email_lock, "w") as file:
323
+ # record = {
324
+ # "email": email,
325
+ # "job_id": job_id
326
+ # }
327
+ # json.dump(record, file)
328
+ # def run_predict():
329
+ # TODO per-user submit usage
330
+ # # email_lock = Path(f"outputs/{email_hash}.lock")
331
+ # # with open(email_lock, "w") as file:
332
+ # # record = {
333
+ # # "email": email,
334
+ # # "job_id": job_id
335
+ # # }
336
+ # # json.dump(record, file)
337
+ #
338
+ # job_lock = DATA_PATH / f"outputs/{job_id}.lock"
339
+ # with open(job_lock, "w") as file:
340
+ # pass
341
+ #
342
+ # try:
343
+ # prediction_df = pd.DataFrame()
344
+ # for family in target_family_list:
345
+ # with hydra.initialize(version_base="1.3", config_path="configs", job_name="webserver_inference"):
346
+ # cfg = hydra.compose(
347
+ # config_name="webserver_inference",
348
+ # overrides=[f"task={task}",
349
+ # f"preset={preset}",
350
+ # f"ckpt_path=resources/checkpoints/{preset}-{task}-{family}.ckpt",
351
+ # f"data.data_file='{str(predict_dataset)}'"])
352
+ #
353
+ # predictions, _ = predict(cfg)
354
+ # predictions = [pd.DataFrame(prediction) for prediction in predictions]
355
+ # prediction_df = pd.concat([prediction_df, pd.concat(predictions, ignore_index=True)])
356
+ # prediction_df.to_csv(f'outputs/{job_id}.csv')
357
+ # # email_lock.unlink()
358
+ # job_lock.unlink()
359
+ #
360
+ # msg = (f'Your DeepSEQcreen prediction job (id: {job_id}) completed successfully. You may retrieve the '
361
+ # f'results and generate an analytical report at {URL} using the job id within 48 hours.')
362
+ # gr.Info(msg)
363
+ # except Exception as e:
364
+ # msg = (f'Your DeepSEQcreen prediction job (id: {job_id}) failed due to an error: "{str(e)}." You may '
365
+ # f'reach out to the author about the error through email ([email protected]).')
366
+ # raise gr.Error(str(e))
367
+ # finally:
368
+ # send_email(email, msg)
369
+ #
370
+ # # Run "predict" asynchronously
371
+ # threading.Thread(target=run_predict).start()
372
+ #
373
+ # msg = (f'Your DeepSEQcreen prediction job (id: {job_id}) started running. You may retrieve the results '
374
+ # f'and generate an analytical report at {URL} using the job id once the job is done. Only one job '
375
+ # f'per user is allowed at the same time.')
376
+ # send_email(email, msg)
377
+
378
+ # # Return the job id first
379
+ # return [
380
+ # gr.Blocks(visible=False),
381
+ # gr.Markdown(f"Your prediction job is running... "
382
+ # f"You may stay on this page or come back later to retrieve the results "
383
+ # f"Once you receive our email notification."),
384
+ # ]
385
+
386
+
387
+ def update_df(file, progress=gr.Progress(track_tqdm=True)):
388
+ global DF_FOR_REPORT
389
+ if file is not None:
390
+ df = pd.read_csv(file)
391
+ if df['X1'].nunique() > 1:
392
+ df['Scaffold SMILES'] = df['X1'].swifter.progress_bar(
393
+ desc=f"Calculating scaffold...").apply(MurckoScaffold.MurckoScaffoldSmilesFromSmiles)
394
+ # Add a new column with RDKit molecule objects
395
+ PandasTools.AddMoleculeColumnToFrame(df, smilesCol='X1', molCol='Compound',
396
+ includeFingerprints=False)
397
+ PandasTools.AddMoleculeColumnToFrame(df, smilesCol='Scaffold SMILES', molCol='Scaffold',
398
+ includeFingerprints=False)
399
+ DF_FOR_REPORT = df.copy()
400
+
401
+ pie_chart = None
402
+ value = None
403
+ if 'Y^' in DF_FOR_REPORT.columns:
404
+ value = 'Y^'
405
+ elif 'Y' in DF_FOR_REPORT.columns:
406
+ value = 'Y'
407
+
408
+ if value:
409
+ if DF_FOR_REPORT['X1'].nunique() > 1 >= DF_FOR_REPORT['X2'].nunique():
410
+ pie_chart = create_pie_chart(DF_FOR_REPORT, category='Scaffold SMILES', value=value, top_k=100)
411
+ elif DF_FOR_REPORT['X2'].nunique() > 1 >= DF_FOR_REPORT['X1'].nunique():
412
+ pie_chart = create_pie_chart(DF_FOR_REPORT, category='Target family', value=value, top_k=100)
413
+
414
+ return create_html_report(DF_FOR_REPORT), pie_chart
415
+ else:
416
+ return gr.HTML(''), gr.Plot()
417
+
418
+
419
+ def create_html_report(df, progress=gr.Progress(track_tqdm=True)):
420
+ cols_left = ['ID2', 'Y', 'Y^', 'ID1', 'Compound', 'Scaffold', 'Scaffold SMILES', ]
421
+ cols_right = ['X1', 'X2']
422
+ cols_left = [col for col in cols_left if col in df.columns]
423
+ cols_right = [col for col in cols_right if col in df.columns]
424
+ df = df[cols_left + (df.columns.drop(cols_left + cols_right).tolist()) + cols_right]
425
+ df['X2'] = df['X2'].apply(wrap_text)
426
+ df.rename(COLUMN_ALIASES, inplace=True)
427
+
428
+ styled_df = df.style
429
+ # styled_df = df.style.format("{:.2f}")
430
+ colors = sns.color_palette('husl', len(df.columns))
431
+ for i, col in enumerate(df.columns):
432
+ if pd.api.types.is_numeric_dtype(df[col]):
433
+ styled_df = styled_df.background_gradient(subset=col, cmap=sns.light_palette(colors[i], as_cmap=True))
434
+
435
+ # Return the DataFrame as HTML
436
+ PandasTools.RenderImagesInAllDataFrames(images=True)
437
+
438
+ html = df.to_html()
439
+ return f'<div style="overflow:auto; height: 500px;">{html}</div>'
440
+ # return gr.HTML(pn.widgets.Tabulator(df).embed())
441
+
442
+
443
+ # def create_pie_chart(df, category, value, top_k):
444
+ # df.rename(COLUMN_ALIASES, inplace=True)
445
+ # # Select the top_k records based on the value_col
446
+ # top_k_df = df.nlargest(top_k, value)
447
+ #
448
+ # # Count the frequency of each unique value in the category_col column
449
+ # category_counts = top_k_df[category].value_counts()
450
+ #
451
+ # # Convert the counts to a DataFrame
452
+ # data = pd.DataFrame({category: category_counts.index, 'value': category_counts.values})
453
+ #
454
+ # # Calculate the angle for each category
455
+ # data['angle'] = data['value']/data['value'].sum() * 2*pi
456
+ #
457
+ # # Assign colors
458
+ # data['color'] = Spectral11[0:len(category_counts)]
459
+ #
460
+ # # Create the plot
461
+ # p = figure(height=350, title="Pie Chart", toolbar_location=None,
462
+ # tools="hover", tooltips="@{}: @value".format(category), x_range=(-0.5, 1.0))
463
+ #
464
+ # p.wedge(x=0, y=1, radius=0.4,
465
+ # start_angle=cumsum('angle', include_zero=True), end_angle=cumsum('angle'),
466
+ # line_color="white", fill_color='color', legend_field=category, source=data)
467
+ #
468
+ # p.axis.axis_label = None
469
+ # p.axis.visible = False
470
+ # p.grid.grid_line_color = None
471
+ #
472
+ # return p
473
+
474
+ def create_pie_chart(df, category, value, top_k):
475
+ df = df.copy()
476
+ df.rename(COLUMN_ALIASES, inplace=True)
477
+ value = COLUMN_ALIASES.get(value, value)
478
+ # Select the top_k records based on the value_col
479
+ top_k_df = df.nlargest(top_k, value)
480
+
481
+ # Count the frequency of each unique value in the category_col column
482
+ category_counts = top_k_df[category].value_counts()
483
+
484
+ # Convert the counts to a DataFrame
485
+ data = pd.DataFrame({category: category_counts.index, 'value': category_counts.values})
486
+
487
+ # Create the plot
488
+ fig = px.pie(data, values='value', names=category, title=f'Top-{top_k} {category} in {value}')
489
+ fig.update_traces(textposition='inside', textinfo='percent+label')
490
+
491
+ return fig
492
+
493
+
494
+ def submit_report(score_list, filter_list, progress=gr.Progress(track_tqdm=True)):
495
+ df = DF_FOR_REPORT.copy()
496
+ try:
497
+ for filter_name in filter_list:
498
+ pass
499
+
500
+ for score_name in score_list:
501
+ df[score_name] = df.swifter.progress_bar(desc=f"Calculating {score_name}").apply(
502
+ SCORE_MAP[score_name], axis=1)
503
+
504
+ pie_chart = None
505
+ value = None
506
+ if 'Y^' in df.columns:
507
+ value = 'Y^'
508
+ elif 'Y' in df.columns:
509
+ value = 'Y'
510
+
511
+ if value:
512
+ if df['X1'].nunique() > 1 >= df['X2'].nunique():
513
+ pie_chart = create_pie_chart(df, category='Scaffold SMILES', value=value, top_k=100)
514
+ elif df['X2'].nunique() > 1 >= df['X1'].nunique():
515
+ pie_chart = create_pie_chart(df, category='Target famiy', value=value, top_k=100)
516
+
517
+ return create_html_report(df), pie_chart
518
+
519
+ except Exception as e:
520
+ raise gr.Error(str(e))
521
+
522
+
523
+ def check_job_status(job_id):
524
+ job_lock = DATA_PATH / f"{job_id}.lock"
525
+ job_file = DATA_PATH / f"{job_id}.csv"
526
+ if job_lock.is_file():
527
+ return {gr.Markdown(f"Your job ({job_id}) is still running... "
528
+ f"You may stay on this page or come back later to retrieve the results "
529
+ f"Once you receive our email notification."),
530
+ None,
531
+ None
532
+ }
533
+ elif job_file.is_file():
534
+ return {gr.Markdown(f"Your job ({job_id}) is done! Redirecting you to generate reports..."),
535
+ gr.Tabs(selected=3),
536
+ gr.File(str(job_lock))}
537
+
538
+
539
+ def wrap_text(text, line_length=60):
540
+ wrapper = textwrap.TextWrapper(width=line_length)
541
+ if text.startswith('>'):
542
+ sections = text.split('>')
543
+ wrapped_sections = []
544
+ for section in sections:
545
+ if not section:
546
+ continue
547
+ lines = section.split('\n')
548
+ seq_header = lines[0]
549
+ wrapped_seq = wrapper.fill(''.join(lines[1:]))
550
+ wrapped_sections.append(f">{seq_header}\n{wrapped_seq}")
551
+ return '\n'.join(wrapped_sections)
552
+ else:
553
+ return wrapper.fill(text)
554
+
555
+
556
+ def unwrap_text(text):
557
+ return text.strip.replece('\n', '')
558
+
559
+
560
+ def smiles_from_sdf(sdf_path):
561
+ with Chem.SDMolSupplier(sdf_path) as suppl:
562
+ return Chem.MolToSmiles(suppl[0])
563
+
564
+
565
+ theme = gr.themes.Base(spacing_size="sm", text_size='md').set(
566
+ background_fill_primary='#dfe6f0',
567
+ background_fill_secondary='#dfe6f0',
568
+ checkbox_label_background_fill='#dfe6f0',
569
+ checkbox_label_background_fill_hover='#dfe6f0',
570
+ checkbox_background_color='white',
571
+ checkbox_border_color='#4372c4',
572
+ border_color_primary='#4372c4',
573
+ border_color_accent='#4372c4',
574
+ button_primary_background_fill='#4372c4',
575
+ button_primary_text_color='white',
576
+ button_secondary_border_color='#4372c4',
577
+ body_text_color='#4372c4',
578
+ block_title_text_color='#4372c4',
579
+ block_label_text_color='#4372c4',
580
+ block_info_text_color='#505358',
581
+ block_border_color=None,
582
+ input_border_color='#4372c4',
583
+ panel_border_color='#4372c4',
584
+ input_background_fill='white',
585
+ code_background_fill='white',
586
+ )
587
+
588
+ with (gr.Blocks(theme=theme, title='DeepScreen', css=CSS) as demo):
589
+ run_state = gr.State(value=False)
590
+ screen_flag = gr.State(value=False)
591
+ identify_flag = gr.State(value=False)
592
+ infer_flag = gr.State(value=False)
593
+
594
  with gr.Tabs() as tabs:
595
+ with gr.TabItem(label='Drug hit screening', id=0):
596
  gr.Markdown('''
597
+ # <center>DeepSEQreen Drug Hit Screening</center>
598
+ <center>
599
+ To predict interactions/binding affinities of a single target against a library of drugs.
600
+ </center>
601
+ ''')
602
+ with gr.Blocks() as screen_block:
603
+ with gr.Column() as screen_page:
604
+ with gr.Row():
605
+ with gr.Column(scale=4, variant='panel'):
606
+ target_fasta = gr.Code(label='Target sequence FASTA',
607
+ interactive=True, lines=5)
608
+ example_target = gr.Button(value='Example: Human MAPK14', elem_id='example')
609
+ with gr.Row():
610
+ with gr.Column(scale=1):
611
+ with gr.Group():
612
+ with gr.Row():
613
+ target_input_type = gr.Radio(label='Target input type',
614
+ choices=['Sequence', 'UniProt ID', 'Gene symbol'],
615
+ value='Sequence')
616
+ target_query = gr.Textbox(label='UniProt ID/Accession',
617
+ visible=False, interactive=True)
618
+ target_upload_btn = gr.UploadButton(label='Upload a FASTA file',
619
+ type='binary',
620
+ visible=True, variant='primary',
621
+ size='lg', elem_classes="upload_button")
622
+ target_query_btn = gr.Button(value='Query the sequence', variant='primary',
623
+ elem_classes='upload_button', visible=False)
624
+
625
+ with gr.Column(scale=1):
626
+ with gr.Row():
627
+ with gr.Group():
628
+ drug_screen_target_family = gr.Dropdown(
629
+ choices=list(TARGET_FAMILY_MAP.keys()),
630
+ value='General',
631
+ label='Target family', interactive=True)
632
+ # with gr.Column(scale=1, min_width=24):
633
+ auto_detect_btn = gr.Button(value='Auto-detect', variant='primary')
634
+ HelpTip(
635
+ "Target amino acid sequence in the FASTA format. Alternatively, you may use a "
636
+ "UniProt ID/accession to query UniProt database for the sequence of your target"
637
+ "of interest. You can also search on databases like UniProt, RCSB PDB, "
638
+ "NCBI Protein for the FASTA string representing your target of interest. If "
639
+ "the input FASTA contains multiple entities, only the first one will be used."
640
+ )
641
+
642
+ with gr.Column(variant='panel'):
643
+ with gr.Group():
644
+ drug_library = gr.Radio(label='Drug library',
645
+ choices=list(DRUG_LIBRARY_MAP.keys()) + ['Upload a drug library'])
646
+ drug_library_upload = gr.File(label='Custom drug library file', visible=True)
647
+
648
+ with gr.Row(variant='panel'):
649
+ drug_screen_task = gr.Radio(list(TASK_MAP.keys()), label='Task',
650
+ value='Drug-target interaction')
651
+
652
+ with gr.Column(scale=2):
653
+ with gr.Group():
654
+ drug_screen_preset = gr.Dropdown(list(PRESET_MAP.keys()), label='Model')
655
+ recommend_btn = gr.Button(value='Recommend a model', variant='primary')
656
+ HelpTip("We recommend the appropriate model for your use case based on model performance "
657
+ "in drug-target interaction or binding affinity prediction "
658
+ "benchmarked on different target families and real-world data scenarios.")
659
+
660
+ # drug_screen_email = gr.Textbox(
661
+ # label='Email (optional)',
662
+ # info="Your email will be used to send you notifications when your job finishes."
663
+ # )
664
+
665
+ with gr.Row(visible=True):
666
+ drug_screen_clr_btn = gr.ClearButton()
667
+ drug_screen_btn = gr.Button(value='SCREEN', variant='primary')
668
+ # TODO Modify the pd df directly with df['X2'] = target
669
+
670
+ screen_data_for_predict = gr.File(visible=False, file_count="single", type='filepath')
671
+ screen_waiting = gr.Markdown("""
672
+ <center>Your job is running... It might take a few minutes.
673
+ When it's done, you will be redirected to the report page.
674
+ Meanwhile, please leave the page on.</center>
675
+ """, visible=False)
676
+
677
+ with gr.TabItem(label='Target protein identification', id=1):
678
  gr.Markdown('''
679
+ # <center>DeepSEQreen Target Protein Identification</center>
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
+ <center>
682
+ To predict interactions/binding affinities of a single drug against a library of targets.
683
+ </center>
684
+ ''')
685
+ with gr.Blocks() as identify_block:
686
+ with gr.Column() as identify_page:
687
+ with gr.Row():
688
+ with gr.Group():
689
+ drug_type = gr.Dropdown(label='Drug input type',
690
+ choices=['SMILES', 'SDF'],
691
+ value='SMILES',
692
+ scale=1,
693
+ interactive=True)
694
+ drug_upload = gr.UploadButton(label='⤒ Upload a file')
695
+ drug_smiles = gr.Code(label='Drug canonical SMILES', interactive=True, scale=5, lines=5)
696
+ with gr.Column(scale=1):
697
+ HelpTip(
698
+ """Drug molecule in the SMILES format. You may search on databases like
699
+ NCBI PubChem, ChEMBL, and DrugBank for the SMILES strings
700
+ representing your drugs of interest.
701
+ """
702
+ )
703
+ example_drug = gr.Button(value='Example: Aspirin', elem_id='example')
704
+
705
+ with gr.Column(variant='panel'):
706
+ with gr.Group():
707
+ target_library = gr.Radio(label='Target library',
708
+ choices=list(TARGET_LIBRARY_MAP.keys()) + ['Upload a target library'])
709
+ target_library_upload = gr.File(label='Custom target library file', visible=True)
710
+
711
+ with gr.Row(visible=True):
712
+ target_identify_task = gr.Dropdown(list(TASK_MAP.keys()), label='Task')
713
+ HelpTip("Choose a preset model for making the predictions.")
714
+ target_identify_preset = gr.Dropdown(list(PRESET_MAP.keys()), label='Preset')
715
+ HelpTip("Choose the protein family of your target.")
716
+ target_identify_target_family = gr.Dropdown(choices=['General'],
717
+ value='General',
718
+ label='Target family')
719
+
720
+ # with gr.Row():
721
+ # target_identify_email = gr.Textbox(
722
+ # label='Email (optional)',
723
+ # info="Your email will be used to send you notifications when your job finishes."
724
+ # )
725
+
726
+ with gr.Row(visible=True):
727
+ target_identify_clr_btn = gr.ClearButton()
728
+ target_identify_btn = gr.Button(value='IDENTIFY', variant='primary')
729
+
730
+ identify_data_for_predict = gr.File(visible=False, file_count="single", type='filepath')
731
+ identify_waiting = gr.Markdown(f"Your job is running... It might take a few minutes."
732
+ f"When it's done, you will be redirected to the report page. "
733
+ f"Meanwhile, please leave the page on.",
734
+ visible=False)
735
+ with gr.TabItem(label='Interaction pair inference', id=2):
736
+ gr.Markdown('''
737
+ # <center>DeepSEQreen Interaction Pair Inference</center>
738
+ <center>
739
+ To predict interactions/binding affinities between any drug-target pairs.
740
+ </center>
741
+ ''')
742
+ with gr.Blocks() as infer_block:
743
+ with gr.Column() as infer_page:
744
+ HelpTip("Upload a custom drug-target pair dataset. See the documentation for details.")
745
+ infer_data_for_predict = gr.File(
746
+ label='Prediction dataset file', file_count="single", type='filepath')
747
+ # TODO example dataset
748
+ # TODO download example dataset
749
+
750
+ with gr.Row(visible=True):
751
+ pair_infer_task = gr.Dropdown(list(TASK_MAP.keys()), label='Task')
752
+ HelpTip("Choose a preset model for making the predictions.")
753
+ pair_infer_preset = gr.Dropdown(list(PRESET_MAP.keys()), label='Preset')
754
+ HelpTip("Choose the protein family of your target.")
755
+ pair_infer_target_family = gr.Dropdown(choices=['General'],
756
+ label='Target family',
757
+ value='General')
758
+
759
+ # with gr.Row():
760
+ # pair_infer_email = gr.Textbox(
761
+ # label='Email (optional)',
762
+ # info="Your email will be used to send you notifications when your job finishes."
763
+ # )
764
+
765
+ with gr.Row(visible=True):
766
+ pair_infer_clr_btn = gr.ClearButton()
767
+ pair_infer_btn = gr.Button(value='INFER', variant='primary')
768
+
769
+ infer_waiting = gr.Markdown(f"Your job is running... It might take a few minutes."
770
+ f"When it's done, you will be redirected to the report page. "
771
+ f"Meanwhile, please leave the page on.",
772
+ visible=False)
773
+
774
+ with gr.TabItem(label='Chemical property report', id=3):
775
+ with gr.Blocks() as report:
776
+ gr.Markdown('''
777
+ # <center>DeepSEQreen Chemical Property Report</center>
778
+ <center>
779
+ To compute chemical properties for the predictions of drug hit screening,
780
+ target protein identification, and interaction pair inference. You may also upload
781
+ your own dataset.
782
+ </center>
783
+ ''')
784
+ with gr.Row():
785
+ file_for_report = gr.File(interactive=True, type='filepath')
786
+ # df_original = gr.Dataframe(type="pandas", interactive=False, visible=False)
787
+ scores = gr.CheckboxGroup(list(SCORE_MAP.keys()), label='Scores')
788
+ filters = gr.CheckboxGroup(list(FILTER_MAP.keys()), label='Filters')
789
+
790
+ with gr.Row():
791
+ clear_btn = gr.ClearButton()
792
+ analyze_btn = gr.Button('REPORT', variant='primary')
793
+
794
+ with gr.Row():
795
+ with gr.Column(scale=3):
796
+ html_report = gr.HTML() # label='Results', visible=True)
797
+ ranking_pie_chart = gr.Plot(visible=False)
798
+
799
+ with gr.Row():
800
+ csv_download_btn = gr.Button('Download report (HTML)', variant='primary')
801
+ html_download_btn = gr.Button('Download raw data (CSV)', variant='primary')
802
+
803
+
804
+ def target_input_type_select(input_type):
805
+ match input_type:
806
+ case 'UniProt ID':
807
+ return [gr.UploadButton(visible=False),
808
+ gr.Textbox(visible=True, label='UniProt ID/accession', info=None, value=''),
809
+ gr.Button(visible=True)]
810
+ case 'Gene symbol':
811
+ return [gr.UploadButton(visible=False),
812
+ gr.Textbox(visible=True, label='Gene symbol/name', info='Organism: human', value=''),
813
+ gr.Button(visible=True)]
814
+ case 'Sequence':
815
+ return [gr.UploadButton(visible=True),
816
+ gr.Textbox(visible=False), gr.Button(visible=False)]
817
+
818
+
819
+ target_input_type.select(fn=target_input_type_select,
820
+ inputs=target_input_type, outputs=[target_upload_btn, target_query, target_query_btn],
821
+ show_progress=False)
822
+
823
+
824
+ def uniprot_query(query, input_type):
825
+ fasta_seq = ''
826
+ query = query.strip()
827
+
828
+ match input_type:
829
+ case 'UniProt ID':
830
+ query = f"{query.strip()}.fasta"
831
+ case 'Gene symbol':
832
+ query = f'search?query=organism_id:9606+AND+gene:{query}&format=fasta'
833
+
834
+ try:
835
+ fasta = SESSION.get(UNIPROT_ENDPOINT.format(query=query))
836
+ fasta.raise_for_status()
837
+ fasta_seq = fasta.text
838
+ except Exception as e:
839
+ raise gr.Warning(f"Failed to query FASTA from UniProt due to {str(e)}")
840
+ finally:
841
+ return fasta_seq
842
+
843
+
844
+ target_upload_btn.upload(fn=lambda x: x.decode(), inputs=target_upload_btn, outputs=target_fasta)
845
+ target_query_btn.click(uniprot_query, inputs=[target_query, target_input_type], outputs=target_fasta)
846
+
847
+ target_fasta.focus(fn=wrap_text, inputs=target_fasta, outputs=target_fasta, show_progress=False)
848
+ target_fasta.blur(fn=wrap_text, inputs=target_fasta, outputs=target_fasta, show_progress=False)
849
+ drug_smiles.focus(fn=wrap_text, inputs=drug_smiles, outputs=drug_smiles, show_progress=False)
850
+ drug_smiles.blur(fn=wrap_text, inputs=drug_smiles, outputs=drug_smiles, show_progress=False)
851
+
852
+
853
+ def example_fill(input_type):
854
+ match input_type:
855
+ case 'UniProt ID':
856
+ query = 'Q16539'
857
+ case 'Gene symbol':
858
+ query = 'MAPK14'
859
+ case _:
860
+ query = ''
861
+ return {target_query: query,
862
+ target_fasta: """
863
+ >sp|Q16539|MK14_HUMAN Mitogen-activated protein kinase 14 OS=Homo sapiens OX=9606 GN=MAPK14 PE=1 SV=3
864
+ MSQERPTFYRQELNKTIWEVPERYQNLSPVGSGAYGSVCAAFDTKTGLRVAVKKLSRPFQ
865
+ SIIHAKRTYRELRLLKHMKHENVIGLLDVFTPARSLEEFNDVYLVTHLMGADLNNIVKCQ
866
+ KLTDDHVQFLIYQILRGLKYIHSADIIHRDLKPSNLAVNEDCELKILDFGLARHTDDEMT
867
+ GYVATRWYRAPEIMLNWMHYNQTVDIWSVGCIMAELLTGRTLFPGTDHIDQLKLILRLVG
868
+ TPGAELLKKISSESARNYIQSLTQMPKMNFANVFIGANPLAVDLLEKMLVLDSDKRITAA
869
+ QALAHAYFAQYHDPDDEPVADPYDQSFESRDLLIDEWKSLTYDEVISFVPPPLDQEEMES
870
+ """}
871
+
872
+
873
+ example_target.click(fn=example_fill, inputs=target_input_type,
874
+ outputs=[target_query, target_fasta], show_progress=False)
875
+ example_drug.click(fn=lambda: 'CC(=O)Oc1ccccc1C(=O)O', outputs=drug_smiles, show_progress=False)
876
+
877
+
878
+ def drug_screen_validate(fasta, library, library_upload, state):
879
+ if not state:
880
+ def process_target_fasta(sequence):
881
+ lines = sequence.strip().split("\n")
882
+ if lines[0].startswith(">"):
883
+ lines = lines[1:]
884
+ return ''.join(lines).split(">")[0]
885
+
886
+ fasta = process_target_fasta(fasta)
887
+ err = validate_seq_str(fasta, FASTA_PAT)
888
+ if err:
889
+ raise gr.Error(f'Found error(s) in your target fasta input: {err}')
890
+
891
+ if library in DRUG_LIBRARY_MAP.keys():
892
+ screen_df = pd.read_csv(Path('data/drug_libraries', DRUG_LIBRARY_MAP[library]))
893
+ else:
894
+ screen_df = pd.read_csv(library_upload)
895
+ validate_columns(screen_df, ['X1'])
896
+
897
+ screen_df['X2'] = fasta
898
+
899
+ job_id = uuid4()
900
+ temp_file = Path(f'{job_id}_temp.csv').resolve()
901
+ screen_df.to_csv(temp_file)
902
+ if temp_file.is_file():
903
+ return {screen_data_for_predict: str(temp_file),
904
+ screen_flag: job_id,
905
+ run_state: job_id}
906
+
907
+ else:
908
+ gr.Warning('You have another prediction job '
909
+ '(drug hit screening, target protein identification, or interation pair inference) '
910
+ 'running in the session right now. '
911
+ 'Please submit another job when your current job has finished.')
912
+ return {screen_flag: False}
913
+
914
+ def target_identify_validate(smiles, library, library_upload, state):
915
+ if not state:
916
+ err = validate_seq_str(smiles, SMILES_PAT)
917
+ if err:
918
+ raise gr.Error(f'Found error(s) in your compound SMILES input: {err}')
919
+
920
+ if library in TARGET_LIBRARY_MAP.keys():
921
+ identify_df = pd.read_csv(TARGET_LIBRARY_MAP['target_library'])
922
+ else:
923
+ identify_df = pd.read_csv(library_upload)
924
+ validate_columns(identify_df, ['X2'])
925
+
926
+ identify_df['X1'] = smiles
927
+
928
+ job_id = uuid4()
929
+ temp_file = Path(f'{job_id}_temp.csv').resolve()
930
+ identify_df.to_csv(temp_file)
931
+ if temp_file.is_file():
932
+ return {identify_data_for_predict: str(temp_file),
933
+ identify_flag: gr.State(job_id),
934
+ run_state: gr.State(job_id)}
935
+
936
+ else:
937
+ gr.Warning('You have another prediction job '
938
+ '(drug hit screening, target protein identification, or interation pair inference) '
939
+ 'running in the session right now. '
940
+ 'Please submit another job when your current job has finished.')
941
+ return {identify_flag: False}
942
+
943
+
944
+ def pair_infer_validate(drug_target_pair_upload, run_state):
945
+ if not run_state:
946
+ df = pd.read_csv(drug_target_pair_upload)
947
+ validate_columns(df, ['X1', 'X2'])
948
+ df['X1_ERR'] = df['X1'].swifter.apply(
949
+ validate_seq_str, regex=SMILES_PAT)
950
+ df['X2_ERR'] = df['X2'].swifter.apply(
951
+ validate_seq_str, regex=FASTA_PAT)
952
+
953
+ if not df['X1_ERR'].isna().all():
954
+ raise gr.Error(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
955
+ if not df['X2_ERR'].isna().all():
956
+ raise gr.Error(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")
957
+
958
+ job_id = uuid4()
959
+ return {infer_flag: gr.State(job_id),
960
+ run_state: gr.State(job_id)}
961
+
962
+ else:
963
+ gr.Warning('You have another prediction job '
964
+ '(drug hit screening, target protein identification, or interation pair inference) '
965
+ 'running in the session right now. '
966
+ 'Please submit another job when your current job has finished.')
967
+ return {infer_flag: False}
968
+
969
+
970
+ drug_screen_btn.click(
971
+ fn=drug_screen_validate,
972
+ inputs=[target_fasta, drug_library, drug_library_upload, run_state], # , drug_screen_email],
973
+ outputs=[screen_data_for_predict, screen_flag, run_state]
974
+ ).then(
975
+ fn=lambda: [gr.Column(visible=False), gr.Markdown(visible=True)],
976
+ outputs=[screen_page, screen_waiting]
977
+ ).then(
978
+ fn=submit_predict,
979
+ inputs=[screen_data_for_predict, drug_screen_task, drug_screen_preset,
980
+ drug_screen_target_family, screen_flag], # , drug_screen_email],
981
+ outputs=[file_for_report, run_state]
982
+ ).then(
983
+ fn=lambda: [gr.Column(visible=True), gr.Markdown(visible=False)],
984
+ outputs=[screen_page, screen_waiting]
985
+ )
986
+
987
+ target_identify_btn.click(
988
+ fn=target_identify_validate,
989
+ inputs=[drug_smiles, target_library, target_library_upload, run_state], # , drug_screen_email],
990
+ outputs=[identify_data_for_predict, identify_flag, run_state]
991
+ ).then(
992
+ fn=lambda: [gr.Column(visible=False), gr.Markdown(visible=True)],
993
+ outputs=[identify_page, identify_waiting]
994
+ ).then(
995
+ fn=submit_predict,
996
+ inputs=[identify_data_for_predict, target_identify_task, target_identify_preset,
997
+ target_identify_target_family, identify_flag], # , target_identify_email],
998
+ outputs=[file_for_report, run_state]
999
+ ).then(
1000
+ fn=lambda: [gr.Column(visible=True), gr.Markdown(visible=False)],
1001
+ outputs=[identify_page, identify_waiting]
1002
+ )
1003
+
1004
+ pair_infer_btn.click(
1005
+ fn=pair_infer_validate,
1006
+ inputs=[infer_data_for_predict, run_state], # , drug_screen_email],
1007
+ outputs=[infer_flag, run_state]
1008
+ ).then(
1009
+ fn=lambda: [gr.Column(visible=False), gr.Markdown(visible=True)],
1010
+ outputs=[infer_page, infer_waiting]
1011
+ ).then(
1012
+ fn=submit_predict,
1013
+ inputs=[infer_data_for_predict, pair_infer_task, pair_infer_preset,
1014
+ pair_infer_target_family, infer_flag], # , pair_infer_email],
1015
+ outputs=[file_for_report, run_state]
1016
+ ).then(
1017
+ fn=lambda: [gr.Column(visible=True), gr.Markdown(visible=False)],
1018
+ outputs=[infer_page, infer_waiting]
1019
+ )
1020
+
1021
+ # TODO background job from these 3 pipelines to update file_for_report
1022
+
1023
+ file_for_report.change(fn=update_df, inputs=file_for_report, outputs=[html_report, ranking_pie_chart])
1024
+
1025
+ analyze_btn.click(fn=submit_report, inputs=[scores, filters], outputs=[html_report, ranking_pie_chart])
1026
+
1027
+ # screen_waiting.change(fn=check_job_status, inputs=run_state, outputs=[pair_waiting, tabs, file_for_report],
1028
+ # every=5)
1029
+ # identify_waiting.change(fn=check_job_status, inputs=run_state, outputs=[identify_waiting, tabs, file_for_report],
1030
+ # every=5)
1031
+ # pair_waiting.change(fn=check_job_status, inputs=run_state, outputs=[pair_waiting, tabs, file_for_report],
1032
+ # every=5)
1033
+
1034
+ # demo.load(None, None, None, js="() => {document.body.classList.remove('dark')}")
1035
+
1036
+ if __name__ == "__main__":
1037
+ screen_block.queue(max_size=2)
1038
+ identify_block.queue(max_size=2)
1039
+ infer_block.queue(max_size=2)
1040
+ report.queue(max_size=20)
1041
 
1042
+ # SCHEDULER.add_job(func=file_cleanup(), trigger="interval", seconds=60)
1043
+ # SCHEDULER.start()
1044
 
1045
+ demo.launch(
1046
+ # debug=True,
1047
+ show_api=False,
1048
+ # favicon_path=,
1049
+ # inline=False
1050
+ debug=True
1051
+ )
data/drug_libraries/drugbank_human_py_annot.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e9e965d0fe672b2d9299bbe507c74eba610b2aaf89326424991ba1c46fdabb3
3
+ size 11047747
data/target_libraries/ChEMBL33_all_spe_single_prot_info.csv ADDED
The diff for this file is too large to render. See raw diff
 
deepscreen/__init__.py CHANGED
@@ -20,9 +20,9 @@ OmegaConf.register_new_resolver("eval", eval)
20
 
21
  def sanitize_path(path_str: str):
22
  """
23
- Sanitize a string for path creation by replacing unsafe characters.
24
  """
25
- return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")
26
 
27
 
28
  OmegaConf.register_new_resolver("sanitize_path", sanitize_path)
 
20
 
21
  def sanitize_path(path_str: str):
22
  """
23
+ Sanitize a string for path creation by replacing unsafe characters and cutting length to 255 (OS limitation).
24
  """
25
+ return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")[:255]
26
 
27
 
28
  OmegaConf.register_new_resolver("sanitize_path", sanitize_path)
deepscreen/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/deepscreen/__pycache__/__init__.cpython-311.pyc and b/deepscreen/__pycache__/__init__.cpython-311.pyc differ
 
deepscreen/__pycache__/train.cpython-311.pyc CHANGED
Binary files a/deepscreen/__pycache__/train.cpython-311.pyc and b/deepscreen/__pycache__/train.cpython-311.pyc differ
 
deepscreen/data/__pycache__/dti.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/__pycache__/dti.cpython-311.pyc and b/deepscreen/data/__pycache__/dti.cpython-311.pyc differ
 
deepscreen/data/dti.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import partial
2
  from numbers import Number
3
  from pathlib import Path
@@ -5,6 +6,7 @@ from typing import Any, Dict, Optional, Sequence, Union, Literal
5
 
6
  from lightning import LightningDataModule
7
  import pandas as pd
 
8
  from sklearn.preprocessing import LabelEncoder
9
  from torch.utils.data import Dataset, DataLoader
10
 
@@ -13,9 +15,33 @@ from deepscreen.utils import get_logger
13
 
14
  log = get_logger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # TODO: save a list of corrupted records
18
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class DTIDataset(Dataset):
21
  def __init__(
@@ -27,6 +53,7 @@ class DTIDataset(Dataset):
27
  protein_featurizer: callable,
28
  thresholds: Optional[Union[Number, Sequence[Number]]] = None,
29
  discard_intermediate: Optional[bool] = False,
 
30
  ):
31
  df = pd.read_csv(
32
  data_path,
@@ -58,40 +85,43 @@ class DTIDataset(Dataset):
58
  # Forward-fill all non-label columns
59
  df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
60
 
 
 
 
61
  if 'Y' in df:
62
- log.info(f"Performing pre-transformation target validation.")
63
  # TODO: check sklearn.utils.multiclass.check_classification_targets
64
  match task:
65
  case 'regression':
66
- assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
67
  f"""`Y` must be numeric for `regression` task,
68
- but it has {set(df['Y'].apply(type))}."""
69
 
70
  case 'binary':
71
  if all(df['Y'].isin([0, 1])):
72
  assert not thresholds, \
73
  f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
74
- but still got `thresholds` {thresholds}.
75
- Double check your choices of `task` and `thresholds` and records in the `Y` column."""
76
  else:
77
  assert thresholds, \
78
  f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
79
- but it has {pd.unique(df['Y'])}.
80
- You must set `thresholds` to discretize continuous labels."""
81
 
82
  case 'multiclass':
83
  assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'
84
 
85
- if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)):
86
  assert not thresholds, \
87
  f"""`Y` is already non-negative integers for
88
- `multiclass` (classification) `task`, but still got `thresholds` {thresholds}.
89
  Double check your choice of `task`, `thresholds` and records in the `Y` column."""
90
  else:
91
  assert thresholds, \
92
  f"""`Y` must be non-negative integers for
93
  `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
94
- You must set `thresholds` to discretize continuous labels."""
95
 
96
  if 'U' in df.columns:
97
  units = df['U']
@@ -107,37 +137,51 @@ class DTIDataset(Dataset):
107
  # Filter out rows with a NaN in Y (missing values)
108
  df.dropna(subset=['Y'], inplace=True)
109
 
110
- log.info(f"Performing post-transformation target validation.")
111
  match task:
112
  case 'regression':
113
  df['Y'] = df['Y'].astype('float32')
114
- assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \
115
  f"""`Y` must be numeric for `regression` task,
116
- but after transformation it still has {set(df['Y'].apply(type))}.
117
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
118
-
119
  case 'binary':
120
  df['Y'] = df['Y'].astype('int')
121
  assert all(df['Y'].isin([0, 1])), \
122
  f"""`Y` must be 0 or 1 for `task=binary`, "
123
  but after transformation it still has {pd.unique(df['Y'])}.
124
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
125
-
126
  case 'multiclass':
127
  df['Y'] = df['Y'].astype('int')
128
- assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \
129
  f"""Y must be non-negative integers for `task=multiclass`
130
  but after transformation it still has {pd.unique(df['Y'])}.
131
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
132
-
133
  target_n_unique = df['Y'].nunique()
134
  assert target_n_unique == num_classes, \
135
  f"""You have set `num_classes` for `task=multiclass` to {num_classes},
136
  but after transformation Y still has {target_n_unique} unique labels.
137
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
138
 
139
- # Indexed protein/FASTA for retrieval metrics
140
- df['IDX'] = LabelEncoder().fit_transform(df['X2'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  self.df = df
143
  self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
@@ -151,13 +195,13 @@ class DTIDataset(Dataset):
151
  return {
152
  'N': i,
153
  'X1': sample['X1'],
154
- 'X1^': self.drug_featurizer(sample['X1']),
155
- 'ID1': sample.get('ID1', sample['X1']),
156
  'X2': sample['X2'],
157
  'X2^': self.protein_featurizer(sample['X2']),
158
- 'ID2': sample.get('ID2', sample['X2']),
159
  'Y': sample.get('Y'),
160
- 'IDX': sample['IDX'],
161
  }
162
 
163
 
 
1
+ import re
2
  from functools import partial
3
  from numbers import Number
4
  from pathlib import Path
 
6
 
7
  from lightning import LightningDataModule
8
  import pandas as pd
9
+ import swifter
10
  from sklearn.preprocessing import LabelEncoder
11
  from torch.utils.data import Dataset, DataLoader
12
 
 
15
 
16
  log = get_logger(__name__)
17
 
18
+ SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]"
19
+ FASTA_PAT = r"[^A-Z*\-]"
20
+
21
+
22
+ def validate_seq_str(seq, regex):
23
+ if seq:
24
+ err_charset = set(re.findall(regex, seq))
25
+ if not err_charset:
26
+ return None
27
+ else:
28
+ return ', '.join(err_charset)
29
+ else:
30
+ return 'Empty string'
31
+
32
 
33
  # TODO: save a list of corrupted records
34
 
35
+ def rdkit_canonicalize(smiles):
36
+ from rdkit import Chem
37
+ try:
38
+ mol = Chem.MolFromSmiles(smiles)
39
+ cano_smiles = Chem.MolToSmiles(mol)
40
+ return cano_smiles
41
+ except Exception as e:
42
+ log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}')
43
+ return smiles
44
+
45
 
46
  class DTIDataset(Dataset):
47
  def __init__(
 
53
  protein_featurizer: callable,
54
  thresholds: Optional[Union[Number, Sequence[Number]]] = None,
55
  discard_intermediate: Optional[bool] = False,
56
+ query: Optional[str] = 'X2'
57
  ):
58
  df = pd.read_csv(
59
  data_path,
 
85
  # Forward-fill all non-label columns
86
  df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0)
87
 
88
+ # TODO potentially allow running through the whole data validation process
89
+ # error = False
90
+
91
  if 'Y' in df:
92
+ log.info(f"Validating labels (`Y`)...")
93
  # TODO: check sklearn.utils.multiclass.check_classification_targets
94
  match task:
95
  case 'regression':
96
+ assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
97
  f"""`Y` must be numeric for `regression` task,
98
+ but it has {set(df['Y'].swifter.apply(type))}."""
99
 
100
  case 'binary':
101
  if all(df['Y'].isin([0, 1])):
102
  assert not thresholds, \
103
  f"""`Y` is already 0 or 1 for `binary` (classification) `task`,
104
+ but still got `thresholds` ({thresholds}).
105
+ Double check your choices of `task` and `thresholds`, and records in the `Y` column."""
106
  else:
107
  assert thresholds, \
108
  f"""`Y` must be 0 or 1 for `binary` (classification) `task`,
109
+ but it has {pd.unique(df['Y'])}.
110
+ You may set `thresholds` to discretize continuous labels.""" # TODO print err idx instead
111
 
112
  case 'multiclass':
113
  assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.'
114
 
115
+ if all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)):
116
  assert not thresholds, \
117
  f"""`Y` is already non-negative integers for
118
+ `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}).
119
  Double check your choice of `task`, `thresholds` and records in the `Y` column."""
120
  else:
121
  assert thresholds, \
122
  f"""`Y` must be non-negative integers for
123
  `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}.
124
+ You must set `thresholds` to discretize continuous labels.""" # TODO print err idx instead
125
 
126
  if 'U' in df.columns:
127
  units = df['U']
 
137
  # Filter out rows with a NaN in Y (missing values)
138
  df.dropna(subset=['Y'], inplace=True)
139
 
 
140
  match task:
141
  case 'regression':
142
  df['Y'] = df['Y'].astype('float32')
143
+ assert all(df['Y'].swifter.apply(lambda x: isinstance(x, Number))), \
144
  f"""`Y` must be numeric for `regression` task,
145
+ but after transformation it still has {set(df['Y'].swifter.apply(type))}.
146
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
147
+ # TODO print err idx instead
148
  case 'binary':
149
  df['Y'] = df['Y'].astype('int')
150
  assert all(df['Y'].isin([0, 1])), \
151
  f"""`Y` must be 0 or 1 for `task=binary`, "
152
  but after transformation it still has {pd.unique(df['Y'])}.
153
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
154
+ # TODO print err idx instead
155
  case 'multiclass':
156
  df['Y'] = df['Y'].astype('int')
157
+ assert all(df['Y'].swifter.apply(lambda x: x.is_integer() and x >= 0)), \
158
  f"""Y must be non-negative integers for `task=multiclass`
159
  but after transformation it still has {pd.unique(df['Y'])}.
160
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
161
+ # TODO print err idx instead
162
  target_n_unique = df['Y'].nunique()
163
  assert target_n_unique == num_classes, \
164
  f"""You have set `num_classes` for `task=multiclass` to {num_classes},
165
  but after transformation Y still has {target_n_unique} unique labels.
166
  Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns."""
167
 
168
+ log.info("Validating SMILES (`X1`)...")
169
+ df['X1_ERR'] = df['X1'].swifter.progress_bar(
170
+ desc="Validating SMILES...").apply(validate_seq_str, regex=SMILES_PAT)
171
+ if not df['X1_ERR'].isna().all():
172
+ raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}")
173
+ df['X1^'] = df['X1'].apply(rdkit_canonicalize) # swifter
174
+
175
+ log.info("Validating FASTA (`X2`)...")
176
+ df['X2'] = df['X2'].str.upper()
177
+ df['X2_ERR'] = df['X2'].swifter.progress_bar(
178
+ desc="Validating FASTA...").apply(validate_seq_str, regex=FASTA_PAT)
179
+ if not df['X2_ERR'].isna().all():
180
+ raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}")
181
+
182
+ # FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate
183
+ if query:
184
+ df['ID^'] = LabelEncoder().fit_transform(df[query])
185
 
186
  self.df = df
187
  self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x)
 
195
  return {
196
  'N': i,
197
  'X1': sample['X1'],
198
+ 'X1^': self.drug_featurizer(sample['X1^']),
199
+ 'ID1': sample.get('ID1'),
200
  'X2': sample['X2'],
201
  'X2^': self.protein_featurizer(sample['X2']),
202
+ 'ID2': sample.get('ID2'),
203
  'Y': sample.get('Y'),
204
+ 'ID^': sample.get('ID^'),
205
  }
206
 
207
 
deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/__init__.cpython-311.pyc differ
 
deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/categorical.cpython-311.pyc differ
 
deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/graph.cpython-311.pyc differ
 
deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc and b/deepscreen/data/featurizers/__pycache__/token.cpython-311.pyc differ
 
deepscreen/data/featurizers/categorical.py CHANGED
@@ -2,20 +2,20 @@ import numpy as np
2
 
3
  # Sets of KNOWN characters in SMILES and FASTA sequences
4
  # Use list instead of set to preserve character order
5
- SMILES_CHARSET = ('#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4',
6
- '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I',
7
- 'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V',
8
- 'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i',
9
- 'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y')
10
- FASTA_CHARSET = ('A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O',
11
- 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z')
12
 
13
  # Check uniqueness, create character-index dicts, and add '?' for unknown characters as index 0
14
- assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.'
15
- SMILES_CHARSET_IDX = {character: index+1 for index, character in enumerate(SMILES_CHARSET)} | {'?': 0}
16
 
17
- assert len(FASTA_CHARSET) == len(set(FASTA_CHARSET)), 'FASTA_CHARSET has duplicate characters.'
18
- FASTA_CHARSET_IDX = {character: index+1 for index, character in enumerate(FASTA_CHARSET)} | {'?': 0}
19
 
20
 
21
  def sequence_to_onehot(sequence: str, charset, max_sequence_length: int):
@@ -40,7 +40,7 @@ def sequence_to_label(sequence: str, charset, max_sequence_length: int):
40
  return label
41
 
42
 
43
- def smiles_to_onehot(smiles: str, smiles_charset=SMILES_CHARSET, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
44
  # assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.'
45
  # onehot = np.zeros((max_sequence_length, len(SMILES_CHARSET_IDX)))
46
  # for index, character in enumerate(smiles[:max_sequence_length]):
@@ -49,7 +49,7 @@ def smiles_to_onehot(smiles: str, smiles_charset=SMILES_CHARSET, max_sequence_le
49
  return sequence_to_onehot(smiles, smiles_charset, max_sequence_length)
50
 
51
 
52
- def smiles_to_label(smiles: str, smiles_charset=SMILES_CHARSET, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
53
  # label = np.zeros(max_sequence_length)
54
  # for index, character in enumerate(smiles[:max_sequence_length]):
55
  # label[index] = SMILES_CHARSET_IDX.get(character, 0)
@@ -57,7 +57,7 @@ def smiles_to_label(smiles: str, smiles_charset=SMILES_CHARSET, max_sequence_len
57
  return sequence_to_label(smiles, smiles_charset, max_sequence_length)
58
 
59
 
60
- def fasta_to_onehot(fasta: str, fasta_charset=FASTA_CHARSET, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
61
  # onehot = np.zeros((max_sequence_length, len(FASTA_CHARSET_IDX)))
62
  # for index, character in enumerate(fasta[:max_sequence_length]):
63
  # onehot[index, FASTA_CHARSET_IDX.get(character, 0)] = 1
@@ -65,7 +65,7 @@ def fasta_to_onehot(fasta: str, fasta_charset=FASTA_CHARSET, max_sequence_length
65
  return sequence_to_onehot(fasta, fasta_charset, max_sequence_length)
66
 
67
 
68
- def fasta_to_label(fasta: str, fasta_charset=FASTA_CHARSET, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
69
  # label = np.zeros(max_sequence_length)
70
  # for index, character in enumerate(fasta[:max_sequence_length]):
71
  # label[index] = FASTA_CHARSET_IDX.get(character, 0)
 
2
 
3
  # Sets of KNOWN characters in SMILES and FASTA sequences
4
  # Use list instead of set to preserve character order
5
+ SMILES_VOCAB = ('#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4',
6
+ '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I',
7
+ 'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V',
8
+ 'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i',
9
+ 'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y')
10
+ FASTA_VOCAB = ('A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O',
11
+ 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z')
12
 
13
  # Check uniqueness, create character-index dicts, and add '?' for unknown characters as index 0
14
+ assert len(SMILES_VOCAB) == len(set(SMILES_VOCAB)), 'SMILES_CHARSET has duplicate characters.'
15
+ SMILES_CHARSET_IDX = {character: index+1 for index, character in enumerate(SMILES_VOCAB)} | {'?': 0}
16
 
17
+ assert len(FASTA_VOCAB) == len(set(FASTA_VOCAB)), 'FASTA_CHARSET has duplicate characters.'
18
+ FASTA_CHARSET_IDX = {character: index+1 for index, character in enumerate(FASTA_VOCAB)} | {'?': 0}
19
 
20
 
21
  def sequence_to_onehot(sequence: str, charset, max_sequence_length: int):
 
40
  return label
41
 
42
 
43
+ def smiles_to_onehot(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
44
  # assert len(SMILES_CHARSET) == len(set(SMILES_CHARSET)), 'SMILES_CHARSET has duplicate characters.'
45
  # onehot = np.zeros((max_sequence_length, len(SMILES_CHARSET_IDX)))
46
  # for index, character in enumerate(smiles[:max_sequence_length]):
 
49
  return sequence_to_onehot(smiles, smiles_charset, max_sequence_length)
50
 
51
 
52
+ def smiles_to_label(smiles: str, smiles_charset=SMILES_VOCAB, max_sequence_length: int = 100): # , in_channels: int = len(SMILES_CHARSET)
53
  # label = np.zeros(max_sequence_length)
54
  # for index, character in enumerate(smiles[:max_sequence_length]):
55
  # label[index] = SMILES_CHARSET_IDX.get(character, 0)
 
57
  return sequence_to_label(smiles, smiles_charset, max_sequence_length)
58
 
59
 
60
+ def fasta_to_onehot(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
61
  # onehot = np.zeros((max_sequence_length, len(FASTA_CHARSET_IDX)))
62
  # for index, character in enumerate(fasta[:max_sequence_length]):
63
  # onehot[index, FASTA_CHARSET_IDX.get(character, 0)] = 1
 
65
  return sequence_to_onehot(fasta, fasta_charset, max_sequence_length)
66
 
67
 
68
+ def fasta_to_label(fasta: str, fasta_charset=FASTA_VOCAB, max_sequence_length: int = 1000): # in_channels: int = len(FASTA_CHARSET)
69
  # label = np.zeros(max_sequence_length)
70
  # for index, character in enumerate(fasta[:max_sequence_length]):
71
  # label[index] = FASTA_CHARSET_IDX.get(character, 0)
deepscreen/data/featurizers/monn.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  from rdkit.Chem import MolFromSmiles
3
 
4
- from deepscreen.data.featurizers.categorical import FASTA_CHARSET, fasta_to_label
5
  from deepscreen.data.featurizers.graph import atom_features, bond_features
6
 
7
 
 
1
  import numpy as np
2
  from rdkit.Chem import MolFromSmiles
3
 
4
+ from deepscreen.data.featurizers.categorical import FASTA_VOCAB, fasta_to_label
5
  from deepscreen.data.featurizers.graph import atom_features, bond_features
6
 
7
 
deepscreen/data/featurizers/token.py CHANGED
@@ -7,13 +7,12 @@ from typing import Optional, List
7
  import numpy as np
8
  from transformers import BertTokenizer
9
 
10
- SMI_REGEX_PATTERN = r"""(
11
- \[[^\]]+\] # match anything inside square brackets
12
- |Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p # match elements
13
- |\(|\) # match parentheses
14
- |\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2} # match various symbols
15
- |[0-9] # match digits
16
- )"""
17
 
18
 
19
  def sequence_to_kmers(sequence, k=3):
@@ -30,17 +29,21 @@ def sequence_to_kmers(sequence, k=3):
30
 
31
  def sequence_to_word_embedding(sequence, model):
32
  """Get protein embedding, infer a list of 3-mers to (num_word, 100) matrix"""
33
- vec = np.zeros((len(sequence), 100))
 
34
  i = 0
35
- for word in sequence:
36
- vec[i,] = model.wv[word]
 
 
 
37
  i += 1
38
  return vec
39
 
40
 
41
  def sequence_to_token_ids(sequence, tokenizer):
42
  token_ids = tokenizer.encode(sequence)
43
- return token_ids
44
 
45
 
46
  # def sequence_to_token_ids(sequence, tokenizer, max_length: int):
@@ -59,14 +62,14 @@ class SmilesTokenizer(BertTokenizer):
59
 
60
  Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer
61
  implementation found in Huggingface's transformers library. It runs a WordPiece tokenization
62
- algorithm over SMILES strings using the tokenisation SMILES regex developed by Schwaller et. al.
63
 
64
  Please see https://github.com/huggingface/transformers
65
  and https://github.com/rxn4chemistry/rxnfp for more details.
66
 
67
  Examples
68
  --------
69
- >>> tokenizer = SmilesTokenizer(vocab_path)
70
  >>> print(tokenizer.encode("CC(=O)OC1=CC=CC=C1C(=O)O"))
71
  [12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 22, 16, 16, 22, 16, 16, 22, 16, 20, 16, 17, 22, 19, 18, 19, 13]
72
 
@@ -81,9 +84,10 @@ class SmilesTokenizer(BertTokenizer):
81
  ----
82
  This class requires huggingface's transformers and tokenizers libraries to be installed.
83
  """
 
84
  def __init__(
85
  self,
86
- vocab_file: str = '',
87
  regex_pattern: str = SMI_REGEX_PATTERN,
88
  # unk_token="[UNK]",
89
  # sep_token="[SEP]",
 
7
  import numpy as np
8
  from transformers import BertTokenizer
9
 
10
+ SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
11
+ # \[[^\]]+\] # match anything inside square brackets
12
+ # |Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p # match elements
13
+ # |\(|\) # match parentheses
14
+ # |\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2} # match various symbols
15
+ # |[0-9] # match digits
 
16
 
17
 
18
  def sequence_to_kmers(sequence, k=3):
 
29
 
30
  def sequence_to_word_embedding(sequence, model):
31
  """Get protein embedding, infer a list of 3-mers to (num_word, 100) matrix"""
32
+ kmers = sequence_to_kmers(sequence)
33
+ vec = np.zeros((len(kmers), 100))
34
  i = 0
35
+ for word in kmers:
36
+ try:
37
+ vec[i,] = model.wv[word]
38
+ except KeyError:
39
+ pass
40
  i += 1
41
  return vec
42
 
43
 
44
  def sequence_to_token_ids(sequence, tokenizer):
45
  token_ids = tokenizer.encode(sequence)
46
+ return np.array(token_ids)
47
 
48
 
49
  # def sequence_to_token_ids(sequence, tokenizer, max_length: int):
 
62
 
63
  Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer
64
  implementation found in Huggingface's transformers library. It runs a WordPiece tokenization
65
+ algorithm over SMILES strings using the tokenization SMILES regex developed by Schwaller et al.
66
 
67
  Please see https://github.com/huggingface/transformers
68
  and https://github.com/rxn4chemistry/rxnfp for more details.
69
 
70
  Examples
71
  --------
72
+ >>> tokenizer = SmilesTokenizer(vocab_path, regex_pattern)
73
  >>> print(tokenizer.encode("CC(=O)OC1=CC=CC=C1C(=O)O"))
74
  [12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 22, 16, 16, 22, 16, 16, 22, 16, 20, 16, 17, 22, 19, 18, 19, 13]
75
 
 
84
  ----
85
  This class requires huggingface's transformers and tokenizers libraries to be installed.
86
  """
87
+
88
  def __init__(
89
  self,
90
+ vocab_file: str = 'resources/vocabs/smiles.txt',
91
  regex_pattern: str = SMI_REGEX_PATTERN,
92
  # unk_token="[UNK]",
93
  # sep_token="[SEP]",
deepscreen/data/utils/__pycache__/collator.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/collator.cpython-311.pyc differ
 
deepscreen/data/utils/__pycache__/label.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/utils/__pycache__/label.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/label.cpython-311.pyc differ
 
deepscreen/data/utils/__pycache__/split.cpython-311.pyc CHANGED
Binary files a/deepscreen/data/utils/__pycache__/split.cpython-311.pyc and b/deepscreen/data/utils/__pycache__/split.cpython-311.pyc differ
 
deepscreen/data/utils/collator.py CHANGED
@@ -72,46 +72,97 @@ def collate_fn(batch, automatic_padding=False, padding_value=0):
72
  return collate(batch, collate_fn_map=COLLATE_FN_MAP)
73
 
74
 
75
- class VariableLengthSequence(torch.Tensor):
76
- """
77
- A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
78
- and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
79
- """
80
-
81
- def __new__(cls, data, lengths):
82
- """
83
- Creates a new VariableLengthSequence object from the given data and lengths.
84
- Args:
85
- data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
86
- lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
87
- Returns:
88
- VariableLengthSequence: A new VariableLengthSequence object.
89
- """
90
- # Check the validity of the inputs
91
- assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
92
- assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
93
- assert data.dim() >= 2, "data must have at least two dimensions"
94
- assert lengths.dim() == 1, "lengths must have one dimension"
95
- assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
96
- assert lengths.min() > 0, "lengths must be positive"
97
- assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
98
-
99
- # Create a new tensor object from data
100
- obj = super().__new__(cls, data)
101
-
102
- # Set the lengths attribute
103
- obj.lengths = lengths
104
-
105
- return obj
106
-
107
- def __repr__(self, *, tensor_contents=None):
108
- """
109
- Returns a string representation of the VariableLengthSequence object.
110
- """
111
- return f"VariableLengthSequence(data={self.data}, lengths={self.lengths})"
112
-
113
- def __reduce_ex__(self, proto):
114
- """
115
- Enables pickling of the VariableLengthSequence object.
116
- """
117
- return type(self), (self.data, self.lengths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return collate(batch, collate_fn_map=COLLATE_FN_MAP)
73
 
74
 
75
+ # class VariableLengthSequence(torch.Tensor):
76
+ # """
77
+ # A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
78
+ # and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
79
+ # """
80
+ #
81
+ # def __new__(cls, data, lengths):
82
+ # """
83
+ # Creates a new VariableLengthSequence object from the given data and lengths.
84
+ # Args:
85
+ # data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
86
+ # lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
87
+ # Returns:
88
+ # VariableLengthSequence: A new VariableLengthSequence object.
89
+ # """
90
+ # # Check the validity of the inputs
91
+ # assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
92
+ # assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
93
+ # assert data.dim() >= 2, "data must have at least two dimensions"
94
+ # assert lengths.dim() == 1, "lengths must have one dimension"
95
+ # assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
96
+ # assert lengths.min() > 0, "lengths must be positive"
97
+ # assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
98
+ #
99
+ # # Create a new tensor object from data
100
+ # obj = super().__new__(cls, data)
101
+ #
102
+ # # Set the lengths attribute
103
+ # obj.lengths = lengths
104
+ #
105
+ # return obj
106
+
107
+
108
+ # class VariableLengthSequence(torch.Tensor):
109
+ # _lengths = torch.Tensor()
110
+ #
111
+ # def __new__(cls, data, lengths, *args, **kwargs):
112
+ # self = super().__new__(cls, data, *args, **kwargs)
113
+ # self.lengths = lengths
114
+ # return self
115
+ #
116
+ # def clone(self, *args, **kwargs):
117
+ # return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone())
118
+ #
119
+ # def new_empty(self, *size):
120
+ # return VariableLengthSequence(super().new_empty(*size), self.lengths)
121
+ #
122
+ # def to(self, *args, **kwargs):
123
+ # return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs))
124
+ #
125
+ # def __format__(self, format_spec):
126
+ # # Convert self to a string or a number here, depending on what you need
127
+ # return self.item().__format__(format_spec)
128
+ #
129
+ # @property
130
+ # def lengths(self):
131
+ # return self._lengths
132
+ #
133
+ # @lengths.setter
134
+ # def lengths(self, lengths):
135
+ # self._lengths = lengths
136
+ #
137
+ # def cpu(self, *args, **kwargs):
138
+ # return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs))
139
+ #
140
+ # def cuda(self, *args, **kwargs):
141
+ # return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs))
142
+ #
143
+ # def pin_memory(self):
144
+ # return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory())
145
+ #
146
+ # def share_memory_(self):
147
+ # super().share_memory_()
148
+ # self.lengths.share_memory_()
149
+ # return self
150
+ #
151
+ # def detach_(self, *args, **kwargs):
152
+ # super().detach_(*args, **kwargs)
153
+ # self.lengths.detach_(*args, **kwargs)
154
+ # return self
155
+ #
156
+ # def detach(self, *args, **kwargs):
157
+ # return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs))
158
+ #
159
+ # def record_stream(self, *args, **kwargs):
160
+ # super().record_stream(*args, **kwargs)
161
+ # self.lengths.record_stream(*args, **kwargs)
162
+ # return self
163
+
164
+
165
+ # @classmethod
166
+ # def __torch_function__(cls, func, types, args=(), kwargs=None):
167
+ # return super().__torch_function__(func, types, args, kwargs) \
168
+ # if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs)
deepscreen/data/utils/label.py CHANGED
@@ -19,6 +19,7 @@ MOLARITY_TO_POTENCY = {
19
  }
20
 
21
 
 
22
  def molar_to_p(labels, units):
23
  assert units in MOLARITY_TO_POTENCY, f"Allowed units: {', '.join(MOLARITY_TO_POTENCY)}."
24
 
 
19
  }
20
 
21
 
22
+ # TODO rewrite for swifter.apply
23
  def molar_to_p(labels, units):
24
  assert units in MOLARITY_TO_POTENCY, f"Allowed units: {', '.join(MOLARITY_TO_POTENCY)}."
25
 
deepscreen/gui/test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio as gr
4
+
5
+ # Use this in a notebook
6
+ root = Path.cwd()
7
+
8
+
9
+ drug_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/drug_encoder").iterdir() if f.suffix == ".yaml"]
10
+
11
+ drug_featurizer_list = [f.stem for f in root.parent.joinpath("configs/model/drug_featurizer").iterdir() if f.suffix == ".yaml"]
12
+
13
+ protein_encoder_list = [f.stem for f in root.parent.joinpath("configs/model/protein_encoder").iterdir() if f.suffix == ".yaml"]
14
+
15
+ protein_featurizer_list = [f.stem for f in root.parent.joinpath("configs/model/protein_featurizer").iterdir() if f.suffix == ".yaml"]
16
+
17
+ classifier_list = [f.stem for f in root.parent.joinpath("configs/model/classifier").iterdir() if f.suffix == ".yaml"]
18
+
19
+ preset_list = [f.stem for f in root.parent.joinpath("configs/model/preset").iterdir() if f.suffix == ".yaml"]
20
+
21
+
22
+ from typing import Optional
23
+
24
+ def drug_target_interaction(
25
+ binary: bool,
26
+ drug_encoder,
27
+ drug_featurizer,
28
+ protein_encoder,
29
+ protein_featurizer,
30
+ classifier,
31
+ preset,) -> Optional[float]:
32
+
33
+
34
+ return 1
35
+
36
+ def drug_encoder(
37
+ binary: bool,
38
+ drug_encoder,
39
+ drug_featurizer,
40
+ protein_encoder,
41
+ protein_featurizer,
42
+ classifier,
43
+ preset,):
44
+
45
+ return
46
+
47
+ def protein_encoder(
48
+ binary: bool,
49
+ drug_encoder,
50
+ drug_featurizer,
51
+ protein_encoder,
52
+ protein_featurizer,
53
+ classifier,
54
+ preset,):
55
+
56
+ return
57
+
58
+ # demo = gr.Interface(
59
+ # fn=drug_target_interaction,
60
+ # inputs=[
61
+ # gr.Radio(["True", "False"]),
62
+ # gr.Dropdown(drug_encoder_list),
63
+ # gr.Dropdown(drug_featurizer_list),
64
+ # gr.Dropdown(protein_encoder_list),
65
+ # gr.Dropdown(protein_featurizer_list),
66
+ # gr.Dropdown(classifier_list),
67
+ # gr.Dropdown(preset_list),
68
+ # ],
69
+ # outputs=["number"],
70
+ # show_error=True,
71
+ #
72
+ # )
73
+ #
74
+ # demo.launch()
75
+
76
+
77
+ from omegaconf import DictConfig, OmegaConf
78
+
79
+ type_to_component_map = {list: gr.Text, int: gr.Number, float: gr.Number}
80
+
81
+
82
+ def get_config_choices(config_path: str):
83
+ return [f.stem for f in Path("../../configs/", config_path).iterdir() if f.suffix == ".yaml"]
84
+
85
+
86
+ def create_blocks_from_config(cfg: DictConfig):
87
+ with gr.Blocks() as blocks:
88
+ for key, value in cfg.items():
89
+ if type(value) in [int, float]:
90
+ component = gr.Number(value=value, label=key, interactive=True)
91
+ if type(value) in [dict, DictConfig]:
92
+ with gr.Tab(label=key):
93
+ component = create_blocks_from_config(value)
94
+ else:
95
+ component = gr.Text(value=value, label=key, interactive=True)
96
+ return blocks
97
+
98
+
99
+ def create_interface_from_config(fn: callable, cfg: DictConfig):
100
+ inputs = []
101
+
102
+ for key, value in OmegaConf.to_object(cfg).items():
103
+ component = type_to_component_map.get(type(value), gr.Text)
104
+ inputs.append(component(value=value, label=key, interactive=True))
105
+
106
+ interface = gr.Interface(fn=fn, inputs=inputs, outputs="label")
107
+
108
+ return interface
109
+
110
+
111
+ import hydra
112
+
113
+ with hydra.initialize(version_base=None, config_path="../../configs/"):
114
+ cfg = hydra.compose("train")
deepscreen/models/__pycache__/dti.cpython-311.pyc CHANGED
Binary files a/deepscreen/models/__pycache__/dti.cpython-311.pyc and b/deepscreen/models/__pycache__/dti.cpython-311.pyc differ
 
deepscreen/models/dti.py CHANGED
@@ -66,7 +66,7 @@ class DTILightningModule(LightningModule):
66
  def forward(self, batch):
67
  output = self.predictor(batch['X1^'], batch['X2^'])
68
  target = batch.get('Y')
69
- indexes = batch.get('IDX')
70
  preds = None
71
  loss = None
72
 
 
66
  def forward(self, batch):
67
  output = self.predictor(batch['X1^'], batch['X2^'])
68
  target = batch.get('Y')
69
+ indexes = batch.get('ID^')
70
  preds = None
71
  loss = None
72
 
deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc CHANGED
Binary files a/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc and b/deepscreen/models/loss/__pycache__/multitask_loss.cpython-311.pyc differ
 
deepscreen/models/metrics/bedroc.py CHANGED
@@ -40,3 +40,6 @@ class BEDROC(RetrievalMetric):
40
  rie_max = (1 - exp_a ** (-r_a)) / (r_a * (1 - exp_a ** (-1)))
41
 
42
  return (rie - rie_min) / (rie_max - rie_min)
 
 
 
 
40
  rie_max = (1 - exp_a ** (-r_a)) / (r_a * (1 - exp_a ** (-1)))
41
 
42
  return (rie - rie_min) / (rie_max - rie_min)
43
+
44
+ def plot(self, val=None, ax=None):
45
+ return self._plot(val, ax)
deepscreen/models/metrics/ci.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+ from torchmetrics.utilities.checks import _check_same_shape
4
+ from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
5
+
6
+ if not _MATPLOTLIB_AVAILABLE:
7
+ __doctest_skip__ = ["ConcordanceIndex.plot"]
8
+
9
+
10
+ class ConcordanceIndex(Metric):
11
+ is_differentiable: bool = False
12
+ higher_is_better: bool = True
13
+ full_state_update: bool = False
14
+ plot_lower_bound: float = 0.5
15
+ plot_upper_bound: float = 1.0
16
+
17
+ def __init__(self, dist_sync_on_step=False):
18
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
19
+
20
+ self.add_state("num_concordant", default=torch.tensor(0), dist_reduce_fx="sum")
21
+ self.add_state("num_valid", default=torch.tensor(0), dist_reduce_fx="sum")
22
+
23
+ def update(self, preds: torch.Tensor, target: torch.Tensor):
24
+ _check_same_shape(preds, target)
25
+
26
+ g = preds.unsqueeze(-1) - preds
27
+ g = (g == 0) * 0.5 + (g > 0)
28
+
29
+ f = (target.unsqueeze(-1) - target) > 0
30
+ f = torch.tril(f, diagonal=0)
31
+
32
+ self.num_concordant += torch.sum(torch.mul(g, f)).long()
33
+ self.num_valid += torch.sum(f).long()
34
+
35
+ def compute(self):
36
+ return torch.where(self.num_valid == 0, 0.0, self.num_concordant / self.num_valid)
37
+
38
+ def plot(self, val=None, ax=None):
39
+ return self._plot(val, ax)
deepscreen/models/metrics/ef.py CHANGED
@@ -5,7 +5,7 @@ from torchmetrics.retrieval.base import RetrievalMetric
5
  from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
6
 
7
 
8
- class EF(RetrievalMetric):
9
  is_differentiable: bool = False
10
  higher_is_better: bool = True
11
  full_state_update: bool = False
@@ -29,3 +29,6 @@ class EF(RetrievalMetric):
29
  hits_total = target.sum()
30
 
31
  return hits_sampled / (hits_total * self.alpha)
 
 
 
 
5
  from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
6
 
7
 
8
+ class EnrichmentFactor(RetrievalMetric):
9
  is_differentiable: bool = False
10
  higher_is_better: bool = True
11
  full_state_update: bool = False
 
29
  hits_total = target.sum()
30
 
31
  return hits_sampled / (hits_total * self.alpha)
32
+
33
+ def plot(self, val=None, ax=None):
34
+ return self._plot(val, ax)
deepscreen/models/metrics/hit_rate.py CHANGED
@@ -31,3 +31,6 @@ class HitRate(RetrievalMetric):
31
  hits_sampled = target[idx].sum()
32
 
33
  return hits_sampled / n_sampled
 
 
 
 
31
  hits_sampled = target[idx].sum()
32
 
33
  return hits_sampled / n_sampled
34
+
35
+ def plot(self, val=None, ax=None):
36
+ return self._plot(val, ax)
deepscreen/models/metrics/rie.py CHANGED
@@ -4,6 +4,13 @@ from torchmetrics.retrieval.base import RetrievalMetric
4
  from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
5
 
6
 
 
 
 
 
 
 
 
7
  class RIE(RetrievalMetric):
8
  is_differentiable: bool = False
9
  higher_is_better: bool = True
@@ -33,9 +40,5 @@ class RIE(RetrievalMetric):
33
 
34
  return calc_rie(n_total, active_ranks, r_a, exp_a)
35
 
36
-
37
- def calc_rie(n_total, active_ranks, r_a, exp_a):
38
- numerator = (exp_a ** (- active_ranks / n_total)).sum()
39
- denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1)
40
-
41
- return numerator / (r_a * denominator)
 
4
  from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
5
 
6
 
7
+ def calc_rie(n_total, active_ranks, r_a, exp_a):
8
+ numerator = (exp_a ** (- active_ranks / n_total)).sum()
9
+ denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1)
10
+
11
+ return numerator / (r_a * denominator)
12
+
13
+
14
  class RIE(RetrievalMetric):
15
  is_differentiable: bool = False
16
  higher_is_better: bool = True
 
40
 
41
  return calc_rie(n_total, active_ranks, r_a, exp_a)
42
 
43
+ def plot(self, val=None, ax=None):
44
+ return self._plot(val, ax)
 
 
 
 
deepscreen/models/predictors/drug_vqa.py CHANGED
@@ -1,10 +1,11 @@
1
  from math import floor
 
2
  from typing import Literal
3
 
 
4
  import torch.nn as nn
5
  import torch
6
  import torch.nn.functional as F
7
- # from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
8
 
9
 
10
  def conv(in_channels, out_channels, kernel_size, conv_dim, stride=1):
@@ -170,6 +171,8 @@ class DrugVQA(nn.Module):
170
  return nn.Sequential(*layers)
171
 
172
  def forward(self, enc_drug, enc_protein):
 
 
173
  smile_embed = self.embeddings(enc_drug.long())
174
  # self.hidden_state = tuple(hidden_state.to(smile_embed).detach() for hidden_state in self.hidden_state)
175
  outputs, hidden_state = self.lstm(smile_embed)
 
1
  from math import floor
2
+ import re
3
  from typing import Literal
4
 
5
+ import numpy as np
6
  import torch.nn as nn
7
  import torch
8
  import torch.nn.functional as F
 
9
 
10
 
11
  def conv(in_channels, out_channels, kernel_size, conv_dim, stride=1):
 
171
  return nn.Sequential(*layers)
172
 
173
  def forward(self, enc_drug, enc_protein):
174
+ enc_drug, _ = enc_drug
175
+ enc_protein, _ = enc_protein
176
  smile_embed = self.embeddings(enc_drug.long())
177
  # self.hidden_state = tuple(hidden_state.to(smile_embed).detach() for hidden_state in self.hidden_state)
178
  outputs, hidden_state = self.lstm(smile_embed)
deepscreen/models/predictors/transformer_cpi.py CHANGED
@@ -9,8 +9,7 @@ class TransformerCPI(nn.Module):
9
  super().__init__()
10
 
11
  self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout)
12
- self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention,
13
- PositionwiseFeedforward, dropout)
14
  self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim))
15
  self.init_weight()
16
 
@@ -23,18 +22,24 @@ class TransformerCPI(nn.Module):
23
  # adj = [batch,num_node, num_node]
24
  support = torch.matmul(input, self.weight)
25
  # support =[batch,num_node,atom_dim]
26
- output = torch.bmm(adj, support)
27
  # output = [batch,num_node,atom_dim]
28
  return output
29
 
30
- def forward(self, compound, adj, protein, atom_num, protein_num):
 
 
 
 
31
  # compound = [batch,atom_num, atom_dim]
32
  # adj = [batch,atom_num, atom_num]
33
  # protein = [batch,protein len, 100]
34
- compound_max_len = compound.shape[1]
35
- protein_max_len = protein.shape[1]
36
- compound_mask, protein_mask = self.make_masks(atom_num, protein_num, compound_max_len, protein_max_len)
37
- compound = self.gcn(compound, adj)
 
 
38
  # compound = torch.unsqueeze(compound, dim=0)
39
  # compound = [batch size=1 ,atom_num, atom_dim]
40
 
@@ -48,54 +53,6 @@ class TransformerCPI(nn.Module):
48
  # out = torch.squeeze(out, dim=0)
49
  return out
50
 
51
- @staticmethod
52
- def make_masks(atom_num, protein_num, compound_max_len, protein_max_len):
53
- n_atom = len(atom_num) # batch size
54
- compound_mask = torch.zeros((n_atom, compound_max_len))
55
- protein_mask = torch.zeros((n_atom, protein_max_len))
56
- for i in range(n_atom):
57
- compound_mask[i, :atom_num[i]] = 1
58
- protein_mask[i, :protein_num[i]] = 1
59
- compound_mask = compound_mask.unsqueeze(1).unsqueeze(3)
60
- protein_mask = protein_mask.unsqueeze(1).unsqueeze(2)
61
- return compound_mask, protein_mask
62
-
63
- @staticmethod
64
- def pack(atoms, adjs, proteins, labels):
65
- atoms_len = 0
66
- proteins_len = 0
67
- N = len(atoms)
68
-
69
- atom_num = []
70
- for atom in atoms:
71
- atom_num.append(atom.shape[0])
72
- if atom.shape[0] >= atoms_len:
73
- atoms_len = atom.shape[0]
74
-
75
- protein_num = []
76
- for protein in proteins:
77
- protein_num.append(protein.shape[0])
78
- if protein.shape[0] >= proteins_len:
79
- proteins_len = protein.shape[0]
80
-
81
- atoms_new = torch.zeros((N, atoms_len, 34))
82
- for i, atom in enumerate(atoms):
83
- a_len = atom.shape[0]
84
- atoms_new[i, :a_len, :] = atom
85
-
86
- adjs_new = torch.zeros((N, atoms_len, atoms_len))
87
- for i, adj in adjs:
88
- a_len = adj.shape[0]
89
- adj = adj + torch.eye(a_len)
90
- adjs_new[i, :a_len, :a_len] = adj
91
-
92
- proteins_new = torch.zeros((N, proteins_len, 100))
93
- for i, protein in enumerate(proteins):
94
- a_len = protein.shape[0]
95
- proteins_new[i, :a_len, :] = protein
96
-
97
- return atoms_new, adjs_new, proteins_new, atom_num, protein_num
98
-
99
 
100
  class SelfAttention(nn.Module):
101
  def __init__(self, hidden_dim, n_heads, dropout):
@@ -114,7 +71,7 @@ class SelfAttention(nn.Module):
114
 
115
  self.do = nn.Dropout(dropout)
116
 
117
- self.scale = torch.sqrt(torch.FloatTensor([hidden_dim // n_heads]))
118
 
119
  def forward(self, query, key, value, mask=None):
120
  bsz = query.shape[0]
@@ -164,7 +121,6 @@ class SelfAttention(nn.Module):
164
 
165
  class Encoder(nn.Module):
166
  """protein feature extraction."""
167
-
168
  def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout):
169
  super().__init__()
170
 
@@ -176,7 +132,7 @@ class Encoder(nn.Module):
176
  self.dropout = dropout
177
  self.n_layers = n_layers
178
  # self.pos_embedding = nn.Embedding(1000, hidden_dim)
179
- self.scale = torch.sqrt(torch.FloatTensor([0.5]))
180
  self.convs = nn.ModuleList(
181
  [nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in
182
  range(self.n_layers)]) # convolutional layers
@@ -189,7 +145,7 @@ class Encoder(nn.Module):
189
  # pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1)
190
  # protein = protein + self.pos_embedding(pos)
191
  # protein = [batch size, protein len,protein_dim]
192
- conv_input = self.fc(protein)
193
  # conv_input=[batch size,protein len,hid dim]
194
  # permute for convolutional layer
195
  conv_input = conv_input.permute(0, 2, 1)
@@ -239,7 +195,9 @@ class PositionwiseFeedforward(nn.Module):
239
 
240
 
241
  class DecoderLayer(nn.Module):
242
- def __init__(self, hidden_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout):
 
 
243
  super().__init__()
244
  self.ln = nn.LayerNorm(hidden_dim)
245
  self.sa = self_attention(hidden_dim, n_heads, dropout)
@@ -262,8 +220,10 @@ class DecoderLayer(nn.Module):
262
  class Decoder(nn.Module):
263
  """ compound feature extraction."""
264
 
265
- def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, decoder_layer, self_attention,
266
- positionwise_feedforward, dropout):
 
 
267
  super().__init__()
268
  self.ln = nn.LayerNorm(hidden_dim)
269
  self.output_dim = atom_dim
@@ -277,12 +237,12 @@ class Decoder(nn.Module):
277
  self.dropout = dropout
278
  self.sa = self_attention(hidden_dim, n_heads, dropout)
279
  self.layers = nn.ModuleList(
280
- [decoder_layer(hidden_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout)
281
  for _ in range(n_layers)])
282
  self.ft = nn.Linear(atom_dim, hidden_dim)
283
  self.do = nn.Dropout(dropout)
284
  self.fc_1 = nn.Linear(hidden_dim, 256)
285
- self.fc_2 = nn.Linear(256, 2)
286
  self.gn = nn.GroupNorm(8, 256)
287
 
288
  def forward(self, trg, src, trg_mask=None, src_mask=None):
@@ -297,7 +257,7 @@ class Decoder(nn.Module):
297
  norm = F.softmax(norm, dim=1) # norm = [batch size,compound len]
298
  # trg = torch.squeeze(trg,dim=0)
299
  # norm = torch.squeeze(norm,dim=0)
300
- sum = torch.zeros((trg.shape[0], self.hidden_dim))
301
  for i in range(norm.shape[0]):
302
  for j in range(norm.shape[1]):
303
  v = trg[i, j,]
 
9
  super().__init__()
10
 
11
  self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout)
12
+ self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout)
 
13
  self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim))
14
  self.init_weight()
15
 
 
22
  # adj = [batch,num_node, num_node]
23
  support = torch.matmul(input, self.weight)
24
  # support =[batch,num_node,atom_dim]
25
+ output = torch.bmm(adj.float(), support.float())
26
  # output = [batch,num_node,atom_dim]
27
  return output
28
 
29
+ def forward(self, compound, protein):
30
+ compound, adj = compound
31
+ compound, compound_lengths = compound
32
+ adj, _ = adj
33
+ protein, protein_lengths = protein
34
  # compound = [batch,atom_num, atom_dim]
35
  # adj = [batch,atom_num, atom_num]
36
  # protein = [batch,protein len, 100]
37
+ compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1)
38
+ protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1)
39
+ compound_mask = compound_mask.unsqueeze(1).unsqueeze(3)
40
+ protein_mask = protein_mask.unsqueeze(1).unsqueeze(2)
41
+
42
+ compound = self.gcn(compound.float(), adj)
43
  # compound = torch.unsqueeze(compound, dim=0)
44
  # compound = [batch size=1 ,atom_num, atom_dim]
45
 
 
53
  # out = torch.squeeze(out, dim=0)
54
  return out
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  class SelfAttention(nn.Module):
58
  def __init__(self, hidden_dim, n_heads, dropout):
 
71
 
72
  self.do = nn.Dropout(dropout)
73
 
74
+ self.scale = (hidden_dim // n_heads) ** 0.5
75
 
76
  def forward(self, query, key, value, mask=None):
77
  bsz = query.shape[0]
 
121
 
122
  class Encoder(nn.Module):
123
  """protein feature extraction."""
 
124
  def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout):
125
  super().__init__()
126
 
 
132
  self.dropout = dropout
133
  self.n_layers = n_layers
134
  # self.pos_embedding = nn.Embedding(1000, hidden_dim)
135
+ self.scale = 0.5 ** 0.5
136
  self.convs = nn.ModuleList(
137
  [nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in
138
  range(self.n_layers)]) # convolutional layers
 
145
  # pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1)
146
  # protein = protein + self.pos_embedding(pos)
147
  # protein = [batch size, protein len,protein_dim]
148
+ conv_input = self.fc(protein.float())
149
  # conv_input=[batch size,protein len,hid dim]
150
  # permute for convolutional layer
151
  conv_input = conv_input.permute(0, 2, 1)
 
195
 
196
 
197
  class DecoderLayer(nn.Module):
198
+ def __init__(self, hidden_dim, n_heads, pf_dim, dropout,
199
+ self_attention=SelfAttention,
200
+ positionwise_feedforward=PositionwiseFeedforward):
201
  super().__init__()
202
  self.ln = nn.LayerNorm(hidden_dim)
203
  self.sa = self_attention(hidden_dim, n_heads, dropout)
 
220
  class Decoder(nn.Module):
221
  """ compound feature extraction."""
222
 
223
+ def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout,
224
+ decoder_layer=DecoderLayer,
225
+ self_attention=SelfAttention,
226
+ positionwise_feedforward=PositionwiseFeedforward):
227
  super().__init__()
228
  self.ln = nn.LayerNorm(hidden_dim)
229
  self.output_dim = atom_dim
 
237
  self.dropout = dropout
238
  self.sa = self_attention(hidden_dim, n_heads, dropout)
239
  self.layers = nn.ModuleList(
240
+ [decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward)
241
  for _ in range(n_layers)])
242
  self.ft = nn.Linear(atom_dim, hidden_dim)
243
  self.do = nn.Dropout(dropout)
244
  self.fc_1 = nn.Linear(hidden_dim, 256)
245
+ # self.fc_2 = nn.Linear(256, 2)
246
  self.gn = nn.GroupNorm(8, 256)
247
 
248
  def forward(self, trg, src, trg_mask=None, src_mask=None):
 
257
  norm = F.softmax(norm, dim=1) # norm = [batch size,compound len]
258
  # trg = torch.squeeze(trg,dim=0)
259
  # norm = torch.squeeze(norm,dim=0)
260
+ sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device)
261
  for i in range(norm.shape[0]):
262
  for j in range(norm.shape[1]):
263
  v = trg[i, j,]
deepscreen/models/predictors/transformer_cpi_2.py CHANGED
@@ -23,9 +23,8 @@ class TransformerCPI2(nn.Module):
23
  # adj_mat = [batch_size, atom_num, atom_num]
24
  # enc_protein = [batch_size, protein_len, 768]
25
  compound, adj = compound
26
-
27
  compound, compound_lengths = compound
28
- adj, adj_lengths = adj
29
  protein, protein_lengths = protein
30
 
31
  # Add a global/master node to the compound
@@ -99,5 +98,5 @@ class Decoder(nn.Module):
99
  tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [batch_size, compound_len, hid_dim]
100
  x = tgt[:, 0, :]
101
  label = F.relu(self.fc_1(x))
102
- label = self.fc_2(label)
103
  return label
 
23
  # adj_mat = [batch_size, atom_num, atom_num]
24
  # enc_protein = [batch_size, protein_len, 768]
25
  compound, adj = compound
26
+ adj, _ = adj
27
  compound, compound_lengths = compound
 
28
  protein, protein_lengths = protein
29
 
30
  # Add a global/master node to the compound
 
98
  tgt = tgt.permute(1, 0, 2).contiguous() # tgt = [batch_size, compound_len, hid_dim]
99
  x = tgt[:, 0, :]
100
  label = F.relu(self.fc_1(x))
101
+ # label = self.fc_2(label)
102
  return label
deepscreen/utils/__pycache__/hydra.cpython-311.pyc CHANGED
Binary files a/deepscreen/utils/__pycache__/hydra.cpython-311.pyc and b/deepscreen/utils/__pycache__/hydra.cpython-311.pyc differ
 
deepscreen/utils/hydra.py CHANGED
@@ -1,8 +1,11 @@
 
1
  from pathlib import Path
2
  import re
 
3
  from typing import Any, Tuple
4
 
5
  import pandas as pd
 
6
  from hydra.core.hydra_config import HydraConfig
7
  from hydra.core.utils import _save_config
8
  from hydra.experimental.callbacks import Callback
@@ -21,21 +24,24 @@ class CSVExperimentSummary(Callback):
21
  self.filename = filename
22
  self.prefix = prefix if isinstance(prefix, str) else tuple(prefix)
23
  self.input_experiment_summary = None
 
24
 
25
  def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
26
- if config.hydra.get('overrides'):
27
- if config.hydra.overrides.task:
28
- for i, override in enumerate(config.hydra.overrides.task):
29
- if override.startswith("ckpt_path"):
30
- ckpt_path = override.split('=', 1)[1]
31
- if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
32
- config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
33
- break
34
- elif config.hydra.sweeper.get('params'):
35
- if config.hydra.sweeper.params.get('ckpt_path'):
36
- ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"")
37
- if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
38
- config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
 
 
39
 
40
  def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None:
41
  # Skip callback if job is DDP subprocess
@@ -43,6 +49,7 @@ class CSVExperimentSummary(Callback):
43
  return
44
 
45
  try:
 
46
  if config.hydra.mode == RunMode.RUN:
47
  summary_file_path = Path(config.hydra.run.dir) / self.filename
48
  elif config.hydra.mode == RunMode.MULTIRUN:
@@ -56,21 +63,23 @@ class CSVExperimentSummary(Callback):
56
  summary_df = pd.DataFrame()
57
 
58
  # Add job and override info
59
- override_dict = dict(override.split('=', 1) for override in job_return.overrides)
60
- override_dict['job_status'] = job_return.status.name
 
 
 
 
61
 
62
  # Add checkpoint info
63
- if override_dict.get('ckpt_path'):
64
- override_dict['ckpt_path'] = str(override_dict['ckpt_path']).strip("'\"")
65
 
66
- if job_return.cfg.get('ckpt_path'):
67
- ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"")
68
- if Path(ckpt_path).is_file():
69
- if override_dict.get('ckpt_path') and ckpt_path != override_dict['ckpt_path']:
70
- override_dict['previous_ckpt_path'] = override_dict['ckpt_path']
71
- override_dict['ckpt_path'] = ckpt_path
72
-
73
- override_dict['epoch'] = int(re.search(r'epoch_(\d+)', override_dict['ckpt_path']).group(1))
74
 
75
  # Add metrics info
76
  metrics_df = pd.DataFrame()
@@ -79,22 +88,22 @@ class CSVExperimentSummary(Callback):
79
  csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
80
  if csv_metrics_path.is_file():
81
  log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
82
- # Use only columns that start with the specified prefix
83
  metrics_df = pd.read_csv(csv_metrics_path)
84
- # Find rows where any 'test/' column is not null and reset its epoch to the best model epoch
85
  test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
86
- mask = metrics_df[test_columns].notna().any(axis=1)
87
- metrics_df.loc[mask, 'epoch'] = override_dict['epoch']
 
88
  # Group and filter by best epoch
89
  metrics_df = metrics_df.groupby('epoch').first()
90
- metrics_df = metrics_df[metrics_df.index == override_dict['epoch']]
91
  else:
92
  log.info(f"No metrics.csv found in {output_dir}")
93
 
94
  if metrics_df.empty:
95
- metrics_df = pd.DataFrame(data=override_dict, index=[0])
96
  else:
97
- metrics_df = metrics_df.assign(**override_dict)
98
  metrics_df.index = [0]
99
 
100
  # Add extra info from the input batch experiment summary
@@ -102,7 +111,8 @@ class CSVExperimentSummary(Callback):
102
  orig_meta = self.input_experiment_summary[
103
  self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
104
  ].head(1)
105
- orig_meta.index = [0]
 
106
  metrics_df = metrics_df.combine_first(orig_meta)
107
 
108
  summary_df = pd.concat([summary_df, metrics_df])
@@ -169,9 +179,8 @@ def checkpoint_rerun_config(config: DictConfig):
169
  ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [
170
  key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split']
171
  ])
172
- ckpt_override_keys = ['task',
173
- 'data.drug_featurizer', 'data.protein_featurizer', 'data.collator',
174
- 'model.predictor']
175
 
176
  for key in ckpt_override_keys:
177
  OmegaConf.update(config, key, OmegaConf.select(ckpt_cfg, key), force_add=True)
@@ -183,3 +192,4 @@ def checkpoint_rerun_config(config: DictConfig):
183
  _save_config(config, "config.yaml", hydra_output)
184
 
185
  return config
 
 
1
+ from datetime import timedelta
2
  from pathlib import Path
3
  import re
4
+ from time import time
5
  from typing import Any, Tuple
6
 
7
  import pandas as pd
8
+ from hydra import TaskFunction
9
  from hydra.core.hydra_config import HydraConfig
10
  from hydra.core.utils import _save_config
11
  from hydra.experimental.callbacks import Callback
 
24
  self.filename = filename
25
  self.prefix = prefix if isinstance(prefix, str) else tuple(prefix)
26
  self.input_experiment_summary = None
27
+ self.time = {}
28
 
29
  def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
30
+ if config.hydra.get('overrides') and config.hydra.overrides.get('task'):
31
+ for i, override in enumerate(config.hydra.overrides.task):
32
+ if override.startswith("ckpt_path"):
33
+ ckpt_path = override.split('=', 1)[1]
34
+ if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
35
+ config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
36
+ break
37
+ if config.hydra.sweeper.get('params'):
38
+ if config.hydra.sweeper.params.get('ckpt_path'):
39
+ ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"")
40
+ if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
41
+ config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
42
+
43
+ def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None:
44
+ self.time['start'] = time()
45
 
46
  def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None:
47
  # Skip callback if job is DDP subprocess
 
49
  return
50
 
51
  try:
52
+ self.time['end'] = time()
53
  if config.hydra.mode == RunMode.RUN:
54
  summary_file_path = Path(config.hydra.run.dir) / self.filename
55
  elif config.hydra.mode == RunMode.MULTIRUN:
 
63
  summary_df = pd.DataFrame()
64
 
65
  # Add job and override info
66
+ info_dict = {}
67
+ if job_return.overrides:
68
+ info_dict = dict(override.split('=', 1) for override in job_return.overrides)
69
+ info_dict['job_status'] = job_return.status.name
70
+ info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id
71
+ info_dict['wall_time'] = str(timedelta(self.time['end'] - self.time['start']))
72
 
73
  # Add checkpoint info
74
+ if info_dict.get('ckpt_path'):
75
+ info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"")
76
 
77
+ ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"")
78
+ if Path(ckpt_path).is_file():
79
+ if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']:
80
+ info_dict['previous_ckpt_path'] = info_dict['ckpt_path']
81
+ info_dict['ckpt_path'] = ckpt_path
82
+ info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1))
 
 
83
 
84
  # Add metrics info
85
  metrics_df = pd.DataFrame()
 
88
  csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
89
  if csv_metrics_path.is_file():
90
  log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
 
91
  metrics_df = pd.read_csv(csv_metrics_path)
92
+ # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch
93
  test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
94
+ if test_columns:
95
+ mask = metrics_df[test_columns].notna().any(axis=1)
96
+ metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch']
97
  # Group and filter by best epoch
98
  metrics_df = metrics_df.groupby('epoch').first()
99
+ metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']]
100
  else:
101
  log.info(f"No metrics.csv found in {output_dir}")
102
 
103
  if metrics_df.empty:
104
+ metrics_df = pd.DataFrame(data=info_dict, index=[0])
105
  else:
106
+ metrics_df = metrics_df.assign(**info_dict)
107
  metrics_df.index = [0]
108
 
109
  # Add extra info from the input batch experiment summary
 
111
  orig_meta = self.input_experiment_summary[
112
  self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
113
  ].head(1)
114
+ if not orig_meta.empty:
115
+ orig_meta.index = [0]
116
  metrics_df = metrics_df.combine_first(orig_meta)
117
 
118
  summary_df = pd.concat([summary_df, metrics_df])
 
179
  ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [
180
  key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split']
181
  ])
182
+ ckpt_override_keys = ['task', 'data.drug_featurizer', 'data.protein_featurizer', 'data.collator',
183
+ 'model.predictor', 'model.out', 'model.loss', 'model.activation', 'model.metrics']
 
184
 
185
  for key in ckpt_override_keys:
186
  OmegaConf.update(config, key, OmegaConf.select(ckpt_cfg, key), force_add=True)
 
192
  _save_config(config, "config.yaml", hydra_output)
193
 
194
  return config
195
+
resources/checkpoints/deep_dta-binary-general.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d652a18dda549aa3b16a47cfbe930a1db4aea79c6ecb5294013fe2225dec313a
3
+ size 16906032
resources/checkpoints/deep_dta-binary-general.ckpt.bak ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:321073e4f30845920a2ec6fa8f18a31e8190d9cfe4a1ad264886084de8d8a0ee
3
+ size 16888959
resources/vocabs/drug_vqa/combinedVoc-wholeFour.voc CHANGED
@@ -1,4 +1,3 @@
1
- [PAD]
2
  [102Ru]
3
  [80Se]
4
  [N-]
 
 
1
  [102Ru]
2
  [80Se]
3
  [N-]