jinysun commited on
Commit
36c5570
·
1 Parent(s): 6edf9da

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitignore +160 -0
  2. 15data.h5 +3 -0
  3. LICENSE +201 -0
  4. abcBERT.py +96 -0
  5. app.py +37 -0
  6. compound_constants.py +156 -0
  7. dataset.py +497 -0
  8. model.py +280 -0
  9. requirements.txt +10 -0
  10. utils.py +696 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
15data.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ec80795633fe96e7226a7e63909138e6f4fc37654dcff6831627b1670986497
3
+ size 17610752
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
abcBERT.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Mar 2 15:05:03 2023
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ import tensorflow as tf
9
+
10
+ import tensorflow.keras as keras
11
+ import tensorflow.keras.layers as layers
12
+ from tensorflow.keras.constraints import max_norm
13
+ import pandas as pd
14
+ import numpy as np
15
+ import sys
16
+ from dataset import predict_smiles
17
+ from sklearn.metrics import r2_score,roc_auc_score
18
+ from model import PredictModel,BertModel
19
+ import os
20
+ os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
21
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
22
+
23
+ def main(smiles):
24
+ keras.backend.clear_session()
25
+ os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
26
+ small = {'name': 'Small', 'num_layers': 3, 'num_heads': 4, 'd_model': 128, 'path': 'small_weights','addH':True}
27
+ medium = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights2','addH':True}
28
+ medium3 = {'name': 'Medium', 'num_layers': 8, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights2',
29
+ 'addH': True}
30
+ large = {'name': 'Large', 'num_layers': 12, 'num_heads': 12, 'd_model': 576, 'path': 'large_weights','addH':True}
31
+ medium_without_H = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'weights_without_H','addH':False}
32
+ medium_without_pretrain = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256,'path': 'medium_without_pretraining_weights','addH':True}
33
+
34
+ arch = medium3## small 3 4 128 medium: 6 6 256 large: 12 8 516
35
+
36
+ pretraining = False
37
+ pretraining_str = 'pretraining' if pretraining else ''
38
+
39
+ trained_epoch = 80
40
+ task = 'data'
41
+ seed = 14
42
+ num_layers = arch['num_layers']
43
+ num_heads = arch['num_heads']
44
+ d_model = arch['d_model']
45
+ addH = arch['addH']
46
+ dff = d_model * 2
47
+ vocab_size =60
48
+ dropout_rate = 0.1
49
+
50
+ tf.random.set_seed(seed=seed)
51
+ graph_dataset = predict_smiles(smiles, addH=addH)
52
+ # graph_dataset = Graph_Regression_Dataset('data/reg/{}.csv', smiles_field='SMILES',
53
+ # label_field='PCE',addH=addH)
54
+ test_dataset = graph_dataset.get_data()
55
+
56
+ #value_range = graph_dataset.value_range()
57
+
58
+ x, adjoin_matrix, y = next(iter(test_dataset.take(1)))
59
+ seq = tf.cast(tf.math.equal(x, 0), tf.float32)
60
+ mask = seq[:, tf.newaxis, tf.newaxis, :]
61
+
62
+ model = PredictModel(num_layers=num_layers, d_model=d_model, dff=dff, num_heads=num_heads, vocab_size=vocab_size,
63
+ dense_dropout=0.2)
64
+ preds = model(x, mask=mask, adjoin_matrix=adjoin_matrix, training=False)
65
+ model.load_weights('{}.h5'.format('15data'))
66
+
67
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
68
+ def __init__(self, d_model, total_steps=4000):
69
+ super(CustomSchedule, self).__init__()
70
+
71
+ self.d_model = d_model
72
+ self.d_model = tf.cast(self.d_model, tf.float32)
73
+ self.total_step = total_steps
74
+ self.warmup_steps = total_steps*0.10
75
+
76
+ def __call__(self, step):
77
+ arg1 = step/self.warmup_steps
78
+ arg2 = 1-(step-self.warmup_steps)/(self.total_step-self.warmup_steps)
79
+
80
+ return 10e-5* tf.math.minimum(arg1, arg2)
81
+
82
+ steps_per_epoch = len(test_dataset)
83
+ value_range = 1
84
+ y_true = []
85
+ y_preds = []
86
+
87
+ for x, adjoin_matrix, y in test_dataset:
88
+ seq = tf.cast(tf.math.equal(x, 0), tf.float32)
89
+ mask = seq[:, tf.newaxis, tf.newaxis, :]
90
+ preds = model(x, mask=mask, adjoin_matrix=adjoin_matrix, training=False)
91
+ y_true.append(y.numpy())
92
+ y_preds.append(preds.numpy())
93
+ y_true = np.concatenate(y_true, axis=0).reshape(-1)
94
+ y_preds = np.concatenate(y_preds, axis=0).reshape(-1)
95
+
96
+ return y_preds
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import rdkit
4
+ import streamlit_ketcher
5
+ from streamlit_ketcher import st_ketcher
6
+ import abcBERT
7
+
8
+ # Page setup
9
+ st.set_page_config(page_title="DeepAcceptor", page_icon="🔋", layout="wide")
10
+ st.title("DeepAcceptor")
11
+
12
+ # Connect to the Google Sheet
13
+ url1 = r"https://docs.google.com/spreadsheets/d/1YOEIg0nMTSPkAOr8wkqxQRLuUhys3-J0I-KPEpmzPLw/gviz/tq?tqx=out:csv&sheet=accept"
14
+ url = r"https://docs.google.com/spreadsheets/d/1YOEIg0nMTSPkAOr8wkqxQRLuUhys3-J0I-KPEpmzPLw/gviz/tq?tqx=out:csv&sheet=111"
15
+ df1 = pd.read_csv(url1, dtype=str, encoding='utf-8')
16
+
17
+ text_search = st.text_input("Search papers or molecules", value="")
18
+ m1 = df1["name"].str.contains(text_search)
19
+ m2 = df1["reference"].str.contains(text_search)
20
+ df_search = df1[m1 | m2]
21
+ if text_search:
22
+ st.write(df_search)
23
+ st.download_button( "Download edited files as .csv", df_search.to_csv(), "df_search.csv", use_container_width=True)
24
+ edited_df = st.data_editor(df1, num_rows="dynamic")
25
+ edited_df.to_csv(url)
26
+ st.download_button(
27
+ "⬇️ Download edited files as .csv", edited_df.to_csv(), "edited_df.csv", use_container_width=True
28
+ )
29
+
30
+ molecule = st.text_input("Molecule")
31
+ smile_code = st_ketcher(molecule)
32
+ st.markdown(f"Smile code: ``{smile_code}``")
33
+ try:
34
+ pce = abcBERT.main( str(smile_code ) )
35
+ st.markdown(f"PCE: ``{pce}``")
36
+ except:
37
+ st.markdown(f"PCE: None ")
compound_constants.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Jul 28 21:20:20 2022
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+
9
+ """
10
+ | Compound constants.
11
+ """
12
+
13
+
14
+ # functional groups from https://www.daylight.com/dayhtml_tutorials/languages/smarts/smarts_examples.html
15
+ DAY_LIGHT_FG_SMARTS_LIST = [
16
+ # C
17
+ "[CX4]",
18
+ "[$([CX2](=C)=C)]",
19
+ "[$([CX3]=[CX3])]",
20
+ "[$([CX2]#C)]",
21
+ # C & O
22
+ "[CX3]=[OX1]",
23
+ "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]",
24
+ "[CX3](=[OX1])C",
25
+ "[OX1]=CN",
26
+ "[CX3](=[OX1])O",
27
+ "[CX3](=[OX1])[F,Cl,Br,I]",
28
+ "[CX3H1](=O)[#6]",
29
+ "[CX3](=[OX1])[OX2][CX3](=[OX1])",
30
+ "[NX3][CX3](=[OX1])[#6]",
31
+ "[NX3][CX3]=[NX3+]",
32
+ "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]",
33
+ "[NX3][CX3](=[OX1])[OX2H0]",
34
+ "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]",
35
+ "[CX3](=O)[O-]",
36
+ "[CX3](=[OX1])(O)O",
37
+ "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]",
38
+ "C[OX2][CX3](=[OX1])[OX2]C",
39
+ "[CX3](=O)[OX2H1]",
40
+ "[CX3](=O)[OX1H0-,OX2H1]",
41
+ "[NX3][CX2]#[NX1]",
42
+ "[#6][CX3](=O)[OX2H0][#6]",
43
+ "[#6][CX3](=O)[#6]",
44
+ "[OD2]([#6])[#6]",
45
+ # H
46
+ "[H]",
47
+ "[!#1]",
48
+ "[H+]",
49
+ "[+H]",
50
+ "[!H]",
51
+ # N
52
+ "[NX3;H2,H1;!$(NC=O)]",
53
+ "[NX3][CX3]=[CX3]",
54
+ "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
55
+ "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]",
56
+ "[NX3][$(C=C),$(cc)]",
57
+ "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]",
58
+ "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]",
59
+ "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]",
60
+ "[CH3X4]",
61
+ "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]",
62
+ "[CH2X4][CX3](=[OX1])[NX3H2]",
63
+ "[CH2X4][CX3](=[OX1])[OH0-,OH]",
64
+ "[CH2X4][SX2H,SX1H0-]",
65
+ "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]",
66
+ "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]",
67
+ "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\
68
+ [$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1",
69
+ "[CHX4]([CH3X4])[CH2X4][CH3X4]",
70
+ "[CH2X4][CHX4]([CH3X4])[CH3X4]",
71
+ "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]",
72
+ "[CH2X4][CH2X4][SX2][CH3X4]",
73
+ "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1",
74
+ "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]",
75
+ "[CH2X4][OX2H]",
76
+ "[NX3][CX3]=[SX1]",
77
+ "[CHX4]([CH3X4])[OX2H]",
78
+ "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12",
79
+ "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1",
80
+ "[CHX4]([CH3X4])[CH3X4]",
81
+ "N[CX4H2][CX3](=[OX1])[O,N]",
82
+ "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]",
83
+ "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]",
84
+ "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]",
85
+ "[#7]",
86
+ "[NX2]=N",
87
+ "[NX2]=[NX2]",
88
+ "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]",
89
+ "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]",
90
+ "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]",
91
+ "[NX3][NX3]",
92
+ "[NX3][NX2]=[*]",
93
+ "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]",
94
+ "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]",
95
+ "[NX3+]=[CX3]",
96
+ "[CX3](=[OX1])[NX3H][CX3](=[OX1])",
97
+ "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])",
98
+ "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])",
99
+ "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]",
100
+ "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]",
101
+ "[NX1]#[CX2]",
102
+ "[CX1-]#[NX2+]",
103
+ "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
104
+ "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
105
+ "[NX2]=[OX1]",
106
+ "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]",
107
+ # O
108
+ "[OX2H]",
109
+ "[#6][OX2H]",
110
+ "[OX2H][CX3]=[OX1]",
111
+ "[OX2H]P",
112
+ "[OX2H][#6X3]=[#6]",
113
+ "[OX2H][cX3]:[c]",
114
+ "[OX2H][$(C=C),$(cc)]",
115
+ "[$([OH]-*=[!#6])]",
116
+ "[OX2,OX1-][OX2,OX1-]",
117
+ # P
118
+ "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\
119
+ $([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\
120
+ ,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]",
121
+ "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\
122
+ $([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\
123
+ $([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]",
124
+ # S
125
+ "[S-][CX3](=S)[#6]",
126
+ "[#6X3](=[SX1])([!N])[!N]",
127
+ "[SX2]",
128
+ "[#16X2H]",
129
+ "[#16!H0]",
130
+ "[#16X2H0]",
131
+ "[#16X2H0][!#16]",
132
+ "[#16X2H0][#16X2H0]",
133
+ "[#16X2H0][!#16].[#16X2H0][!#16]",
134
+ "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]",
135
+ "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]",
136
+ "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]",
137
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]",
138
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]",
139
+ "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]",
140
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]",
141
+ "[SX4](C)(C)(=O)=N",
142
+ "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]",
143
+ "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]",
144
+ "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]",
145
+ "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]",
146
+ "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]",
147
+ "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]",
148
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]",
149
+ "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]",
150
+ "[#16X2][OX2H,OX1H0-]",
151
+ "[#16X2][OX2H0]",
152
+ # X
153
+ "[#6][F,Cl,Br,I]",
154
+ "[F,Cl,Br,I]",
155
+ "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]",
156
+ ]
dataset.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from utils import mol_to_geognn_graph_data_MMFF3d as smiles2adjoin
4
+ import tensorflow as tf
5
+
6
+ str2num = {'<pad>':0 ,'H': 1, 'C': 2, 'N': 3, 'O': 4, 'F': 5, 'S': 6, 'Cl': 7, 'P': 8, 'Br': 9,
7
+ 'B': 10,'I': 11,'Si':12,'Se':13,'<unk>':14,'<mask>':15,'<global>':16}
8
+
9
+ num2str = {i:j for j,i in str2num.items()}
10
+
11
+
12
+ class Graph_Bert_Dataset(object):
13
+ def __init__(self,path,smiles_field=['0'], adj=['1'],addH=True):
14
+ if path.endswith('.txt') or path.endswith('.tsv'):
15
+ self.df = pd.read_csv(path,sep='\n\t')
16
+ else:
17
+ self.df = pd.read_csv(path)
18
+ self.smiles_field = smiles_field
19
+ self.adj = adj
20
+ self.vocab = str2num
21
+ self.devocab = num2str
22
+ self.addH = addH
23
+
24
+ def get_data(self):
25
+
26
+ data = self.df
27
+
28
+ train_idx = []
29
+ idx = data.sample(frac=0.9).index
30
+
31
+ train_idx.extend(idx)
32
+
33
+ data1 = data[data.index.isin(train_idx)]
34
+ data2 = data[~data.index.isin(train_idx)]
35
+
36
+ self.dataset1 = tf.data.Dataset.from_tensor_slices((data1[self.smiles_field],data1[self.adj]))
37
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(256, padded_shapes=(
38
+ tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([None]) ,tf.TensorShape([None]))).prefetch(50)
39
+
40
+ self.dataset2 = tf.data.Dataset.from_tensor_slices((data2[self.smiles_field],data2[self.adj]))
41
+ self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(512, padded_shapes=(
42
+ tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None]),
43
+ tf.TensorShape([None]))).prefetch(50)
44
+ return self.dataset1, self.dataset2
45
+
46
+ def numerical_smiles(self, atom, adj):
47
+ #smiles = smiles.numpy().decode()
48
+ atom = np.array(atom)
49
+ atom = atom[0].decode()
50
+
51
+ atom = atom.replace('\n','')
52
+
53
+ atom = atom.replace('[',' ')
54
+ atom = atom.replace(']',' ')
55
+ atom = atom.split("'")
56
+
57
+
58
+ atoms_list = []
59
+ for i in atom:
60
+ if i not in [' ']:
61
+ atoms_list.append(i)
62
+
63
+ adj = np.array(adj)[0].decode()
64
+
65
+ adjoin_matrix =np.load( adj )
66
+
67
+ atoms_list = ['<global>'] + atoms_list
68
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
69
+ temp = np.ones((len(nums_list),len(nums_list)))
70
+ temp[1:,1:] = adjoin_matrix
71
+ temp[np.where(temp == 0)]=-1e9
72
+
73
+
74
+ adjoin_matrix = temp
75
+ #adjoin_matrix = (1 - temp) * (-1e9)
76
+
77
+ choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
78
+ y = np.array(nums_list).astype('int64')
79
+ weight = np.zeros(len(nums_list))
80
+ for i in choices:
81
+ rand = np.random.rand()
82
+ weight[i] = 1
83
+ if rand < 0.8:
84
+ nums_list[i] = str2num['<mask>']
85
+ elif rand < 0.9:
86
+ nums_list[i] = int(np.random.rand() * 14 + 1)
87
+
88
+ x = np.array(nums_list).astype('int64')
89
+ weight = weight.astype('float32')
90
+ return x, adjoin_matrix, y, weight
91
+
92
+ def tf_numerical_smiles(self, atom,adj):
93
+ #print(data)
94
+ # x,adjoin_matrix,y,weight = tf.py_function(self.balanced_numerical_smiles,
95
+ # [data], [tf.int64, tf.float32 ,tf.int64,tf.float32])
96
+ x, adjoin_matrix, y, weight = tf.py_function(self.numerical_smiles, (atom, adj),
97
+ [tf.int64, tf.float32, tf.int64, tf.float32])
98
+
99
+ x.set_shape([None])
100
+ adjoin_matrix.set_shape([None,None])
101
+ y.set_shape([None])
102
+ weight.set_shape([None])
103
+ return x, adjoin_matrix, y, weight
104
+
105
+ class Graph_Regression_Dataset_test(object):
106
+ def __init__(self,path,smiles_field='SMILES',label_field='PCE',normalize=False,max_len=1000,addH=True):
107
+ if path.endswith('.txt') or path.endswith('.tsv'):
108
+ self.df = pd.read_csv(path.format('test'),sep='\t')
109
+ else:
110
+ self.df = pd.read_csv(path.format('test'))
111
+
112
+ self.smiles_field = smiles_field
113
+ self.label_field = label_field
114
+ self.vocab = str2num
115
+ self.devocab = num2str
116
+ self.df = self.df[self.df[smiles_field].str.len()<=max_len]
117
+ self.addH = addH
118
+ if normalize:
119
+ self.max = self.df[self.label_field].max()
120
+ self.min = self.df[self.label_field].min()
121
+ self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
122
+ self.value_range = self.max-self.min
123
+
124
+
125
+ def get_data(self):
126
+ train_data = self.df
127
+ self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field], train_data[self.label_field]))
128
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
129
+ tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1])))
130
+ return self.dataset1
131
+
132
+ def numerical_smiles(self, smiles,label):
133
+ smiles = smiles.numpy().decode()
134
+ atoms_list, adjoin_matrix = smiles2adjoins(smiles)
135
+ atoms_list = list(atoms_list)
136
+ atoms_list = ['<global>'] + atoms_list
137
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
138
+ temp = np.ones((len(nums_list),len(nums_list)))
139
+ temp[1:,1:] = adjoin_matrix
140
+ temp[np.where(temp == 0)]=-1e9
141
+ adjoin_matrix = temp
142
+ x = np.array(nums_list).astype('int64')
143
+ y = np.array([label]).astype('float32')
144
+ return x, adjoin_matrix,y
145
+
146
+ def tf_numerical_smiles(self, smiles,label):
147
+ x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, [smiles,label], [tf.int64, tf.float32 ,tf.float32])
148
+ x.set_shape([None])
149
+ adjoin_matrix.set_shape([None,None])
150
+ y.set_shape([None])
151
+ return x, adjoin_matrix , y
152
+
153
+ class predict_smiles(object):
154
+ def __init__(self,smiles ,normalize=False,max_len=1000,addH=True):
155
+
156
+ self.smiles_field = smiles
157
+
158
+ self.label_field = float(0)
159
+ self.vocab = str2num
160
+ self.devocab = num2str
161
+ #self.df = self.df[self.df[smiles_field].str.len()<=max_len]
162
+ self.addH = addH
163
+ if normalize:
164
+ self.max = self.df[self.label_field].max()
165
+ self.min = self.df[self.label_field].min()
166
+ self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
167
+ self.value_range = self.max-self.min
168
+ def numerical_smiles(self, atoms_list,adj,label):
169
+
170
+ atom = np.array(atoms_list)
171
+ atoms_list = []
172
+ for i in atom:
173
+ if i not in [' ']:
174
+ atoms_list.append(i)
175
+ label = np.array(label)
176
+
177
+ adj = np.array(adj)
178
+
179
+ adjoin_matrix =adj
180
+
181
+ atoms_list = ['<global>'] + atoms_list
182
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
183
+ #temp = np.ones((len(nums_list),len(nums_list)))
184
+ #temp[1:, 1:] = adjoin_matrix
185
+ #adjoin_matrix = (1-temp)*(-1e9)
186
+
187
+ temp = np.ones((len(nums_list),len(nums_list)))
188
+ temp[1:,1:] = adjoin_matrix
189
+ temp[np.where(temp == 0)]=-1e9
190
+
191
+
192
+ adjoin_matrix = temp
193
+ x = np.array(nums_list).astype('int64')
194
+ y = np.array([label]).astype('float32')
195
+ return x, adjoin_matrix,y
196
+
197
+ def get_data(self):
198
+ atom, adj = smiles2adjoin( self.smiles_field)
199
+ atom = np.array(atom)
200
+ atoms_list = []
201
+ for i in atom:
202
+ if i not in [' ']:
203
+ atoms_list.append(i)
204
+ adj = np.array(adj)
205
+ adjoin_matrix = adj
206
+ self.dataset1 = tf.data.Dataset.from_tensors((atoms_list, adjoin_matrix, self.label_field))
207
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(1, padded_shapes=(
208
+ tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1])))
209
+
210
+ return self.dataset1
211
+
212
+ def tf_numerical_smiles(self, atoms_list,adj,label):
213
+ x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (atoms_list,adj,label), [tf.int64, tf.float32 ,tf.float32])
214
+ x.set_shape([None])
215
+ adjoin_matrix.set_shape([None,None])
216
+ y.set_shape([None])
217
+ return x, adjoin_matrix , y
218
+
219
+ class Graph_Regression_test(object):
220
+ def __init__(self,path,smiles_field=['0'],adj = ['1'], label_field=['2'],normalize=False,max_len=1000,addH=True):
221
+ if path.endswith('.txt') or path.endswith('.tsv'):
222
+ # self.df = pd.read_csv(path.format('train3'),sep='\t')
223
+ #self.dt = pd.read_csv(path.format('test3'),sep='\t')
224
+ self.dv = pd.read_csv(path.format('val3'),sep='\t')
225
+ else:
226
+ #self.df = pd.read_csv(path.format('train/train'))
227
+ #self.dt = pd.read_csv(path.format('test/test'))
228
+ self.dv = pd.read_csv(path.format('val/val'))
229
+ self.smiles_field = smiles_field
230
+ self.adj = adj
231
+ self.label_field = label_field
232
+ self.vocab = str2num
233
+ self.devocab = num2str
234
+ #self.df = self.df[self.df[smiles_field].str.len()<=max_len]
235
+ self.addH = addH
236
+ if normalize:
237
+ self.max = self.df[self.label_field].max()
238
+ self.min = self.df[self.label_field].min()
239
+ self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
240
+ self.value_range = self.max-self.min
241
+
242
+
243
+ def get_data(self):
244
+ train_data = self.dv
245
+
246
+
247
+ #idx = train_data.sample(frac=0.9).index
248
+ # train_idx = []
249
+ # #idx = train_data.sample(frac=0.9).index
250
+
251
+ # train_idx.extend(idx)
252
+ # data1 = train_data[train_data.index.isin(train_idx)]
253
+ # data2 = train_data[~train_data.index.isin(train_idx)]
254
+ self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field],train_data[self.adj], train_data[self.label_field]))
255
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(64, padded_shapes=(
256
+ tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]))).prefetch(100)
257
+ return self.dataset1
258
+
259
+ def numerical_smiles(self, atom,adj,label):
260
+ atom = np.array(atom)
261
+ atom = atom[0].decode()
262
+
263
+ atom = atom.replace('\n','')
264
+
265
+ atom = atom.replace('[',' ')
266
+ atom = atom.replace(']',' ')
267
+ atom = atom.split("'")
268
+
269
+
270
+ atoms_list = []
271
+ for i in atom:
272
+ if i not in [' ']:
273
+ atoms_list.append(i)
274
+ label = np.array(label)[0]
275
+
276
+ adj = np.array(adj)[0].decode()
277
+
278
+ adjoin_matrix =np.load( adj )
279
+
280
+
281
+
282
+ atoms_list = ['<global>'] + atoms_list
283
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
284
+ #temp = np.ones((len(nums_list),len(nums_list)))
285
+ #temp[1:, 1:] = adjoin_matrix
286
+ #adjoin_matrix = (1-temp)*(-1e9)
287
+
288
+ temp = np.ones((len(nums_list),len(nums_list)))
289
+ temp[1:,1:] = adjoin_matrix
290
+ temp[np.where(temp == 0)]=-1e9
291
+
292
+
293
+ adjoin_matrix = temp
294
+ x = np.array(nums_list).astype('int64')
295
+ y = np.array([label]).astype('float32')
296
+ return x, adjoin_matrix,y
297
+
298
+ def tf_numerical_smiles(self, smiles,adj,label):
299
+ x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (smiles,adj,label), [tf.int64, tf.float32 ,tf.float32])
300
+ x.set_shape([None])
301
+ adjoin_matrix.set_shape([None,None])
302
+ y.set_shape([None])
303
+ return x, adjoin_matrix , y
304
+
305
+ class Graph_Regression(object):
306
+ def __init__(self,path,smiles_field=['0'],adj = ['1'], label_field=['2'],normalize=False,max_len=1000,addH=True):
307
+ if path.endswith('.txt') or path.endswith('.tsv'):
308
+ self.df = pd.read_csv(path.format('train3'),sep='\t')
309
+ self.dt = pd.read_csv(path.format('test3'),sep='\t')
310
+ #self.dv = pd.read_csv(path.format('val3'),sep='\t')
311
+ else:
312
+ self.df = pd.read_csv(path.format('train/train'))
313
+ self.dt = pd.read_csv(path.format('test/test'))
314
+ #self.dv = pd.read_csv(path.format('val3'))
315
+ self.smiles_field = smiles_field
316
+ self.adj = adj
317
+ self.label_field = label_field
318
+ self.vocab = str2num
319
+ self.devocab = num2str
320
+ #self.df = self.df[self.df[smiles_field].str.len()<=max_len]
321
+ self.addH = addH
322
+ if normalize:
323
+ self.max = self.df[self.label_field].max()
324
+ self.min = self.df[self.label_field].min()
325
+ self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
326
+ self.value_range = self.max-self.min
327
+
328
+
329
+ def get_data(self):
330
+ train_data = self.df
331
+
332
+ test_data = self.dt
333
+ data2=test_data
334
+ #idx = train_data.sample(frac=0.9).index
335
+ # train_idx = []
336
+ # #idx = train_data.sample(frac=0.9).index
337
+
338
+ # train_idx.extend(idx)
339
+ # data1 = train_data[train_data.index.isin(train_idx)]
340
+ # data2 = train_data[~train_data.index.isin(train_idx)]
341
+ self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field],train_data[self.adj], train_data[self.label_field]))
342
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(64, padded_shapes=(
343
+ tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]))).prefetch(100)
344
+
345
+ self.dataset2 = tf.data.Dataset.from_tensor_slices((test_data[self.smiles_field], test_data[self.adj],test_data[self.label_field]))
346
+ self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
347
+ tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([1]))).cache().prefetch(100)
348
+
349
+ self.dataset3 = tf.data.Dataset.from_tensor_slices((data2[self.smiles_field],test_data[self.adj], data2[self.label_field]))
350
+ self.dataset3 = self.dataset3.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
351
+ tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([1]))).cache().prefetch(100)
352
+
353
+ return self.dataset1,self.dataset2,self.dataset3
354
+
355
+ def numerical_smiles(self, atom,adj,label):
356
+ atom = np.array(atom)
357
+ atom = atom[0].decode()
358
+
359
+ atom = atom.replace('\n','')
360
+
361
+ atom = atom.replace('[',' ')
362
+ atom = atom.replace(']',' ')
363
+ atom = atom.split("'")
364
+
365
+
366
+ atoms_list = []
367
+ for i in atom:
368
+ if i not in [' ']:
369
+ atoms_list.append(i)
370
+ label = np.array(label)[0]
371
+
372
+ adj = np.array(adj)[0].decode()
373
+
374
+ adjoin_matrix =np.load( adj )
375
+
376
+
377
+
378
+ atoms_list = ['<global>'] + atoms_list
379
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
380
+ #temp = np.ones((len(nums_list),len(nums_list)))
381
+ #temp[1:, 1:] = adjoin_matrix
382
+ #adjoin_matrix = (1-temp)*(-1e9)
383
+
384
+ temp = np.ones((len(nums_list),len(nums_list)))
385
+ temp[1:,1:] = adjoin_matrix
386
+ temp[np.where(temp == 0)]=-1e9
387
+
388
+
389
+ adjoin_matrix = temp
390
+ x = np.array(nums_list).astype('int64')
391
+ y = np.array([label]).astype('float32')
392
+ return x, adjoin_matrix,y
393
+
394
+ def tf_numerical_smiles(self, smiles,adj,label):
395
+ x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (smiles,adj,label), [tf.int64, tf.float32 ,tf.float32])
396
+ x.set_shape([None])
397
+ adjoin_matrix.set_shape([None,None])
398
+ y.set_shape([None])
399
+ return x, adjoin_matrix , y
400
+
401
+ class Inference_Dataset(object):
402
+ def __init__(self,path,smiles_field='Smiles',addH=True):
403
+ if path.endswith('.txt') or path.endswith('.tsv'):
404
+ self.df = pd.read_csv(path,sep='\t')
405
+ else:
406
+ self.df = pd.read_csv(path)
407
+ self.smiles_field = smiles_field
408
+ self.vocab = str2num
409
+ self.devocab = num2str
410
+ self.addH = addH
411
+
412
+ def get_data(self):
413
+
414
+ data = self.df
415
+
416
+ train_idx = []
417
+ idx = data.sample(frac=0.9).index
418
+
419
+ train_idx.extend(idx)
420
+
421
+ data1 = data[data.index.isin(train_idx)]
422
+ data2 = data[~data.index.isin(train_idx)]
423
+ print(len(data1))
424
+ self.dataset1 = tf.data.Dataset.from_tensor_slices(data1[self.smiles_field].tolist())
425
+ self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(1, padded_shapes=(
426
+ tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([None]) ,tf.TensorShape([None]))).prefetch(50)
427
+ print(self.dataset1)
428
+ self.dataset2 = tf.data.Dataset.from_tensor_slices(data2[self.smiles_field].tolist())
429
+ self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(1, padded_shapes=(
430
+ tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None]),
431
+ tf.TensorShape([None]))).prefetch(50)
432
+ return self.dataset1, self.dataset2
433
+
434
+ def numerical_smiles(self, smiles):
435
+ smiles = smiles.numpy().decode()
436
+ atoms_list, adjoin_matrix = smiles2adjoins(smiles,explicit_hydrogens=self.addH)
437
+ print(atoms_list)
438
+ atoms_list = ['<global>'] + atoms_list
439
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
440
+ temp = np.ones((len(nums_list),len(nums_list)))
441
+ temp[1:,1:] = adjoin_matrix
442
+ temp[np.where(temp == 0)]=-1e9
443
+ adjoin_matrix = temp
444
+ choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
445
+ y = np.array(nums_list).astype('int64')
446
+
447
+ x = np.array(nums_list).astype('int64')
448
+
449
+ return x, adjoin_matrix, [smiles],atoms_list
450
+
451
+ def tf_numerical_smiles(self, data):
452
+ # x,adjoin_matrix,y,weight = tf.py_function(self.balanced_numerical_smiles,
453
+ # [data], [tf.int64, tf.float32 ,tf.int64,tf.float32])
454
+ x, adjoin_matrix, y, weight = tf.py_function(self.numerical_smiles, [data],
455
+ [tf.int64, tf.float32, tf.int64, tf.float32])
456
+ smiles.set_shape([1])
457
+ atom_list.set_shape([None])
458
+ x.set_shape([None])
459
+ adjoin_matrix.set_shape([None,None])
460
+ y.set_shape([None])
461
+ weight.set_shape([None])
462
+ return x, adjoin_matrix,smiles,atom_list
463
+
464
+ class Inference_Dataset(object):
465
+ def __init__(self,sml_list,max_len=1000,addH=True):
466
+ self.vocab = str2num
467
+ self.devocab = num2str
468
+ self.sml_list = [i for i in sml_list if len(i)<max_len]
469
+ self.addH = addH
470
+
471
+ def get_data(self):
472
+
473
+ self.dataset = tf.data.Dataset.from_tensor_slices((self.sml_list,))
474
+ self.dataset = self.dataset.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
475
+ tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]),tf.TensorShape([None]))).cache().prefetch(20)
476
+
477
+ return self.dataset
478
+
479
+ def numerical_smiles(self, smiles):
480
+ smiles_origin = smiles
481
+ smiles = smiles.numpy().decode()
482
+ atoms_list, adjoin_matrix = smiles2adjoins(smiles)
483
+ atoms_list = ['<global>'] + atoms_list
484
+ nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
485
+ temp = np.ones((len(nums_list),len(nums_list)))
486
+ temp[1:,1:] = adjoin_matrix
487
+ adjoin_matrix = (1-temp)*(-1e9)
488
+ x = np.array(nums_list).astype('int64')
489
+ return x, adjoin_matrix,[smiles], atoms_list
490
+
491
+ def tf_numerical_smiles(self, smiles):
492
+ x,adjoin_matrix,smiles,atom_list = tf.py_function(self.numerical_smiles, [smiles], [tf.int64, tf.float32,tf.string, tf.string])
493
+ x.set_shape([None])
494
+ adjoin_matrix.set_shape([None,None])
495
+ smiles.set_shape([1])
496
+ atom_list.set_shape([None])
497
+ return x, adjoin_matrix,smiles,atom_list
model.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ import time
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+
9
+ def gelu(x):
10
+ return 0.5 * x * (1.0 + tf.math.erf(x / tf.sqrt(2.)))
11
+
12
+ def scaled_dot_product_attention(q, k, v, mask,adjoin_matrix):
13
+ """Calculate the attention weights.
14
+ q, k, v must have matching leading dimensions.
15
+ k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
16
+ The mask has different shapes depending on its type(padding or look ahead)
17
+ but it must be broadcastable for addition.
18
+
19
+ Args:
20
+ q: query shape == (..., seq_len_q, depth)
21
+ k: key shape == (..., seq_len_k, depth)
22
+ v: value shape == (..., seq_len_v, depth_v)
23
+ mask: Float tensor with shape broadcastable
24
+ to (..., seq_len_q, seq_len_k). Defaults to None.
25
+
26
+ Returns:
27
+ output, attention_weights
28
+ """
29
+
30
+ matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
31
+
32
+ # scale matmul_qk
33
+ dk = tf.cast(tf.shape(k)[-1], tf.float32)
34
+ scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
35
+
36
+ # add the mask to the scaled tensor.
37
+ if mask is not None:
38
+ scaled_attention_logits += (mask * -1e9)
39
+ if adjoin_matrix is not None:
40
+ #adjoin_matrix1 =tf.where(adjoin_matrix>0,0.0,-1e9)
41
+ #scaled_attention_logits += adjoin_matrix1
42
+ #scaled_attention_logits = scaled_attention_logits * adjoin_matrix
43
+ scaled_attention_logits += adjoin_matrix
44
+
45
+ # softmax is normalized on the last axis (seq_len_k) so that the scores
46
+ # add up to 1.
47
+ attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
48
+
49
+ output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
50
+
51
+ return output, attention_weights
52
+
53
+
54
+ class MultiHeadAttention(tf.keras.layers.Layer):
55
+ def __init__(self, d_model, num_heads):
56
+ super(MultiHeadAttention, self).__init__()
57
+ self.num_heads = num_heads
58
+ self.d_model = d_model
59
+
60
+ assert d_model % self.num_heads == 0
61
+
62
+ self.depth = d_model // self.num_heads
63
+
64
+ self.wq = tf.keras.layers.Dense(d_model)
65
+ self.wk = tf.keras.layers.Dense(d_model)
66
+ self.wv = tf.keras.layers.Dense(d_model)
67
+
68
+ self.dense = tf.keras.layers.Dense(d_model)
69
+
70
+ def split_heads(self, x, batch_size):
71
+ """Split the last dimension into (num_heads, depth).
72
+ Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
73
+ """
74
+ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
75
+ return tf.transpose(x, perm=[0, 2, 1, 3])
76
+
77
+ def call(self, v, k, q, mask,adjoin_matrix):
78
+ batch_size = tf.shape(q)[0]
79
+
80
+ q = self.wq(q) # (batch_size, seq_len, d_model)
81
+ k = self.wk(k) # (batch_size, seq_len, d_model)
82
+ v = self.wv(v) # (batch_size, seq_len, d_model)
83
+
84
+ q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
85
+ k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
86
+ v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
87
+
88
+ # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
89
+ # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
90
+ scaled_attention, attention_weights = scaled_dot_product_attention(
91
+ q, k, v, mask,adjoin_matrix)
92
+
93
+ scaled_attention = tf.transpose(scaled_attention,
94
+ perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
95
+
96
+ concat_attention = tf.reshape(scaled_attention,
97
+ (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
98
+
99
+ output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
100
+
101
+ return output, attention_weights
102
+
103
+ def point_wise_feed_forward_network(d_model, dff):
104
+ return tf.keras.Sequential([
105
+ tf.keras.layers.Dense(dff, activation=gelu), # (batch_size, seq_len, dff)tf.keras.layers.LeakyReLU(0.01)
106
+ tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
107
+ ])
108
+
109
+
110
+ class EncoderLayer(tf.keras.layers.Layer):
111
+ def __init__(self, d_model, num_heads, dff, rate=0.1):
112
+ super(EncoderLayer, self).__init__()
113
+
114
+ self.mha = MultiHeadAttention(d_model, num_heads)
115
+ self.ffn = point_wise_feed_forward_network(d_model, dff)
116
+
117
+ self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
118
+ self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
119
+
120
+ self.dropout1 = tf.keras.layers.Dropout(rate)
121
+ self.dropout2 = tf.keras.layers.Dropout(rate)
122
+
123
+ def call(self, x, training, mask,adjoin_matrix):
124
+ attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix) # (batch_size, input_seq_len, d_model)
125
+ attn_output = self.dropout1(attn_output, training=training)
126
+ out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
127
+
128
+ ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
129
+ ffn_output = self.dropout2(ffn_output, training=training)
130
+ out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
131
+
132
+ return out2,attention_weights
133
+
134
+
135
+ class Encoder(tf.keras.Model):
136
+ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
137
+ maximum_position_encoding, rate=0.1):
138
+ super(Encoder, self).__init__()
139
+
140
+ self.d_model = d_model
141
+ self.num_layers = num_layers
142
+
143
+ self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
144
+ # self.pos_encoding = positional_encoding(maximum_position_encoding,
145
+ # self.d_model)
146
+
147
+ self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
148
+ for _ in range(num_layers)]
149
+
150
+ self.dropout = tf.keras.layers.Dropout(rate)
151
+
152
+ def call(self, x, training, mask,adjoin_matrix):
153
+ seq_len = tf.shape(x)[1]
154
+ adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
155
+ # adding embedding and position encoding.
156
+ x = self.embedding(x) # (batch_size, input_seq_len, d_model)
157
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
158
+
159
+ x = self.dropout(x, training=training)
160
+
161
+ for i in range(self.num_layers):
162
+ x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
163
+ return x # (batch_size, input_seq_len, d_model)
164
+
165
+ class Encoder_test(tf.keras.Model):
166
+ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
167
+ maximum_position_encoding, rate=0.1):
168
+ super(Encoder_test, self).__init__()
169
+
170
+ self.d_model = d_model
171
+ self.num_layers = num_layers
172
+
173
+ self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
174
+ # self.pos_encoding = positional_encoding(maximum_position_encoding,
175
+ # self.d_model)
176
+
177
+ self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
178
+ for _ in range(num_layers)]
179
+
180
+ self.dropout = tf.keras.layers.Dropout(rate)
181
+
182
+ def call(self, x, training, mask,adjoin_matrix):
183
+ seq_len = tf.shape(x)[1]
184
+ adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
185
+ # adding embedding and position encoding.
186
+ x = self.embedding(x) # (batch_size, input_seq_len, d_model)
187
+ x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
188
+ # x += self.pos_encoding[:, :seq_len, :]
189
+
190
+ x = self.dropout(x, training=training)
191
+ attention_weights_list = []
192
+ xs = []
193
+
194
+ for i in range(self.num_layers):
195
+ x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
196
+ attention_weights_list.append(attention_weights)
197
+ xs.append(x)
198
+
199
+ return x,attention_weights_list,xs
200
+
201
+ class BertModel_test(tf.keras.Model):
202
+ def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size = 17,dropout_rate = 0.1):
203
+ super(BertModel_test, self).__init__()
204
+ self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
205
+ num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
206
+ self.fc1 = tf.keras.layers.Dense(d_model, activation=gelu)
207
+ self.layernorm = tf.keras.layers.LayerNormalization(-1)
208
+ self.fc2 = tf.keras.layers.Dense(vocab_size)
209
+ def call(self,x,adjoin_matrix,mask,training=False):
210
+ x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
211
+ x = self.fc1(x)
212
+ x = self.layernorm(x)
213
+ x = self.fc2(x)
214
+ return x,att,xs
215
+
216
+
217
+
218
+
219
+ class BertModel(tf.keras.Model):
220
+ def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size = 17,dropout_rate = 0.1):
221
+ super(BertModel, self).__init__()
222
+ self.encoder = Encoder(num_layers=num_layers,d_model=d_model,
223
+ num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
224
+ self.fc1 = tf.keras.layers.Dense(d_model, activation=gelu)
225
+ self.layernorm = tf.keras.layers.LayerNormalization(-1)
226
+ self.fc2 = tf.keras.layers.Dense(vocab_size)
227
+
228
+ def call(self,x,adjoin_matrix,mask,training=False):
229
+ x = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
230
+ x = self.fc1(x)
231
+ x = self.layernorm(x)
232
+ x = self.fc2(x)
233
+ return x
234
+
235
+
236
+ class PredictModel(tf.keras.Model):
237
+ def __init__(self,num_layers = 8,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.1):
238
+ super(PredictModel, self).__init__()
239
+ self.encoder = Encoder(num_layers=num_layers,d_model=d_model,
240
+ num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
241
+
242
+ self.fc1 = tf.keras.layers.Dense(256,activation=tf.keras.layers.LeakyReLU(0.25))
243
+ self.fc2 = tf.keras.layers.Dense(256,activation=tf.keras.layers.LeakyReLU(0.25))
244
+ self.dropout = tf.keras.layers.Dropout(dense_dropout)
245
+ self.fc3 = tf.keras.layers.Dense(1)
246
+
247
+ def call(self,x,adjoin_matrix,mask,training=False):
248
+ x = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
249
+ x = x[:,0,:]
250
+ x = self.fc1(x)
251
+ x = self.dropout(x,training=training)
252
+ x = self.fc2(x)
253
+ x = self.fc3(x)
254
+ return x
255
+
256
+
257
+
258
+ class PredictModel_test(tf.keras.Model):
259
+ def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
260
+ super(PredictModel_test, self).__init__()
261
+ self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
262
+ num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
263
+
264
+ self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
265
+ self.dropout = tf.keras.layers.Dropout(dense_dropout)
266
+ self.fc2 = tf.keras.layers.Dense(1)
267
+
268
+ def call(self,x,adjoin_matrix,mask,training=False):
269
+ x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
270
+ x = x[:, 0, :]
271
+ x = self.fc1(x)
272
+ x = self.dropout(x, training=training)
273
+ x = self.fc2(x)
274
+ return x,att,xs
275
+
276
+
277
+
278
+
279
+
280
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ altair
2
+ streamlit
3
+ streamlit-ketcher
4
+ tensorflow
5
+ pandas
6
+ rdkit
7
+ scikit-learn
8
+ matplotlib
9
+
10
+
utils.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Jul 28 14:40:59 2022
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ import numpy as np
12
+ from rdkit import Chem
13
+ from rdkit.Chem import AllChem
14
+ from rdkit.Chem import rdchem
15
+
16
+ from compound_constants import DAY_LIGHT_FG_SMARTS_LIST
17
+
18
+
19
+ def get_gasteiger_partial_charges(mol, n_iter=12):
20
+ """
21
+ Calculates list of gasteiger partial charges for each atom in mol object.
22
+ Args:
23
+ mol: rdkit mol object.
24
+ n_iter(int): number of iterations. Default 12.
25
+ Returns:
26
+ list of computed partial charges for each atom.
27
+ """
28
+ Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
29
+ throwOnParamFailure=True)
30
+ partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
31
+ mol.GetAtoms()]
32
+ return partial_charges
33
+
34
+
35
+ def create_standardized_mol_id(smiles):
36
+ """
37
+ Args:
38
+ smiles: smiles sequence.
39
+ Returns:
40
+ inchi.
41
+ """
42
+ if check_smiles_validity(smiles):
43
+ # remove stereochemistry
44
+ smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
45
+ isomericSmiles=False)
46
+ mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
47
+
48
+ if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
49
+ if '.' in smiles: # if multiple species, pick largest molecule
50
+ mol_species_list = split_rdkit_mol_obj(mol)
51
+ largest_mol = get_largest_mol(mol_species_list)
52
+ inchi = AllChem.MolToInchi(largest_mol)
53
+ else:
54
+ inchi = AllChem.MolToInchi(mol)
55
+ return inchi
56
+ else:
57
+ return
58
+ else:
59
+ return
60
+
61
+
62
+ def check_smiles_validity(smiles):
63
+ """
64
+ Check whether the smile can't be converted to rdkit mol object.
65
+ """
66
+ try:
67
+ m = Chem.MolFromSmiles(smiles)
68
+ if m:
69
+ return True
70
+ else:
71
+ return False
72
+ except Exception as e:
73
+ return False
74
+
75
+
76
+ def split_rdkit_mol_obj(mol):
77
+ """
78
+ Split rdkit mol object containing multiple species or one species into a
79
+ list of mol objects or a list containing a single object respectively.
80
+ Args:
81
+ mol: rdkit mol object.
82
+ """
83
+ smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
84
+ smiles_list = smiles.split('.')
85
+ mol_species_list = []
86
+ for s in smiles_list:
87
+ if check_smiles_validity(s):
88
+ mol_species_list.append(AllChem.MolFromSmiles(s))
89
+ return mol_species_list
90
+
91
+
92
+ def get_largest_mol(mol_list):
93
+ """
94
+ Given a list of rdkit mol objects, returns mol object containing the
95
+ largest num of atoms. If multiple containing largest num of atoms,
96
+ picks the first one.
97
+ Args:
98
+ mol_list(list): a list of rdkit mol object.
99
+ Returns:
100
+ the largest mol.
101
+ """
102
+ num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
103
+ largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
104
+ return mol_list[largest_mol_idx]
105
+
106
+ def rdchem_enum_to_list(values):
107
+ """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
108
+ 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
109
+ 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
110
+ 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
111
+ """
112
+ return [values[i] for i in range(len(values))]
113
+
114
+
115
+ def safe_index(alist, elem):
116
+ """
117
+ Return index of element e in list l. If e is not present, return the last index
118
+ """
119
+ try:
120
+ return alist.index(elem)
121
+ except ValueError:
122
+ return len(alist) - 1
123
+
124
+
125
+ def get_atom_feature_dims(list_acquired_feature_names):
126
+ """ tbd
127
+ """
128
+ return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))
129
+
130
+
131
+ def get_bond_feature_dims(list_acquired_feature_names):
132
+ """ tbd
133
+ """
134
+ list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
135
+ # +1 for self loop edges
136
+ return [_l + 1 for _l in list_bond_feat_dim]
137
+
138
+
139
+ class CompoundKit(object):
140
+ """
141
+ CompoundKit
142
+ """
143
+ atom_vocab_dict = {
144
+ "atomic_num": list(range(1, 119)) + ['misc'],
145
+ "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
146
+
147
+ }
148
+ bond_vocab_dict = {
149
+ "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
150
+ "bond_type": rdchem_enum_to_list(rdchem.BondType.values),
151
+
152
+ }
153
+ # float features
154
+ atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
155
+ # bond_float_feats= ["bond_length", "bond_angle"] # optional
156
+
157
+ ### functional groups
158
+ day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
159
+ day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]
160
+
161
+ morgan_fp_N = 200
162
+ morgan2048_fp_N = 2048
163
+ maccs_fp_N = 167
164
+
165
+ period_table = Chem.GetPeriodicTable()
166
+
167
+ ### atom
168
+
169
+ @staticmethod
170
+ def get_atom_value(atom, name):
171
+ """get atom values"""
172
+ if name == 'atomic_num':
173
+ return atom.GetAtomicNum()
174
+ elif name == 'chiral_tag':
175
+ return atom.GetChiralTag()
176
+ elif name == 'degree':
177
+ return atom.GetDegree()
178
+ elif name == 'explicit_valence':
179
+ return atom.GetExplicitValence()
180
+ elif name == 'formal_charge':
181
+ return atom.GetFormalCharge()
182
+ elif name == 'hybridization':
183
+ return atom.GetHybridization()
184
+ elif name == 'implicit_valence':
185
+ return atom.GetImplicitValence()
186
+ elif name == 'is_aromatic':
187
+ return int(atom.GetIsAromatic())
188
+ elif name == 'mass':
189
+ return int(atom.GetMass())
190
+ elif name == 'total_numHs':
191
+ return atom.GetTotalNumHs()
192
+ elif name == 'num_radical_e':
193
+ return atom.GetNumRadicalElectrons()
194
+ elif name == 'atom_is_in_ring':
195
+ return int(atom.IsInRing())
196
+ elif name == 'valence_out_shell':
197
+ return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
198
+ else:
199
+ raise ValueError(name)
200
+
201
+ @staticmethod
202
+ def get_atom_feature_id(atom, name):
203
+ """get atom features id"""
204
+ assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
205
+ return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))
206
+
207
+ @staticmethod
208
+ def get_atom_feature_size(name):
209
+ """get atom features size"""
210
+ assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
211
+ return len(CompoundKit.atom_vocab_dict[name])
212
+
213
+ ### bond
214
+
215
+ @staticmethod
216
+ def get_bond_value(bond, name):
217
+ """get bond values"""
218
+ if name == 'bond_dir':
219
+ return bond.GetBondDir()
220
+ elif name == 'bond_type':
221
+ return bond.GetBondType()
222
+ elif name == 'is_in_ring':
223
+ return int(bond.IsInRing())
224
+ elif name == 'is_conjugated':
225
+ return int(bond.GetIsConjugated())
226
+ elif name == 'bond_stereo':
227
+ return bond.GetStereo()
228
+ else:
229
+ raise ValueError(name)
230
+
231
+ @staticmethod
232
+ def get_bond_feature_id(bond, name):
233
+ """get bond features id"""
234
+ assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
235
+ return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))
236
+
237
+ @staticmethod
238
+ def get_bond_feature_size(name):
239
+ """get bond features size"""
240
+ assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
241
+ return len(CompoundKit.bond_vocab_dict[name])
242
+
243
+ ### fingerprint
244
+
245
+ @staticmethod
246
+ def get_morgan_fingerprint(mol, radius=2):
247
+ """get morgan fingerprint"""
248
+ nBits = CompoundKit.morgan_fp_N
249
+ mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
250
+ return [int(b) for b in mfp.ToBitString()]
251
+
252
+ @staticmethod
253
+ def get_morgan2048_fingerprint(mol, radius=2):
254
+ """get morgan2048 fingerprint"""
255
+ nBits = CompoundKit.morgan2048_fp_N
256
+ mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
257
+ return [int(b) for b in mfp.ToBitString()]
258
+
259
+ @staticmethod
260
+ def get_maccs_fingerprint(mol):
261
+ """get maccs fingerprint"""
262
+ fp = AllChem.GetMACCSKeysFingerprint(mol)
263
+ return [int(b) for b in fp.ToBitString()]
264
+
265
+ ### functional groups
266
+
267
+ @staticmethod
268
+ def get_daylight_functional_group_counts(mol):
269
+ """get daylight functional group counts"""
270
+ fg_counts = []
271
+ for fg_mol in CompoundKit.day_light_fg_mo_list:
272
+ sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
273
+ fg_counts.append(len(sub_structs))
274
+ return fg_counts
275
+
276
+ @staticmethod
277
+ def get_ring_size(mol):
278
+ """return (N,6) list"""
279
+ rings = mol.GetRingInfo()
280
+ rings_info = []
281
+ for r in rings.AtomRings():
282
+ rings_info.append(r)
283
+ ring_list = []
284
+ for atom in mol.GetAtoms():
285
+ atom_result = []
286
+ for ringsize in range(3, 9):
287
+ num_of_ring_at_ringsize = 0
288
+ for r in rings_info:
289
+ if len(r) == ringsize and atom.GetIdx() in r:
290
+ num_of_ring_at_ringsize += 1
291
+ if num_of_ring_at_ringsize > 8:
292
+ num_of_ring_at_ringsize = 9
293
+ atom_result.append(num_of_ring_at_ringsize)
294
+
295
+ ring_list.append(atom_result)
296
+ return ring_list
297
+
298
+ @staticmethod
299
+ def atom_to_feat_vector(atom):
300
+ """ tbd """
301
+ atom_names = {
302
+ "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
303
+
304
+ }
305
+ return atom_names
306
+
307
+ @staticmethod
308
+ def get_atom_names(mol):
309
+ """get atom name list
310
+ TODO: to be remove in the future
311
+ """
312
+ atom_features_dicts = []
313
+ Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
314
+ for i, atom in enumerate(mol.GetAtoms()):
315
+ atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))
316
+
317
+ ring_list = CompoundKit.get_ring_size(mol)
318
+ for i, atom in enumerate(mol.GetAtoms()):
319
+ atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
320
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
321
+ atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
322
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
323
+ atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
324
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
325
+ atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
326
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
327
+ atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
328
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
329
+ atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
330
+ CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])
331
+
332
+ return atom_features_dicts
333
+
334
+ @staticmethod
335
+ def check_partial_charge(atom):
336
+ """tbd"""
337
+ pc = atom.GetDoubleProp('_GasteigerCharge')
338
+ if pc != pc:
339
+ # unsupported atom, replace nan with 0
340
+ pc = 0
341
+ if pc == float('inf'):
342
+ # max 4 for other atoms, set to 10 here if inf is get
343
+ pc = 10
344
+ return pc
345
+
346
+
347
+ class Compound3DKit(object):
348
+ """the 3Dkit of Compound"""
349
+ @staticmethod
350
+ def get_atom_poses(mol, conf):
351
+ """tbd"""
352
+ atom_poses = []
353
+ for i, atom in enumerate(mol.GetAtoms()):
354
+ if atom.GetAtomicNum() == 0:
355
+ return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
356
+ pos = conf.GetAtomPosition(i)
357
+ atom_poses.append([pos.x, pos.y, pos.z])
358
+ return atom_poses
359
+
360
+ @staticmethod
361
+ def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
362
+ """the atoms of mol will be changed in some cases."""
363
+ try:
364
+ new_mol = Chem.AddHs(mol)
365
+ res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
366
+ ### MMFF generates multiple conformations
367
+ res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
368
+ #new_mol = Chem.RemoveHs(new_mol)
369
+ index = np.argmin([x[1] for x in res])
370
+ energy = res[index][1]
371
+ conf = new_mol.GetConformer(id=int(index))
372
+ except:
373
+ new_mol = Chem.AddHs(mol)
374
+ AllChem.Compute2DCoords(new_mol)
375
+ energy = 0
376
+ conf = new_mol.GetConformer()
377
+
378
+ atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
379
+ if return_energy:
380
+ return new_mol, atom_poses, energy
381
+ else:
382
+ return new_mol, atom_poses
383
+
384
+ @staticmethod
385
+ def get_2d_atom_poses(mol):
386
+ """get 2d atom poses"""
387
+ AllChem.Compute2DCoords(mol)
388
+ conf = mol.GetConformer()
389
+ atom_poses = Compound3DKit.get_atom_poses(mol, conf)
390
+ return atom_poses
391
+
392
+ @staticmethod
393
+ def get_bond_lengths(edges, atom_poses):
394
+ """get bond lengths"""
395
+ bond_lengths = []
396
+ for src_node_i, tar_node_j in edges:
397
+ bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
398
+ bond_lengths = np.array(bond_lengths, 'float32')
399
+ return bond_lengths
400
+
401
+ @staticmethod
402
+ def get_superedge_angles(edges, atom_poses, dir_type='HT'):
403
+ """get superedge angles"""
404
+ def _get_vec(atom_poses, edge):
405
+ return atom_poses[edge[1]] - atom_poses[edge[0]]
406
+ def _get_angle(vec1, vec2):
407
+ norm1 = np.linalg.norm(vec1)
408
+ norm2 = np.linalg.norm(vec2)
409
+ if norm1 == 0 or norm2 == 0:
410
+ return 0
411
+ vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
412
+ vec2 = vec2 / (norm2 + 1e-5)
413
+ angle = np.arccos(np.dot(vec1, vec2))
414
+ return angle
415
+
416
+ E = len(edges)
417
+ edge_indices = np.arange(E)
418
+ super_edges = []
419
+ bond_angles = []
420
+ bond_angle_dirs = []
421
+ for tar_edge_i in range(E):
422
+ tar_edge = edges[tar_edge_i]
423
+ if dir_type == 'HT':
424
+ src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
425
+
426
+ elif dir_type == 'HH':
427
+ src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
428
+ else:
429
+ raise ValueError(dir_type)
430
+ for src_edge_i in src_edge_indices:
431
+ if src_edge_i == tar_edge_i:
432
+ continue
433
+ src_edge = edges[src_edge_i]
434
+ src_vec = _get_vec(atom_poses, src_edge)
435
+ tar_vec = _get_vec(atom_poses, tar_edge)
436
+ super_edges.append([src_edge_i, tar_edge_i])
437
+ angle = _get_angle(src_vec, tar_vec)
438
+ bond_angles.append(angle)
439
+ bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T
440
+
441
+ if len(super_edges) == 0:
442
+ super_edges = np.zeros([0, 2], 'int64')
443
+ bond_angles = np.zeros([0,], 'float32')
444
+ else:
445
+ super_edges = np.array(super_edges, 'int64')
446
+ bond_angles = np.array(bond_angles, 'float32')
447
+ return super_edges, bond_angles, bond_angle_dirs
448
+
449
+
450
+
451
+ def new_smiles_to_graph_data(smiles, **kwargs):
452
+ """
453
+ Convert smiles to graph data.
454
+ """
455
+ mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
456
+ if mol is None:
457
+ return None
458
+ data = new_mol_to_graph_data(mol)
459
+ return data
460
+
461
+
462
+ def new_mol_to_graph_data(mol):
463
+ """
464
+ mol_to_graph_data
465
+ Args:
466
+ atom_features: Atom features.
467
+ edge_features: Edge features.
468
+ morgan_fingerprint: Morgan fingerprint.
469
+ functional_groups: Functional groups.
470
+ """
471
+ if len(mol.GetAtoms()) == 0:
472
+ return None
473
+
474
+ atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
475
+ bond_id_names = list(CompoundKit.bond_vocab_dict.keys())
476
+
477
+ data = {}
478
+
479
+ ### atom features
480
+ data = {name: [] for name in atom_id_names}
481
+
482
+ raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
483
+ for atom_feat in raw_atom_feat_dicts:
484
+ for name in atom_id_names:
485
+ data[name].append(atom_feat[name])
486
+
487
+ ### bond and bond features
488
+ for name in bond_id_names:
489
+ data[name] = []
490
+ data['edges'] = []
491
+
492
+ for bond in mol.GetBonds():
493
+ i = bond.GetBeginAtomIdx()
494
+ j = bond.GetEndAtomIdx()
495
+ # i->j and j->i
496
+ data['edges'] += [(i, j), (j, i)]
497
+ for name in bond_id_names:
498
+ bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
499
+ data[name] += [bond_feature_id] * 2
500
+
501
+ #### self loop
502
+ N = len(data[atom_id_names[0]])
503
+ for i in range(N):
504
+ data['edges'] += [(i, i)]
505
+ for name in bond_id_names:
506
+ bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
507
+ data[name] += [bond_feature_id] * N
508
+
509
+ ### make ndarray and check length
510
+ for name in list(CompoundKit.atom_vocab_dict.keys()):
511
+ data[name] = np.array(data[name], 'int64')
512
+ for name in CompoundKit.atom_float_names:
513
+ data[name] = np.array(data[name], 'float32')
514
+ for name in bond_id_names:
515
+ data[name] = np.array(data[name], 'int64')
516
+ data['edges'] = np.array(data['edges'], 'int64')
517
+
518
+ ### morgan fingerprint
519
+ data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
520
+ # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
521
+ data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
522
+ data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
523
+ return data
524
+
525
+
526
+ def mol_to_graph_data(mol):
527
+ """
528
+ mol_to_graph_data
529
+ Args:
530
+ atom_features: Atom features.
531
+ edge_features: Edge features.
532
+ morgan_fingerprint: Morgan fingerprint.
533
+ functional_groups: Functional groups.
534
+ """
535
+ if len(mol.GetAtoms()) == 0:
536
+ return None
537
+
538
+ atom_id_names = [
539
+ "atomic_num"
540
+ ]
541
+ bond_id_names = [
542
+ "bond_dir", "bond_type"
543
+ ]
544
+
545
+ data = {}
546
+ for name in atom_id_names:
547
+ data[name] = []
548
+ data['mass'] = []
549
+ for name in bond_id_names:
550
+ data[name] = []
551
+ data['edges'] = []
552
+
553
+ ### atom features
554
+ for i, atom in enumerate(mol.GetAtoms()):
555
+ if atom.GetAtomicNum() == 0:
556
+ return None
557
+ for name in atom_id_names:
558
+
559
+ data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
560
+ data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)
561
+
562
+ ### bond features
563
+ for bond in mol.GetBonds():
564
+
565
+ i = bond.GetBeginAtomIdx()
566
+ j = bond.GetEndAtomIdx()
567
+ # i->j and j->i
568
+ data['edges'] += [(i, j), (j, i)]
569
+ for name in bond_id_names:
570
+ bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
571
+ data[name] += [bond_feature_id] * 2
572
+ num_atoms = mol.GetNumAtoms()
573
+ atoms_list = []
574
+ for i in range(num_atoms):
575
+ atom = mol.GetAtomWithIdx(i)
576
+ atoms_list.append(atom.GetSymbol())
577
+ ### self loop (+2)
578
+
579
+
580
+ N = len(data[atom_id_names[0]])
581
+ for i in range(N):
582
+ data['edges'] += [(i, i)]
583
+ for name in bond_id_names:
584
+ bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
585
+ data[name] += [bond_feature_id] * N
586
+
587
+ ### check whether edge exists
588
+ if len(data['edges']) == 0: # mol has no bonds
589
+ for name in bond_id_names:
590
+ data[name] = np.zeros((0,), dtype="int64")
591
+ data['edges'] = np.zeros((0, 2), dtype="int64")
592
+
593
+ ### make ndarray and check length
594
+ for name in atom_id_names:
595
+ data[name] = np.array(data[name], 'int64')
596
+ data['mass'] = np.array(data['mass'], 'float32')
597
+ for name in bond_id_names:
598
+ data[name] = np.array(data[name], 'int64')
599
+ data['edges'] = np.array(data['edges'], 'int64')
600
+ data['atoms'] = np.array(atoms_list)
601
+ ### morgan fingerprint
602
+ #data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
603
+ # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
604
+ #data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
605
+ #data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
606
+ #return data['bonds_dir'],data['adj_angle']
607
+ return data
608
+
609
+
610
+ def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
611
+ """
612
+ mol: rdkit molecule
613
+ dir_type: direction type for bond_angle grpah
614
+ """
615
+ if len(mol.GetAtoms()) == 0:
616
+ return None
617
+
618
+ data = mol_to_graph_data(mol)
619
+
620
+ data['atom_pos'] = np.array(atom_poses, 'float32')
621
+ data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
622
+ # BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
623
+ # Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
624
+ # data['BondAngleGraph_edges'] = BondAngleGraph_edges
625
+ # data['bond_angle'] = np.array(bond_angles, 'float32')
626
+ data['adj_node'] = gen_adj(len(data['atoms']),data['edges'],data['bond_length'])
627
+ # data['adj_edge'] = gen_adj(len(data['bond_dir']),data['BondAngleGraph_edges'],data['bond_angle'])
628
+ return data['atoms'], data['adj_node']
629
+
630
+
631
+ def mol_to_geognn_graph_data_MMFF3d(smiles):
632
+ """tbd"""
633
+ mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
634
+ if len(mol.GetAtoms()) <= 400:
635
+ mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
636
+ else:
637
+ atom_poses = Compound3DKit.get_2d_atom_poses(mol)
638
+ return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
639
+
640
+
641
+ def mol_to_geognn_graph_data_raw3d(mol):
642
+ """tbd"""
643
+ atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
644
+ return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
645
+ def gen_adj(shape,edges,length):
646
+
647
+ adj=edges
648
+ e = shape
649
+ ones = np.zeros([e,e])
650
+
651
+ #for i in range(e):
652
+ for i in range (len(length)):
653
+ if adj[i,0] != adj[i,1]:
654
+ ones[adj[i,0],adj[i,1]]=format(float(length[i] ), '.3f')
655
+
656
+ return ones
657
+
658
+
659
+ if __name__ == "__main__":
660
+ import pandas as pd
661
+ from tqdm import tqdm
662
+ f = pd.read_csv (r"data/reg/train3.csv")
663
+ re = []
664
+ pce = f['PCE']
665
+ for ind,smile in enumerate ( f.iloc[:,1]):
666
+ print(ind)
667
+ atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
668
+ np.save('data/reg/train/adj'+str(ind)+'.npy',np.array(adj))
669
+ re.append([atom,'data/reg/train/adj'+str(ind)+'.npy',pce[ind] ])
670
+ r = pd.DataFrame(re)
671
+ r.to_csv('data/reg/train/train.csv')
672
+ re = []
673
+
674
+ f = pd.read_csv(r'data/reg/test3.csv')
675
+ re = []
676
+ pce = f['PCE']
677
+
678
+ for ind,smile in enumerate ( f.iloc[:,1]):
679
+ print(ind)
680
+ atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
681
+ np.save('data/reg/test/adj'+str(ind)+'.npy',np.array(adj))
682
+ re.append([atom,'data/reg/test/adj'+str(ind)+'.npy',pce[ind] ])
683
+ r = pd.DataFrame(re)
684
+ r.to_csv('data/reg/test/test.csv')
685
+
686
+ f = pd.read_csv(r'val.csv')
687
+ re = []
688
+ pce = f['PCE']
689
+
690
+ for ind,smile in enumerate ( f.iloc[:,1]):
691
+ print(ind)
692
+ atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
693
+ np.save('data/reg/val/adj'+str(ind)+'.npy',np.array(adj))
694
+ re.append([atom,'data/reg/val/adj'+str(ind)+'.npy',pce[ind] ])
695
+ r = pd.DataFrame(re)
696
+ r.to_csv('data/reg/val/val.csv')