Spaces:
Running
Running
Upload 10 files
Browse files- .gitignore +160 -0
- 15data.h5 +3 -0
- LICENSE +201 -0
- abcBERT.py +96 -0
- app.py +37 -0
- compound_constants.py +156 -0
- dataset.py +497 -0
- model.py +280 -0
- requirements.txt +10 -0
- utils.py +696 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
15data.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ec80795633fe96e7226a7e63909138e6f4fc37654dcff6831627b1670986497
|
3 |
+
size 17610752
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
abcBERT.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Thu Mar 2 15:05:03 2023
|
4 |
+
|
5 |
+
@author: BM109X32G-10GPU-02
|
6 |
+
"""
|
7 |
+
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
import tensorflow.keras as keras
|
11 |
+
import tensorflow.keras.layers as layers
|
12 |
+
from tensorflow.keras.constraints import max_norm
|
13 |
+
import pandas as pd
|
14 |
+
import numpy as np
|
15 |
+
import sys
|
16 |
+
from dataset import predict_smiles
|
17 |
+
from sklearn.metrics import r2_score,roc_auc_score
|
18 |
+
from model import PredictModel,BertModel
|
19 |
+
import os
|
20 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
|
21 |
+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
|
22 |
+
|
23 |
+
def main(smiles):
|
24 |
+
keras.backend.clear_session()
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
|
26 |
+
small = {'name': 'Small', 'num_layers': 3, 'num_heads': 4, 'd_model': 128, 'path': 'small_weights','addH':True}
|
27 |
+
medium = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights2','addH':True}
|
28 |
+
medium3 = {'name': 'Medium', 'num_layers': 8, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights2',
|
29 |
+
'addH': True}
|
30 |
+
large = {'name': 'Large', 'num_layers': 12, 'num_heads': 12, 'd_model': 576, 'path': 'large_weights','addH':True}
|
31 |
+
medium_without_H = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'weights_without_H','addH':False}
|
32 |
+
medium_without_pretrain = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256,'path': 'medium_without_pretraining_weights','addH':True}
|
33 |
+
|
34 |
+
arch = medium3## small 3 4 128 medium: 6 6 256 large: 12 8 516
|
35 |
+
|
36 |
+
pretraining = False
|
37 |
+
pretraining_str = 'pretraining' if pretraining else ''
|
38 |
+
|
39 |
+
trained_epoch = 80
|
40 |
+
task = 'data'
|
41 |
+
seed = 14
|
42 |
+
num_layers = arch['num_layers']
|
43 |
+
num_heads = arch['num_heads']
|
44 |
+
d_model = arch['d_model']
|
45 |
+
addH = arch['addH']
|
46 |
+
dff = d_model * 2
|
47 |
+
vocab_size =60
|
48 |
+
dropout_rate = 0.1
|
49 |
+
|
50 |
+
tf.random.set_seed(seed=seed)
|
51 |
+
graph_dataset = predict_smiles(smiles, addH=addH)
|
52 |
+
# graph_dataset = Graph_Regression_Dataset('data/reg/{}.csv', smiles_field='SMILES',
|
53 |
+
# label_field='PCE',addH=addH)
|
54 |
+
test_dataset = graph_dataset.get_data()
|
55 |
+
|
56 |
+
#value_range = graph_dataset.value_range()
|
57 |
+
|
58 |
+
x, adjoin_matrix, y = next(iter(test_dataset.take(1)))
|
59 |
+
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
|
60 |
+
mask = seq[:, tf.newaxis, tf.newaxis, :]
|
61 |
+
|
62 |
+
model = PredictModel(num_layers=num_layers, d_model=d_model, dff=dff, num_heads=num_heads, vocab_size=vocab_size,
|
63 |
+
dense_dropout=0.2)
|
64 |
+
preds = model(x, mask=mask, adjoin_matrix=adjoin_matrix, training=False)
|
65 |
+
model.load_weights('{}.h5'.format('15data'))
|
66 |
+
|
67 |
+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
68 |
+
def __init__(self, d_model, total_steps=4000):
|
69 |
+
super(CustomSchedule, self).__init__()
|
70 |
+
|
71 |
+
self.d_model = d_model
|
72 |
+
self.d_model = tf.cast(self.d_model, tf.float32)
|
73 |
+
self.total_step = total_steps
|
74 |
+
self.warmup_steps = total_steps*0.10
|
75 |
+
|
76 |
+
def __call__(self, step):
|
77 |
+
arg1 = step/self.warmup_steps
|
78 |
+
arg2 = 1-(step-self.warmup_steps)/(self.total_step-self.warmup_steps)
|
79 |
+
|
80 |
+
return 10e-5* tf.math.minimum(arg1, arg2)
|
81 |
+
|
82 |
+
steps_per_epoch = len(test_dataset)
|
83 |
+
value_range = 1
|
84 |
+
y_true = []
|
85 |
+
y_preds = []
|
86 |
+
|
87 |
+
for x, adjoin_matrix, y in test_dataset:
|
88 |
+
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
|
89 |
+
mask = seq[:, tf.newaxis, tf.newaxis, :]
|
90 |
+
preds = model(x, mask=mask, adjoin_matrix=adjoin_matrix, training=False)
|
91 |
+
y_true.append(y.numpy())
|
92 |
+
y_preds.append(preds.numpy())
|
93 |
+
y_true = np.concatenate(y_true, axis=0).reshape(-1)
|
94 |
+
y_preds = np.concatenate(y_preds, axis=0).reshape(-1)
|
95 |
+
|
96 |
+
return y_preds
|
app.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import rdkit
|
4 |
+
import streamlit_ketcher
|
5 |
+
from streamlit_ketcher import st_ketcher
|
6 |
+
import abcBERT
|
7 |
+
|
8 |
+
# Page setup
|
9 |
+
st.set_page_config(page_title="DeepAcceptor", page_icon="🔋", layout="wide")
|
10 |
+
st.title("DeepAcceptor")
|
11 |
+
|
12 |
+
# Connect to the Google Sheet
|
13 |
+
url1 = r"https://docs.google.com/spreadsheets/d/1YOEIg0nMTSPkAOr8wkqxQRLuUhys3-J0I-KPEpmzPLw/gviz/tq?tqx=out:csv&sheet=accept"
|
14 |
+
url = r"https://docs.google.com/spreadsheets/d/1YOEIg0nMTSPkAOr8wkqxQRLuUhys3-J0I-KPEpmzPLw/gviz/tq?tqx=out:csv&sheet=111"
|
15 |
+
df1 = pd.read_csv(url1, dtype=str, encoding='utf-8')
|
16 |
+
|
17 |
+
text_search = st.text_input("Search papers or molecules", value="")
|
18 |
+
m1 = df1["name"].str.contains(text_search)
|
19 |
+
m2 = df1["reference"].str.contains(text_search)
|
20 |
+
df_search = df1[m1 | m2]
|
21 |
+
if text_search:
|
22 |
+
st.write(df_search)
|
23 |
+
st.download_button( "Download edited files as .csv", df_search.to_csv(), "df_search.csv", use_container_width=True)
|
24 |
+
edited_df = st.data_editor(df1, num_rows="dynamic")
|
25 |
+
edited_df.to_csv(url)
|
26 |
+
st.download_button(
|
27 |
+
"⬇️ Download edited files as .csv", edited_df.to_csv(), "edited_df.csv", use_container_width=True
|
28 |
+
)
|
29 |
+
|
30 |
+
molecule = st.text_input("Molecule")
|
31 |
+
smile_code = st_ketcher(molecule)
|
32 |
+
st.markdown(f"Smile code: ``{smile_code}``")
|
33 |
+
try:
|
34 |
+
pce = abcBERT.main( str(smile_code ) )
|
35 |
+
st.markdown(f"PCE: ``{pce}``")
|
36 |
+
except:
|
37 |
+
st.markdown(f"PCE: None ")
|
compound_constants.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Thu Jul 28 21:20:20 2022
|
4 |
+
|
5 |
+
@author: BM109X32G-10GPU-02
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
| Compound constants.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
# functional groups from https://www.daylight.com/dayhtml_tutorials/languages/smarts/smarts_examples.html
|
15 |
+
DAY_LIGHT_FG_SMARTS_LIST = [
|
16 |
+
# C
|
17 |
+
"[CX4]",
|
18 |
+
"[$([CX2](=C)=C)]",
|
19 |
+
"[$([CX3]=[CX3])]",
|
20 |
+
"[$([CX2]#C)]",
|
21 |
+
# C & O
|
22 |
+
"[CX3]=[OX1]",
|
23 |
+
"[$([CX3]=[OX1]),$([CX3+]-[OX1-])]",
|
24 |
+
"[CX3](=[OX1])C",
|
25 |
+
"[OX1]=CN",
|
26 |
+
"[CX3](=[OX1])O",
|
27 |
+
"[CX3](=[OX1])[F,Cl,Br,I]",
|
28 |
+
"[CX3H1](=O)[#6]",
|
29 |
+
"[CX3](=[OX1])[OX2][CX3](=[OX1])",
|
30 |
+
"[NX3][CX3](=[OX1])[#6]",
|
31 |
+
"[NX3][CX3]=[NX3+]",
|
32 |
+
"[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]",
|
33 |
+
"[NX3][CX3](=[OX1])[OX2H0]",
|
34 |
+
"[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]",
|
35 |
+
"[CX3](=O)[O-]",
|
36 |
+
"[CX3](=[OX1])(O)O",
|
37 |
+
"[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]",
|
38 |
+
"C[OX2][CX3](=[OX1])[OX2]C",
|
39 |
+
"[CX3](=O)[OX2H1]",
|
40 |
+
"[CX3](=O)[OX1H0-,OX2H1]",
|
41 |
+
"[NX3][CX2]#[NX1]",
|
42 |
+
"[#6][CX3](=O)[OX2H0][#6]",
|
43 |
+
"[#6][CX3](=O)[#6]",
|
44 |
+
"[OD2]([#6])[#6]",
|
45 |
+
# H
|
46 |
+
"[H]",
|
47 |
+
"[!#1]",
|
48 |
+
"[H+]",
|
49 |
+
"[+H]",
|
50 |
+
"[!H]",
|
51 |
+
# N
|
52 |
+
"[NX3;H2,H1;!$(NC=O)]",
|
53 |
+
"[NX3][CX3]=[CX3]",
|
54 |
+
"[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
|
55 |
+
"[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]",
|
56 |
+
"[NX3][$(C=C),$(cc)]",
|
57 |
+
"[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]",
|
58 |
+
"[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]",
|
59 |
+
"[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]",
|
60 |
+
"[CH3X4]",
|
61 |
+
"[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]",
|
62 |
+
"[CH2X4][CX3](=[OX1])[NX3H2]",
|
63 |
+
"[CH2X4][CX3](=[OX1])[OH0-,OH]",
|
64 |
+
"[CH2X4][SX2H,SX1H0-]",
|
65 |
+
"[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]",
|
66 |
+
"[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]",
|
67 |
+
"[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\
|
68 |
+
[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1",
|
69 |
+
"[CHX4]([CH3X4])[CH2X4][CH3X4]",
|
70 |
+
"[CH2X4][CHX4]([CH3X4])[CH3X4]",
|
71 |
+
"[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]",
|
72 |
+
"[CH2X4][CH2X4][SX2][CH3X4]",
|
73 |
+
"[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1",
|
74 |
+
"[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]",
|
75 |
+
"[CH2X4][OX2H]",
|
76 |
+
"[NX3][CX3]=[SX1]",
|
77 |
+
"[CHX4]([CH3X4])[OX2H]",
|
78 |
+
"[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12",
|
79 |
+
"[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1",
|
80 |
+
"[CHX4]([CH3X4])[CH3X4]",
|
81 |
+
"N[CX4H2][CX3](=[OX1])[O,N]",
|
82 |
+
"N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]",
|
83 |
+
"[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]",
|
84 |
+
"[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]",
|
85 |
+
"[#7]",
|
86 |
+
"[NX2]=N",
|
87 |
+
"[NX2]=[NX2]",
|
88 |
+
"[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]",
|
89 |
+
"[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]",
|
90 |
+
"[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]",
|
91 |
+
"[NX3][NX3]",
|
92 |
+
"[NX3][NX2]=[*]",
|
93 |
+
"[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]",
|
94 |
+
"[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]",
|
95 |
+
"[NX3+]=[CX3]",
|
96 |
+
"[CX3](=[OX1])[NX3H][CX3](=[OX1])",
|
97 |
+
"[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])",
|
98 |
+
"[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])",
|
99 |
+
"[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]",
|
100 |
+
"[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]",
|
101 |
+
"[NX1]#[CX2]",
|
102 |
+
"[CX1-]#[NX2+]",
|
103 |
+
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
|
104 |
+
"[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]",
|
105 |
+
"[NX2]=[OX1]",
|
106 |
+
"[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]",
|
107 |
+
# O
|
108 |
+
"[OX2H]",
|
109 |
+
"[#6][OX2H]",
|
110 |
+
"[OX2H][CX3]=[OX1]",
|
111 |
+
"[OX2H]P",
|
112 |
+
"[OX2H][#6X3]=[#6]",
|
113 |
+
"[OX2H][cX3]:[c]",
|
114 |
+
"[OX2H][$(C=C),$(cc)]",
|
115 |
+
"[$([OH]-*=[!#6])]",
|
116 |
+
"[OX2,OX1-][OX2,OX1-]",
|
117 |
+
# P
|
118 |
+
"[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\
|
119 |
+
$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\
|
120 |
+
,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]",
|
121 |
+
"[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\
|
122 |
+
$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\
|
123 |
+
$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]",
|
124 |
+
# S
|
125 |
+
"[S-][CX3](=S)[#6]",
|
126 |
+
"[#6X3](=[SX1])([!N])[!N]",
|
127 |
+
"[SX2]",
|
128 |
+
"[#16X2H]",
|
129 |
+
"[#16!H0]",
|
130 |
+
"[#16X2H0]",
|
131 |
+
"[#16X2H0][!#16]",
|
132 |
+
"[#16X2H0][#16X2H0]",
|
133 |
+
"[#16X2H0][!#16].[#16X2H0][!#16]",
|
134 |
+
"[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]",
|
135 |
+
"[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]",
|
136 |
+
"[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]",
|
137 |
+
"[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]",
|
138 |
+
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]",
|
139 |
+
"[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]",
|
140 |
+
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]",
|
141 |
+
"[SX4](C)(C)(=O)=N",
|
142 |
+
"[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]",
|
143 |
+
"[$([#16X3]=[OX1]),$([#16X3+][OX1-])]",
|
144 |
+
"[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]",
|
145 |
+
"[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]",
|
146 |
+
"[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]",
|
147 |
+
"[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]",
|
148 |
+
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]",
|
149 |
+
"[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]",
|
150 |
+
"[#16X2][OX2H,OX1H0-]",
|
151 |
+
"[#16X2][OX2H0]",
|
152 |
+
# X
|
153 |
+
"[#6][F,Cl,Br,I]",
|
154 |
+
"[F,Cl,Br,I]",
|
155 |
+
"[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]",
|
156 |
+
]
|
dataset.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from utils import mol_to_geognn_graph_data_MMFF3d as smiles2adjoin
|
4 |
+
import tensorflow as tf
|
5 |
+
|
6 |
+
str2num = {'<pad>':0 ,'H': 1, 'C': 2, 'N': 3, 'O': 4, 'F': 5, 'S': 6, 'Cl': 7, 'P': 8, 'Br': 9,
|
7 |
+
'B': 10,'I': 11,'Si':12,'Se':13,'<unk>':14,'<mask>':15,'<global>':16}
|
8 |
+
|
9 |
+
num2str = {i:j for j,i in str2num.items()}
|
10 |
+
|
11 |
+
|
12 |
+
class Graph_Bert_Dataset(object):
|
13 |
+
def __init__(self,path,smiles_field=['0'], adj=['1'],addH=True):
|
14 |
+
if path.endswith('.txt') or path.endswith('.tsv'):
|
15 |
+
self.df = pd.read_csv(path,sep='\n\t')
|
16 |
+
else:
|
17 |
+
self.df = pd.read_csv(path)
|
18 |
+
self.smiles_field = smiles_field
|
19 |
+
self.adj = adj
|
20 |
+
self.vocab = str2num
|
21 |
+
self.devocab = num2str
|
22 |
+
self.addH = addH
|
23 |
+
|
24 |
+
def get_data(self):
|
25 |
+
|
26 |
+
data = self.df
|
27 |
+
|
28 |
+
train_idx = []
|
29 |
+
idx = data.sample(frac=0.9).index
|
30 |
+
|
31 |
+
train_idx.extend(idx)
|
32 |
+
|
33 |
+
data1 = data[data.index.isin(train_idx)]
|
34 |
+
data2 = data[~data.index.isin(train_idx)]
|
35 |
+
|
36 |
+
self.dataset1 = tf.data.Dataset.from_tensor_slices((data1[self.smiles_field],data1[self.adj]))
|
37 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(256, padded_shapes=(
|
38 |
+
tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([None]) ,tf.TensorShape([None]))).prefetch(50)
|
39 |
+
|
40 |
+
self.dataset2 = tf.data.Dataset.from_tensor_slices((data2[self.smiles_field],data2[self.adj]))
|
41 |
+
self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(512, padded_shapes=(
|
42 |
+
tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None]),
|
43 |
+
tf.TensorShape([None]))).prefetch(50)
|
44 |
+
return self.dataset1, self.dataset2
|
45 |
+
|
46 |
+
def numerical_smiles(self, atom, adj):
|
47 |
+
#smiles = smiles.numpy().decode()
|
48 |
+
atom = np.array(atom)
|
49 |
+
atom = atom[0].decode()
|
50 |
+
|
51 |
+
atom = atom.replace('\n','')
|
52 |
+
|
53 |
+
atom = atom.replace('[',' ')
|
54 |
+
atom = atom.replace(']',' ')
|
55 |
+
atom = atom.split("'")
|
56 |
+
|
57 |
+
|
58 |
+
atoms_list = []
|
59 |
+
for i in atom:
|
60 |
+
if i not in [' ']:
|
61 |
+
atoms_list.append(i)
|
62 |
+
|
63 |
+
adj = np.array(adj)[0].decode()
|
64 |
+
|
65 |
+
adjoin_matrix =np.load( adj )
|
66 |
+
|
67 |
+
atoms_list = ['<global>'] + atoms_list
|
68 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
69 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
70 |
+
temp[1:,1:] = adjoin_matrix
|
71 |
+
temp[np.where(temp == 0)]=-1e9
|
72 |
+
|
73 |
+
|
74 |
+
adjoin_matrix = temp
|
75 |
+
#adjoin_matrix = (1 - temp) * (-1e9)
|
76 |
+
|
77 |
+
choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
|
78 |
+
y = np.array(nums_list).astype('int64')
|
79 |
+
weight = np.zeros(len(nums_list))
|
80 |
+
for i in choices:
|
81 |
+
rand = np.random.rand()
|
82 |
+
weight[i] = 1
|
83 |
+
if rand < 0.8:
|
84 |
+
nums_list[i] = str2num['<mask>']
|
85 |
+
elif rand < 0.9:
|
86 |
+
nums_list[i] = int(np.random.rand() * 14 + 1)
|
87 |
+
|
88 |
+
x = np.array(nums_list).astype('int64')
|
89 |
+
weight = weight.astype('float32')
|
90 |
+
return x, adjoin_matrix, y, weight
|
91 |
+
|
92 |
+
def tf_numerical_smiles(self, atom,adj):
|
93 |
+
#print(data)
|
94 |
+
# x,adjoin_matrix,y,weight = tf.py_function(self.balanced_numerical_smiles,
|
95 |
+
# [data], [tf.int64, tf.float32 ,tf.int64,tf.float32])
|
96 |
+
x, adjoin_matrix, y, weight = tf.py_function(self.numerical_smiles, (atom, adj),
|
97 |
+
[tf.int64, tf.float32, tf.int64, tf.float32])
|
98 |
+
|
99 |
+
x.set_shape([None])
|
100 |
+
adjoin_matrix.set_shape([None,None])
|
101 |
+
y.set_shape([None])
|
102 |
+
weight.set_shape([None])
|
103 |
+
return x, adjoin_matrix, y, weight
|
104 |
+
|
105 |
+
class Graph_Regression_Dataset_test(object):
|
106 |
+
def __init__(self,path,smiles_field='SMILES',label_field='PCE',normalize=False,max_len=1000,addH=True):
|
107 |
+
if path.endswith('.txt') or path.endswith('.tsv'):
|
108 |
+
self.df = pd.read_csv(path.format('test'),sep='\t')
|
109 |
+
else:
|
110 |
+
self.df = pd.read_csv(path.format('test'))
|
111 |
+
|
112 |
+
self.smiles_field = smiles_field
|
113 |
+
self.label_field = label_field
|
114 |
+
self.vocab = str2num
|
115 |
+
self.devocab = num2str
|
116 |
+
self.df = self.df[self.df[smiles_field].str.len()<=max_len]
|
117 |
+
self.addH = addH
|
118 |
+
if normalize:
|
119 |
+
self.max = self.df[self.label_field].max()
|
120 |
+
self.min = self.df[self.label_field].min()
|
121 |
+
self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
|
122 |
+
self.value_range = self.max-self.min
|
123 |
+
|
124 |
+
|
125 |
+
def get_data(self):
|
126 |
+
train_data = self.df
|
127 |
+
self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field], train_data[self.label_field]))
|
128 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
|
129 |
+
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1])))
|
130 |
+
return self.dataset1
|
131 |
+
|
132 |
+
def numerical_smiles(self, smiles,label):
|
133 |
+
smiles = smiles.numpy().decode()
|
134 |
+
atoms_list, adjoin_matrix = smiles2adjoins(smiles)
|
135 |
+
atoms_list = list(atoms_list)
|
136 |
+
atoms_list = ['<global>'] + atoms_list
|
137 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
138 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
139 |
+
temp[1:,1:] = adjoin_matrix
|
140 |
+
temp[np.where(temp == 0)]=-1e9
|
141 |
+
adjoin_matrix = temp
|
142 |
+
x = np.array(nums_list).astype('int64')
|
143 |
+
y = np.array([label]).astype('float32')
|
144 |
+
return x, adjoin_matrix,y
|
145 |
+
|
146 |
+
def tf_numerical_smiles(self, smiles,label):
|
147 |
+
x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, [smiles,label], [tf.int64, tf.float32 ,tf.float32])
|
148 |
+
x.set_shape([None])
|
149 |
+
adjoin_matrix.set_shape([None,None])
|
150 |
+
y.set_shape([None])
|
151 |
+
return x, adjoin_matrix , y
|
152 |
+
|
153 |
+
class predict_smiles(object):
|
154 |
+
def __init__(self,smiles ,normalize=False,max_len=1000,addH=True):
|
155 |
+
|
156 |
+
self.smiles_field = smiles
|
157 |
+
|
158 |
+
self.label_field = float(0)
|
159 |
+
self.vocab = str2num
|
160 |
+
self.devocab = num2str
|
161 |
+
#self.df = self.df[self.df[smiles_field].str.len()<=max_len]
|
162 |
+
self.addH = addH
|
163 |
+
if normalize:
|
164 |
+
self.max = self.df[self.label_field].max()
|
165 |
+
self.min = self.df[self.label_field].min()
|
166 |
+
self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
|
167 |
+
self.value_range = self.max-self.min
|
168 |
+
def numerical_smiles(self, atoms_list,adj,label):
|
169 |
+
|
170 |
+
atom = np.array(atoms_list)
|
171 |
+
atoms_list = []
|
172 |
+
for i in atom:
|
173 |
+
if i not in [' ']:
|
174 |
+
atoms_list.append(i)
|
175 |
+
label = np.array(label)
|
176 |
+
|
177 |
+
adj = np.array(adj)
|
178 |
+
|
179 |
+
adjoin_matrix =adj
|
180 |
+
|
181 |
+
atoms_list = ['<global>'] + atoms_list
|
182 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
183 |
+
#temp = np.ones((len(nums_list),len(nums_list)))
|
184 |
+
#temp[1:, 1:] = adjoin_matrix
|
185 |
+
#adjoin_matrix = (1-temp)*(-1e9)
|
186 |
+
|
187 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
188 |
+
temp[1:,1:] = adjoin_matrix
|
189 |
+
temp[np.where(temp == 0)]=-1e9
|
190 |
+
|
191 |
+
|
192 |
+
adjoin_matrix = temp
|
193 |
+
x = np.array(nums_list).astype('int64')
|
194 |
+
y = np.array([label]).astype('float32')
|
195 |
+
return x, adjoin_matrix,y
|
196 |
+
|
197 |
+
def get_data(self):
|
198 |
+
atom, adj = smiles2adjoin( self.smiles_field)
|
199 |
+
atom = np.array(atom)
|
200 |
+
atoms_list = []
|
201 |
+
for i in atom:
|
202 |
+
if i not in [' ']:
|
203 |
+
atoms_list.append(i)
|
204 |
+
adj = np.array(adj)
|
205 |
+
adjoin_matrix = adj
|
206 |
+
self.dataset1 = tf.data.Dataset.from_tensors((atoms_list, adjoin_matrix, self.label_field))
|
207 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(1, padded_shapes=(
|
208 |
+
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1])))
|
209 |
+
|
210 |
+
return self.dataset1
|
211 |
+
|
212 |
+
def tf_numerical_smiles(self, atoms_list,adj,label):
|
213 |
+
x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (atoms_list,adj,label), [tf.int64, tf.float32 ,tf.float32])
|
214 |
+
x.set_shape([None])
|
215 |
+
adjoin_matrix.set_shape([None,None])
|
216 |
+
y.set_shape([None])
|
217 |
+
return x, adjoin_matrix , y
|
218 |
+
|
219 |
+
class Graph_Regression_test(object):
|
220 |
+
def __init__(self,path,smiles_field=['0'],adj = ['1'], label_field=['2'],normalize=False,max_len=1000,addH=True):
|
221 |
+
if path.endswith('.txt') or path.endswith('.tsv'):
|
222 |
+
# self.df = pd.read_csv(path.format('train3'),sep='\t')
|
223 |
+
#self.dt = pd.read_csv(path.format('test3'),sep='\t')
|
224 |
+
self.dv = pd.read_csv(path.format('val3'),sep='\t')
|
225 |
+
else:
|
226 |
+
#self.df = pd.read_csv(path.format('train/train'))
|
227 |
+
#self.dt = pd.read_csv(path.format('test/test'))
|
228 |
+
self.dv = pd.read_csv(path.format('val/val'))
|
229 |
+
self.smiles_field = smiles_field
|
230 |
+
self.adj = adj
|
231 |
+
self.label_field = label_field
|
232 |
+
self.vocab = str2num
|
233 |
+
self.devocab = num2str
|
234 |
+
#self.df = self.df[self.df[smiles_field].str.len()<=max_len]
|
235 |
+
self.addH = addH
|
236 |
+
if normalize:
|
237 |
+
self.max = self.df[self.label_field].max()
|
238 |
+
self.min = self.df[self.label_field].min()
|
239 |
+
self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
|
240 |
+
self.value_range = self.max-self.min
|
241 |
+
|
242 |
+
|
243 |
+
def get_data(self):
|
244 |
+
train_data = self.dv
|
245 |
+
|
246 |
+
|
247 |
+
#idx = train_data.sample(frac=0.9).index
|
248 |
+
# train_idx = []
|
249 |
+
# #idx = train_data.sample(frac=0.9).index
|
250 |
+
|
251 |
+
# train_idx.extend(idx)
|
252 |
+
# data1 = train_data[train_data.index.isin(train_idx)]
|
253 |
+
# data2 = train_data[~train_data.index.isin(train_idx)]
|
254 |
+
self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field],train_data[self.adj], train_data[self.label_field]))
|
255 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(64, padded_shapes=(
|
256 |
+
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]))).prefetch(100)
|
257 |
+
return self.dataset1
|
258 |
+
|
259 |
+
def numerical_smiles(self, atom,adj,label):
|
260 |
+
atom = np.array(atom)
|
261 |
+
atom = atom[0].decode()
|
262 |
+
|
263 |
+
atom = atom.replace('\n','')
|
264 |
+
|
265 |
+
atom = atom.replace('[',' ')
|
266 |
+
atom = atom.replace(']',' ')
|
267 |
+
atom = atom.split("'")
|
268 |
+
|
269 |
+
|
270 |
+
atoms_list = []
|
271 |
+
for i in atom:
|
272 |
+
if i not in [' ']:
|
273 |
+
atoms_list.append(i)
|
274 |
+
label = np.array(label)[0]
|
275 |
+
|
276 |
+
adj = np.array(adj)[0].decode()
|
277 |
+
|
278 |
+
adjoin_matrix =np.load( adj )
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
atoms_list = ['<global>'] + atoms_list
|
283 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
284 |
+
#temp = np.ones((len(nums_list),len(nums_list)))
|
285 |
+
#temp[1:, 1:] = adjoin_matrix
|
286 |
+
#adjoin_matrix = (1-temp)*(-1e9)
|
287 |
+
|
288 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
289 |
+
temp[1:,1:] = adjoin_matrix
|
290 |
+
temp[np.where(temp == 0)]=-1e9
|
291 |
+
|
292 |
+
|
293 |
+
adjoin_matrix = temp
|
294 |
+
x = np.array(nums_list).astype('int64')
|
295 |
+
y = np.array([label]).astype('float32')
|
296 |
+
return x, adjoin_matrix,y
|
297 |
+
|
298 |
+
def tf_numerical_smiles(self, smiles,adj,label):
|
299 |
+
x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (smiles,adj,label), [tf.int64, tf.float32 ,tf.float32])
|
300 |
+
x.set_shape([None])
|
301 |
+
adjoin_matrix.set_shape([None,None])
|
302 |
+
y.set_shape([None])
|
303 |
+
return x, adjoin_matrix , y
|
304 |
+
|
305 |
+
class Graph_Regression(object):
|
306 |
+
def __init__(self,path,smiles_field=['0'],adj = ['1'], label_field=['2'],normalize=False,max_len=1000,addH=True):
|
307 |
+
if path.endswith('.txt') or path.endswith('.tsv'):
|
308 |
+
self.df = pd.read_csv(path.format('train3'),sep='\t')
|
309 |
+
self.dt = pd.read_csv(path.format('test3'),sep='\t')
|
310 |
+
#self.dv = pd.read_csv(path.format('val3'),sep='\t')
|
311 |
+
else:
|
312 |
+
self.df = pd.read_csv(path.format('train/train'))
|
313 |
+
self.dt = pd.read_csv(path.format('test/test'))
|
314 |
+
#self.dv = pd.read_csv(path.format('val3'))
|
315 |
+
self.smiles_field = smiles_field
|
316 |
+
self.adj = adj
|
317 |
+
self.label_field = label_field
|
318 |
+
self.vocab = str2num
|
319 |
+
self.devocab = num2str
|
320 |
+
#self.df = self.df[self.df[smiles_field].str.len()<=max_len]
|
321 |
+
self.addH = addH
|
322 |
+
if normalize:
|
323 |
+
self.max = self.df[self.label_field].max()
|
324 |
+
self.min = self.df[self.label_field].min()
|
325 |
+
self.df[self.label_field] = (self.df[self.label_field]-self.min)/(self.max-self.min)-0.5
|
326 |
+
self.value_range = self.max-self.min
|
327 |
+
|
328 |
+
|
329 |
+
def get_data(self):
|
330 |
+
train_data = self.df
|
331 |
+
|
332 |
+
test_data = self.dt
|
333 |
+
data2=test_data
|
334 |
+
#idx = train_data.sample(frac=0.9).index
|
335 |
+
# train_idx = []
|
336 |
+
# #idx = train_data.sample(frac=0.9).index
|
337 |
+
|
338 |
+
# train_idx.extend(idx)
|
339 |
+
# data1 = train_data[train_data.index.isin(train_idx)]
|
340 |
+
# data2 = train_data[~train_data.index.isin(train_idx)]
|
341 |
+
self.dataset1 = tf.data.Dataset.from_tensor_slices((train_data[self.smiles_field],train_data[self.adj], train_data[self.label_field]))
|
342 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).cache().padded_batch(64, padded_shapes=(
|
343 |
+
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]))).prefetch(100)
|
344 |
+
|
345 |
+
self.dataset2 = tf.data.Dataset.from_tensor_slices((test_data[self.smiles_field], test_data[self.adj],test_data[self.label_field]))
|
346 |
+
self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
|
347 |
+
tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([1]))).cache().prefetch(100)
|
348 |
+
|
349 |
+
self.dataset3 = tf.data.Dataset.from_tensor_slices((data2[self.smiles_field],test_data[self.adj], data2[self.label_field]))
|
350 |
+
self.dataset3 = self.dataset3.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
|
351 |
+
tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([1]))).cache().prefetch(100)
|
352 |
+
|
353 |
+
return self.dataset1,self.dataset2,self.dataset3
|
354 |
+
|
355 |
+
def numerical_smiles(self, atom,adj,label):
|
356 |
+
atom = np.array(atom)
|
357 |
+
atom = atom[0].decode()
|
358 |
+
|
359 |
+
atom = atom.replace('\n','')
|
360 |
+
|
361 |
+
atom = atom.replace('[',' ')
|
362 |
+
atom = atom.replace(']',' ')
|
363 |
+
atom = atom.split("'")
|
364 |
+
|
365 |
+
|
366 |
+
atoms_list = []
|
367 |
+
for i in atom:
|
368 |
+
if i not in [' ']:
|
369 |
+
atoms_list.append(i)
|
370 |
+
label = np.array(label)[0]
|
371 |
+
|
372 |
+
adj = np.array(adj)[0].decode()
|
373 |
+
|
374 |
+
adjoin_matrix =np.load( adj )
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
atoms_list = ['<global>'] + atoms_list
|
379 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
380 |
+
#temp = np.ones((len(nums_list),len(nums_list)))
|
381 |
+
#temp[1:, 1:] = adjoin_matrix
|
382 |
+
#adjoin_matrix = (1-temp)*(-1e9)
|
383 |
+
|
384 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
385 |
+
temp[1:,1:] = adjoin_matrix
|
386 |
+
temp[np.where(temp == 0)]=-1e9
|
387 |
+
|
388 |
+
|
389 |
+
adjoin_matrix = temp
|
390 |
+
x = np.array(nums_list).astype('int64')
|
391 |
+
y = np.array([label]).astype('float32')
|
392 |
+
return x, adjoin_matrix,y
|
393 |
+
|
394 |
+
def tf_numerical_smiles(self, smiles,adj,label):
|
395 |
+
x,adjoin_matrix,y = tf.py_function(self.numerical_smiles, (smiles,adj,label), [tf.int64, tf.float32 ,tf.float32])
|
396 |
+
x.set_shape([None])
|
397 |
+
adjoin_matrix.set_shape([None,None])
|
398 |
+
y.set_shape([None])
|
399 |
+
return x, adjoin_matrix , y
|
400 |
+
|
401 |
+
class Inference_Dataset(object):
|
402 |
+
def __init__(self,path,smiles_field='Smiles',addH=True):
|
403 |
+
if path.endswith('.txt') or path.endswith('.tsv'):
|
404 |
+
self.df = pd.read_csv(path,sep='\t')
|
405 |
+
else:
|
406 |
+
self.df = pd.read_csv(path)
|
407 |
+
self.smiles_field = smiles_field
|
408 |
+
self.vocab = str2num
|
409 |
+
self.devocab = num2str
|
410 |
+
self.addH = addH
|
411 |
+
|
412 |
+
def get_data(self):
|
413 |
+
|
414 |
+
data = self.df
|
415 |
+
|
416 |
+
train_idx = []
|
417 |
+
idx = data.sample(frac=0.9).index
|
418 |
+
|
419 |
+
train_idx.extend(idx)
|
420 |
+
|
421 |
+
data1 = data[data.index.isin(train_idx)]
|
422 |
+
data2 = data[~data.index.isin(train_idx)]
|
423 |
+
print(len(data1))
|
424 |
+
self.dataset1 = tf.data.Dataset.from_tensor_slices(data1[self.smiles_field].tolist())
|
425 |
+
self.dataset1 = self.dataset1.map(self.tf_numerical_smiles).padded_batch(1, padded_shapes=(
|
426 |
+
tf.TensorShape([None]),tf.TensorShape([None,None]), tf.TensorShape([None]) ,tf.TensorShape([None]))).prefetch(50)
|
427 |
+
print(self.dataset1)
|
428 |
+
self.dataset2 = tf.data.Dataset.from_tensor_slices(data2[self.smiles_field].tolist())
|
429 |
+
self.dataset2 = self.dataset2.map(self.tf_numerical_smiles).padded_batch(1, padded_shapes=(
|
430 |
+
tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None]),
|
431 |
+
tf.TensorShape([None]))).prefetch(50)
|
432 |
+
return self.dataset1, self.dataset2
|
433 |
+
|
434 |
+
def numerical_smiles(self, smiles):
|
435 |
+
smiles = smiles.numpy().decode()
|
436 |
+
atoms_list, adjoin_matrix = smiles2adjoins(smiles,explicit_hydrogens=self.addH)
|
437 |
+
print(atoms_list)
|
438 |
+
atoms_list = ['<global>'] + atoms_list
|
439 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
440 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
441 |
+
temp[1:,1:] = adjoin_matrix
|
442 |
+
temp[np.where(temp == 0)]=-1e9
|
443 |
+
adjoin_matrix = temp
|
444 |
+
choices = np.random.permutation(len(nums_list)-1)[:max(int(len(nums_list)*0.15),1)] + 1
|
445 |
+
y = np.array(nums_list).astype('int64')
|
446 |
+
|
447 |
+
x = np.array(nums_list).astype('int64')
|
448 |
+
|
449 |
+
return x, adjoin_matrix, [smiles],atoms_list
|
450 |
+
|
451 |
+
def tf_numerical_smiles(self, data):
|
452 |
+
# x,adjoin_matrix,y,weight = tf.py_function(self.balanced_numerical_smiles,
|
453 |
+
# [data], [tf.int64, tf.float32 ,tf.int64,tf.float32])
|
454 |
+
x, adjoin_matrix, y, weight = tf.py_function(self.numerical_smiles, [data],
|
455 |
+
[tf.int64, tf.float32, tf.int64, tf.float32])
|
456 |
+
smiles.set_shape([1])
|
457 |
+
atom_list.set_shape([None])
|
458 |
+
x.set_shape([None])
|
459 |
+
adjoin_matrix.set_shape([None,None])
|
460 |
+
y.set_shape([None])
|
461 |
+
weight.set_shape([None])
|
462 |
+
return x, adjoin_matrix,smiles,atom_list
|
463 |
+
|
464 |
+
class Inference_Dataset(object):
|
465 |
+
def __init__(self,sml_list,max_len=1000,addH=True):
|
466 |
+
self.vocab = str2num
|
467 |
+
self.devocab = num2str
|
468 |
+
self.sml_list = [i for i in sml_list if len(i)<max_len]
|
469 |
+
self.addH = addH
|
470 |
+
|
471 |
+
def get_data(self):
|
472 |
+
|
473 |
+
self.dataset = tf.data.Dataset.from_tensor_slices((self.sml_list,))
|
474 |
+
self.dataset = self.dataset.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
|
475 |
+
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]),tf.TensorShape([None]))).cache().prefetch(20)
|
476 |
+
|
477 |
+
return self.dataset
|
478 |
+
|
479 |
+
def numerical_smiles(self, smiles):
|
480 |
+
smiles_origin = smiles
|
481 |
+
smiles = smiles.numpy().decode()
|
482 |
+
atoms_list, adjoin_matrix = smiles2adjoins(smiles)
|
483 |
+
atoms_list = ['<global>'] + atoms_list
|
484 |
+
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
|
485 |
+
temp = np.ones((len(nums_list),len(nums_list)))
|
486 |
+
temp[1:,1:] = adjoin_matrix
|
487 |
+
adjoin_matrix = (1-temp)*(-1e9)
|
488 |
+
x = np.array(nums_list).astype('int64')
|
489 |
+
return x, adjoin_matrix,[smiles], atoms_list
|
490 |
+
|
491 |
+
def tf_numerical_smiles(self, smiles):
|
492 |
+
x,adjoin_matrix,smiles,atom_list = tf.py_function(self.numerical_smiles, [smiles], [tf.int64, tf.float32,tf.string, tf.string])
|
493 |
+
x.set_shape([None])
|
494 |
+
adjoin_matrix.set_shape([None,None])
|
495 |
+
smiles.set_shape([1])
|
496 |
+
atom_list.set_shape([None])
|
497 |
+
return x, adjoin_matrix,smiles,atom_list
|
model.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def gelu(x):
|
10 |
+
return 0.5 * x * (1.0 + tf.math.erf(x / tf.sqrt(2.)))
|
11 |
+
|
12 |
+
def scaled_dot_product_attention(q, k, v, mask,adjoin_matrix):
|
13 |
+
"""Calculate the attention weights.
|
14 |
+
q, k, v must have matching leading dimensions.
|
15 |
+
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
|
16 |
+
The mask has different shapes depending on its type(padding or look ahead)
|
17 |
+
but it must be broadcastable for addition.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
q: query shape == (..., seq_len_q, depth)
|
21 |
+
k: key shape == (..., seq_len_k, depth)
|
22 |
+
v: value shape == (..., seq_len_v, depth_v)
|
23 |
+
mask: Float tensor with shape broadcastable
|
24 |
+
to (..., seq_len_q, seq_len_k). Defaults to None.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
output, attention_weights
|
28 |
+
"""
|
29 |
+
|
30 |
+
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
|
31 |
+
|
32 |
+
# scale matmul_qk
|
33 |
+
dk = tf.cast(tf.shape(k)[-1], tf.float32)
|
34 |
+
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
|
35 |
+
|
36 |
+
# add the mask to the scaled tensor.
|
37 |
+
if mask is not None:
|
38 |
+
scaled_attention_logits += (mask * -1e9)
|
39 |
+
if adjoin_matrix is not None:
|
40 |
+
#adjoin_matrix1 =tf.where(adjoin_matrix>0,0.0,-1e9)
|
41 |
+
#scaled_attention_logits += adjoin_matrix1
|
42 |
+
#scaled_attention_logits = scaled_attention_logits * adjoin_matrix
|
43 |
+
scaled_attention_logits += adjoin_matrix
|
44 |
+
|
45 |
+
# softmax is normalized on the last axis (seq_len_k) so that the scores
|
46 |
+
# add up to 1.
|
47 |
+
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
|
48 |
+
|
49 |
+
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
|
50 |
+
|
51 |
+
return output, attention_weights
|
52 |
+
|
53 |
+
|
54 |
+
class MultiHeadAttention(tf.keras.layers.Layer):
|
55 |
+
def __init__(self, d_model, num_heads):
|
56 |
+
super(MultiHeadAttention, self).__init__()
|
57 |
+
self.num_heads = num_heads
|
58 |
+
self.d_model = d_model
|
59 |
+
|
60 |
+
assert d_model % self.num_heads == 0
|
61 |
+
|
62 |
+
self.depth = d_model // self.num_heads
|
63 |
+
|
64 |
+
self.wq = tf.keras.layers.Dense(d_model)
|
65 |
+
self.wk = tf.keras.layers.Dense(d_model)
|
66 |
+
self.wv = tf.keras.layers.Dense(d_model)
|
67 |
+
|
68 |
+
self.dense = tf.keras.layers.Dense(d_model)
|
69 |
+
|
70 |
+
def split_heads(self, x, batch_size):
|
71 |
+
"""Split the last dimension into (num_heads, depth).
|
72 |
+
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
|
73 |
+
"""
|
74 |
+
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
|
75 |
+
return tf.transpose(x, perm=[0, 2, 1, 3])
|
76 |
+
|
77 |
+
def call(self, v, k, q, mask,adjoin_matrix):
|
78 |
+
batch_size = tf.shape(q)[0]
|
79 |
+
|
80 |
+
q = self.wq(q) # (batch_size, seq_len, d_model)
|
81 |
+
k = self.wk(k) # (batch_size, seq_len, d_model)
|
82 |
+
v = self.wv(v) # (batch_size, seq_len, d_model)
|
83 |
+
|
84 |
+
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
|
85 |
+
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
|
86 |
+
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
|
87 |
+
|
88 |
+
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
|
89 |
+
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
|
90 |
+
scaled_attention, attention_weights = scaled_dot_product_attention(
|
91 |
+
q, k, v, mask,adjoin_matrix)
|
92 |
+
|
93 |
+
scaled_attention = tf.transpose(scaled_attention,
|
94 |
+
perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
|
95 |
+
|
96 |
+
concat_attention = tf.reshape(scaled_attention,
|
97 |
+
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
|
98 |
+
|
99 |
+
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
|
100 |
+
|
101 |
+
return output, attention_weights
|
102 |
+
|
103 |
+
def point_wise_feed_forward_network(d_model, dff):
|
104 |
+
return tf.keras.Sequential([
|
105 |
+
tf.keras.layers.Dense(dff, activation=gelu), # (batch_size, seq_len, dff)tf.keras.layers.LeakyReLU(0.01)
|
106 |
+
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
|
107 |
+
])
|
108 |
+
|
109 |
+
|
110 |
+
class EncoderLayer(tf.keras.layers.Layer):
|
111 |
+
def __init__(self, d_model, num_heads, dff, rate=0.1):
|
112 |
+
super(EncoderLayer, self).__init__()
|
113 |
+
|
114 |
+
self.mha = MultiHeadAttention(d_model, num_heads)
|
115 |
+
self.ffn = point_wise_feed_forward_network(d_model, dff)
|
116 |
+
|
117 |
+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
118 |
+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
119 |
+
|
120 |
+
self.dropout1 = tf.keras.layers.Dropout(rate)
|
121 |
+
self.dropout2 = tf.keras.layers.Dropout(rate)
|
122 |
+
|
123 |
+
def call(self, x, training, mask,adjoin_matrix):
|
124 |
+
attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix) # (batch_size, input_seq_len, d_model)
|
125 |
+
attn_output = self.dropout1(attn_output, training=training)
|
126 |
+
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
|
127 |
+
|
128 |
+
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
|
129 |
+
ffn_output = self.dropout2(ffn_output, training=training)
|
130 |
+
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
|
131 |
+
|
132 |
+
return out2,attention_weights
|
133 |
+
|
134 |
+
|
135 |
+
class Encoder(tf.keras.Model):
|
136 |
+
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
|
137 |
+
maximum_position_encoding, rate=0.1):
|
138 |
+
super(Encoder, self).__init__()
|
139 |
+
|
140 |
+
self.d_model = d_model
|
141 |
+
self.num_layers = num_layers
|
142 |
+
|
143 |
+
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
|
144 |
+
# self.pos_encoding = positional_encoding(maximum_position_encoding,
|
145 |
+
# self.d_model)
|
146 |
+
|
147 |
+
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
|
148 |
+
for _ in range(num_layers)]
|
149 |
+
|
150 |
+
self.dropout = tf.keras.layers.Dropout(rate)
|
151 |
+
|
152 |
+
def call(self, x, training, mask,adjoin_matrix):
|
153 |
+
seq_len = tf.shape(x)[1]
|
154 |
+
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
|
155 |
+
# adding embedding and position encoding.
|
156 |
+
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
|
157 |
+
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
|
158 |
+
|
159 |
+
x = self.dropout(x, training=training)
|
160 |
+
|
161 |
+
for i in range(self.num_layers):
|
162 |
+
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
|
163 |
+
return x # (batch_size, input_seq_len, d_model)
|
164 |
+
|
165 |
+
class Encoder_test(tf.keras.Model):
|
166 |
+
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
|
167 |
+
maximum_position_encoding, rate=0.1):
|
168 |
+
super(Encoder_test, self).__init__()
|
169 |
+
|
170 |
+
self.d_model = d_model
|
171 |
+
self.num_layers = num_layers
|
172 |
+
|
173 |
+
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
|
174 |
+
# self.pos_encoding = positional_encoding(maximum_position_encoding,
|
175 |
+
# self.d_model)
|
176 |
+
|
177 |
+
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
|
178 |
+
for _ in range(num_layers)]
|
179 |
+
|
180 |
+
self.dropout = tf.keras.layers.Dropout(rate)
|
181 |
+
|
182 |
+
def call(self, x, training, mask,adjoin_matrix):
|
183 |
+
seq_len = tf.shape(x)[1]
|
184 |
+
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
|
185 |
+
# adding embedding and position encoding.
|
186 |
+
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
|
187 |
+
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
|
188 |
+
# x += self.pos_encoding[:, :seq_len, :]
|
189 |
+
|
190 |
+
x = self.dropout(x, training=training)
|
191 |
+
attention_weights_list = []
|
192 |
+
xs = []
|
193 |
+
|
194 |
+
for i in range(self.num_layers):
|
195 |
+
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
|
196 |
+
attention_weights_list.append(attention_weights)
|
197 |
+
xs.append(x)
|
198 |
+
|
199 |
+
return x,attention_weights_list,xs
|
200 |
+
|
201 |
+
class BertModel_test(tf.keras.Model):
|
202 |
+
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size = 17,dropout_rate = 0.1):
|
203 |
+
super(BertModel_test, self).__init__()
|
204 |
+
self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
|
205 |
+
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
|
206 |
+
self.fc1 = tf.keras.layers.Dense(d_model, activation=gelu)
|
207 |
+
self.layernorm = tf.keras.layers.LayerNormalization(-1)
|
208 |
+
self.fc2 = tf.keras.layers.Dense(vocab_size)
|
209 |
+
def call(self,x,adjoin_matrix,mask,training=False):
|
210 |
+
x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
|
211 |
+
x = self.fc1(x)
|
212 |
+
x = self.layernorm(x)
|
213 |
+
x = self.fc2(x)
|
214 |
+
return x,att,xs
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
class BertModel(tf.keras.Model):
|
220 |
+
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size = 17,dropout_rate = 0.1):
|
221 |
+
super(BertModel, self).__init__()
|
222 |
+
self.encoder = Encoder(num_layers=num_layers,d_model=d_model,
|
223 |
+
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
|
224 |
+
self.fc1 = tf.keras.layers.Dense(d_model, activation=gelu)
|
225 |
+
self.layernorm = tf.keras.layers.LayerNormalization(-1)
|
226 |
+
self.fc2 = tf.keras.layers.Dense(vocab_size)
|
227 |
+
|
228 |
+
def call(self,x,adjoin_matrix,mask,training=False):
|
229 |
+
x = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
|
230 |
+
x = self.fc1(x)
|
231 |
+
x = self.layernorm(x)
|
232 |
+
x = self.fc2(x)
|
233 |
+
return x
|
234 |
+
|
235 |
+
|
236 |
+
class PredictModel(tf.keras.Model):
|
237 |
+
def __init__(self,num_layers = 8,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.1):
|
238 |
+
super(PredictModel, self).__init__()
|
239 |
+
self.encoder = Encoder(num_layers=num_layers,d_model=d_model,
|
240 |
+
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
|
241 |
+
|
242 |
+
self.fc1 = tf.keras.layers.Dense(256,activation=tf.keras.layers.LeakyReLU(0.25))
|
243 |
+
self.fc2 = tf.keras.layers.Dense(256,activation=tf.keras.layers.LeakyReLU(0.25))
|
244 |
+
self.dropout = tf.keras.layers.Dropout(dense_dropout)
|
245 |
+
self.fc3 = tf.keras.layers.Dense(1)
|
246 |
+
|
247 |
+
def call(self,x,adjoin_matrix,mask,training=False):
|
248 |
+
x = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
|
249 |
+
x = x[:,0,:]
|
250 |
+
x = self.fc1(x)
|
251 |
+
x = self.dropout(x,training=training)
|
252 |
+
x = self.fc2(x)
|
253 |
+
x = self.fc3(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
class PredictModel_test(tf.keras.Model):
|
259 |
+
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
|
260 |
+
super(PredictModel_test, self).__init__()
|
261 |
+
self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
|
262 |
+
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
|
263 |
+
|
264 |
+
self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
|
265 |
+
self.dropout = tf.keras.layers.Dropout(dense_dropout)
|
266 |
+
self.fc2 = tf.keras.layers.Dense(1)
|
267 |
+
|
268 |
+
def call(self,x,adjoin_matrix,mask,training=False):
|
269 |
+
x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
|
270 |
+
x = x[:, 0, :]
|
271 |
+
x = self.fc1(x)
|
272 |
+
x = self.dropout(x, training=training)
|
273 |
+
x = self.fc2(x)
|
274 |
+
return x,att,xs
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair
|
2 |
+
streamlit
|
3 |
+
streamlit-ketcher
|
4 |
+
tensorflow
|
5 |
+
pandas
|
6 |
+
rdkit
|
7 |
+
scikit-learn
|
8 |
+
matplotlib
|
9 |
+
|
10 |
+
|
utils.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Thu Jul 28 14:40:59 2022
|
4 |
+
|
5 |
+
@author: BM109X32G-10GPU-02
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from rdkit import Chem
|
13 |
+
from rdkit.Chem import AllChem
|
14 |
+
from rdkit.Chem import rdchem
|
15 |
+
|
16 |
+
from compound_constants import DAY_LIGHT_FG_SMARTS_LIST
|
17 |
+
|
18 |
+
|
19 |
+
def get_gasteiger_partial_charges(mol, n_iter=12):
|
20 |
+
"""
|
21 |
+
Calculates list of gasteiger partial charges for each atom in mol object.
|
22 |
+
Args:
|
23 |
+
mol: rdkit mol object.
|
24 |
+
n_iter(int): number of iterations. Default 12.
|
25 |
+
Returns:
|
26 |
+
list of computed partial charges for each atom.
|
27 |
+
"""
|
28 |
+
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
|
29 |
+
throwOnParamFailure=True)
|
30 |
+
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
|
31 |
+
mol.GetAtoms()]
|
32 |
+
return partial_charges
|
33 |
+
|
34 |
+
|
35 |
+
def create_standardized_mol_id(smiles):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
smiles: smiles sequence.
|
39 |
+
Returns:
|
40 |
+
inchi.
|
41 |
+
"""
|
42 |
+
if check_smiles_validity(smiles):
|
43 |
+
# remove stereochemistry
|
44 |
+
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
|
45 |
+
isomericSmiles=False)
|
46 |
+
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
|
47 |
+
|
48 |
+
if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
|
49 |
+
if '.' in smiles: # if multiple species, pick largest molecule
|
50 |
+
mol_species_list = split_rdkit_mol_obj(mol)
|
51 |
+
largest_mol = get_largest_mol(mol_species_list)
|
52 |
+
inchi = AllChem.MolToInchi(largest_mol)
|
53 |
+
else:
|
54 |
+
inchi = AllChem.MolToInchi(mol)
|
55 |
+
return inchi
|
56 |
+
else:
|
57 |
+
return
|
58 |
+
else:
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
def check_smiles_validity(smiles):
|
63 |
+
"""
|
64 |
+
Check whether the smile can't be converted to rdkit mol object.
|
65 |
+
"""
|
66 |
+
try:
|
67 |
+
m = Chem.MolFromSmiles(smiles)
|
68 |
+
if m:
|
69 |
+
return True
|
70 |
+
else:
|
71 |
+
return False
|
72 |
+
except Exception as e:
|
73 |
+
return False
|
74 |
+
|
75 |
+
|
76 |
+
def split_rdkit_mol_obj(mol):
|
77 |
+
"""
|
78 |
+
Split rdkit mol object containing multiple species or one species into a
|
79 |
+
list of mol objects or a list containing a single object respectively.
|
80 |
+
Args:
|
81 |
+
mol: rdkit mol object.
|
82 |
+
"""
|
83 |
+
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
|
84 |
+
smiles_list = smiles.split('.')
|
85 |
+
mol_species_list = []
|
86 |
+
for s in smiles_list:
|
87 |
+
if check_smiles_validity(s):
|
88 |
+
mol_species_list.append(AllChem.MolFromSmiles(s))
|
89 |
+
return mol_species_list
|
90 |
+
|
91 |
+
|
92 |
+
def get_largest_mol(mol_list):
|
93 |
+
"""
|
94 |
+
Given a list of rdkit mol objects, returns mol object containing the
|
95 |
+
largest num of atoms. If multiple containing largest num of atoms,
|
96 |
+
picks the first one.
|
97 |
+
Args:
|
98 |
+
mol_list(list): a list of rdkit mol object.
|
99 |
+
Returns:
|
100 |
+
the largest mol.
|
101 |
+
"""
|
102 |
+
num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
|
103 |
+
largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
|
104 |
+
return mol_list[largest_mol_idx]
|
105 |
+
|
106 |
+
def rdchem_enum_to_list(values):
|
107 |
+
"""values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
108 |
+
1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
109 |
+
2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
110 |
+
3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
|
111 |
+
"""
|
112 |
+
return [values[i] for i in range(len(values))]
|
113 |
+
|
114 |
+
|
115 |
+
def safe_index(alist, elem):
|
116 |
+
"""
|
117 |
+
Return index of element e in list l. If e is not present, return the last index
|
118 |
+
"""
|
119 |
+
try:
|
120 |
+
return alist.index(elem)
|
121 |
+
except ValueError:
|
122 |
+
return len(alist) - 1
|
123 |
+
|
124 |
+
|
125 |
+
def get_atom_feature_dims(list_acquired_feature_names):
|
126 |
+
""" tbd
|
127 |
+
"""
|
128 |
+
return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))
|
129 |
+
|
130 |
+
|
131 |
+
def get_bond_feature_dims(list_acquired_feature_names):
|
132 |
+
""" tbd
|
133 |
+
"""
|
134 |
+
list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
|
135 |
+
# +1 for self loop edges
|
136 |
+
return [_l + 1 for _l in list_bond_feat_dim]
|
137 |
+
|
138 |
+
|
139 |
+
class CompoundKit(object):
|
140 |
+
"""
|
141 |
+
CompoundKit
|
142 |
+
"""
|
143 |
+
atom_vocab_dict = {
|
144 |
+
"atomic_num": list(range(1, 119)) + ['misc'],
|
145 |
+
"chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
|
146 |
+
|
147 |
+
}
|
148 |
+
bond_vocab_dict = {
|
149 |
+
"bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
|
150 |
+
"bond_type": rdchem_enum_to_list(rdchem.BondType.values),
|
151 |
+
|
152 |
+
}
|
153 |
+
# float features
|
154 |
+
atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
|
155 |
+
# bond_float_feats= ["bond_length", "bond_angle"] # optional
|
156 |
+
|
157 |
+
### functional groups
|
158 |
+
day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
|
159 |
+
day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]
|
160 |
+
|
161 |
+
morgan_fp_N = 200
|
162 |
+
morgan2048_fp_N = 2048
|
163 |
+
maccs_fp_N = 167
|
164 |
+
|
165 |
+
period_table = Chem.GetPeriodicTable()
|
166 |
+
|
167 |
+
### atom
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def get_atom_value(atom, name):
|
171 |
+
"""get atom values"""
|
172 |
+
if name == 'atomic_num':
|
173 |
+
return atom.GetAtomicNum()
|
174 |
+
elif name == 'chiral_tag':
|
175 |
+
return atom.GetChiralTag()
|
176 |
+
elif name == 'degree':
|
177 |
+
return atom.GetDegree()
|
178 |
+
elif name == 'explicit_valence':
|
179 |
+
return atom.GetExplicitValence()
|
180 |
+
elif name == 'formal_charge':
|
181 |
+
return atom.GetFormalCharge()
|
182 |
+
elif name == 'hybridization':
|
183 |
+
return atom.GetHybridization()
|
184 |
+
elif name == 'implicit_valence':
|
185 |
+
return atom.GetImplicitValence()
|
186 |
+
elif name == 'is_aromatic':
|
187 |
+
return int(atom.GetIsAromatic())
|
188 |
+
elif name == 'mass':
|
189 |
+
return int(atom.GetMass())
|
190 |
+
elif name == 'total_numHs':
|
191 |
+
return atom.GetTotalNumHs()
|
192 |
+
elif name == 'num_radical_e':
|
193 |
+
return atom.GetNumRadicalElectrons()
|
194 |
+
elif name == 'atom_is_in_ring':
|
195 |
+
return int(atom.IsInRing())
|
196 |
+
elif name == 'valence_out_shell':
|
197 |
+
return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
|
198 |
+
else:
|
199 |
+
raise ValueError(name)
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def get_atom_feature_id(atom, name):
|
203 |
+
"""get atom features id"""
|
204 |
+
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
|
205 |
+
return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
def get_atom_feature_size(name):
|
209 |
+
"""get atom features size"""
|
210 |
+
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
|
211 |
+
return len(CompoundKit.atom_vocab_dict[name])
|
212 |
+
|
213 |
+
### bond
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def get_bond_value(bond, name):
|
217 |
+
"""get bond values"""
|
218 |
+
if name == 'bond_dir':
|
219 |
+
return bond.GetBondDir()
|
220 |
+
elif name == 'bond_type':
|
221 |
+
return bond.GetBondType()
|
222 |
+
elif name == 'is_in_ring':
|
223 |
+
return int(bond.IsInRing())
|
224 |
+
elif name == 'is_conjugated':
|
225 |
+
return int(bond.GetIsConjugated())
|
226 |
+
elif name == 'bond_stereo':
|
227 |
+
return bond.GetStereo()
|
228 |
+
else:
|
229 |
+
raise ValueError(name)
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def get_bond_feature_id(bond, name):
|
233 |
+
"""get bond features id"""
|
234 |
+
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
|
235 |
+
return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def get_bond_feature_size(name):
|
239 |
+
"""get bond features size"""
|
240 |
+
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
|
241 |
+
return len(CompoundKit.bond_vocab_dict[name])
|
242 |
+
|
243 |
+
### fingerprint
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def get_morgan_fingerprint(mol, radius=2):
|
247 |
+
"""get morgan fingerprint"""
|
248 |
+
nBits = CompoundKit.morgan_fp_N
|
249 |
+
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
|
250 |
+
return [int(b) for b in mfp.ToBitString()]
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def get_morgan2048_fingerprint(mol, radius=2):
|
254 |
+
"""get morgan2048 fingerprint"""
|
255 |
+
nBits = CompoundKit.morgan2048_fp_N
|
256 |
+
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
|
257 |
+
return [int(b) for b in mfp.ToBitString()]
|
258 |
+
|
259 |
+
@staticmethod
|
260 |
+
def get_maccs_fingerprint(mol):
|
261 |
+
"""get maccs fingerprint"""
|
262 |
+
fp = AllChem.GetMACCSKeysFingerprint(mol)
|
263 |
+
return [int(b) for b in fp.ToBitString()]
|
264 |
+
|
265 |
+
### functional groups
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def get_daylight_functional_group_counts(mol):
|
269 |
+
"""get daylight functional group counts"""
|
270 |
+
fg_counts = []
|
271 |
+
for fg_mol in CompoundKit.day_light_fg_mo_list:
|
272 |
+
sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
|
273 |
+
fg_counts.append(len(sub_structs))
|
274 |
+
return fg_counts
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def get_ring_size(mol):
|
278 |
+
"""return (N,6) list"""
|
279 |
+
rings = mol.GetRingInfo()
|
280 |
+
rings_info = []
|
281 |
+
for r in rings.AtomRings():
|
282 |
+
rings_info.append(r)
|
283 |
+
ring_list = []
|
284 |
+
for atom in mol.GetAtoms():
|
285 |
+
atom_result = []
|
286 |
+
for ringsize in range(3, 9):
|
287 |
+
num_of_ring_at_ringsize = 0
|
288 |
+
for r in rings_info:
|
289 |
+
if len(r) == ringsize and atom.GetIdx() in r:
|
290 |
+
num_of_ring_at_ringsize += 1
|
291 |
+
if num_of_ring_at_ringsize > 8:
|
292 |
+
num_of_ring_at_ringsize = 9
|
293 |
+
atom_result.append(num_of_ring_at_ringsize)
|
294 |
+
|
295 |
+
ring_list.append(atom_result)
|
296 |
+
return ring_list
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def atom_to_feat_vector(atom):
|
300 |
+
""" tbd """
|
301 |
+
atom_names = {
|
302 |
+
"atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
|
303 |
+
|
304 |
+
}
|
305 |
+
return atom_names
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def get_atom_names(mol):
|
309 |
+
"""get atom name list
|
310 |
+
TODO: to be remove in the future
|
311 |
+
"""
|
312 |
+
atom_features_dicts = []
|
313 |
+
Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
|
314 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
315 |
+
atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))
|
316 |
+
|
317 |
+
ring_list = CompoundKit.get_ring_size(mol)
|
318 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
319 |
+
atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
|
320 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
|
321 |
+
atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
|
322 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
|
323 |
+
atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
|
324 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
|
325 |
+
atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
|
326 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
|
327 |
+
atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
|
328 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
|
329 |
+
atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
|
330 |
+
CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])
|
331 |
+
|
332 |
+
return atom_features_dicts
|
333 |
+
|
334 |
+
@staticmethod
|
335 |
+
def check_partial_charge(atom):
|
336 |
+
"""tbd"""
|
337 |
+
pc = atom.GetDoubleProp('_GasteigerCharge')
|
338 |
+
if pc != pc:
|
339 |
+
# unsupported atom, replace nan with 0
|
340 |
+
pc = 0
|
341 |
+
if pc == float('inf'):
|
342 |
+
# max 4 for other atoms, set to 10 here if inf is get
|
343 |
+
pc = 10
|
344 |
+
return pc
|
345 |
+
|
346 |
+
|
347 |
+
class Compound3DKit(object):
|
348 |
+
"""the 3Dkit of Compound"""
|
349 |
+
@staticmethod
|
350 |
+
def get_atom_poses(mol, conf):
|
351 |
+
"""tbd"""
|
352 |
+
atom_poses = []
|
353 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
354 |
+
if atom.GetAtomicNum() == 0:
|
355 |
+
return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
|
356 |
+
pos = conf.GetAtomPosition(i)
|
357 |
+
atom_poses.append([pos.x, pos.y, pos.z])
|
358 |
+
return atom_poses
|
359 |
+
|
360 |
+
@staticmethod
|
361 |
+
def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
|
362 |
+
"""the atoms of mol will be changed in some cases."""
|
363 |
+
try:
|
364 |
+
new_mol = Chem.AddHs(mol)
|
365 |
+
res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
|
366 |
+
### MMFF generates multiple conformations
|
367 |
+
res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
|
368 |
+
#new_mol = Chem.RemoveHs(new_mol)
|
369 |
+
index = np.argmin([x[1] for x in res])
|
370 |
+
energy = res[index][1]
|
371 |
+
conf = new_mol.GetConformer(id=int(index))
|
372 |
+
except:
|
373 |
+
new_mol = Chem.AddHs(mol)
|
374 |
+
AllChem.Compute2DCoords(new_mol)
|
375 |
+
energy = 0
|
376 |
+
conf = new_mol.GetConformer()
|
377 |
+
|
378 |
+
atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
|
379 |
+
if return_energy:
|
380 |
+
return new_mol, atom_poses, energy
|
381 |
+
else:
|
382 |
+
return new_mol, atom_poses
|
383 |
+
|
384 |
+
@staticmethod
|
385 |
+
def get_2d_atom_poses(mol):
|
386 |
+
"""get 2d atom poses"""
|
387 |
+
AllChem.Compute2DCoords(mol)
|
388 |
+
conf = mol.GetConformer()
|
389 |
+
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
|
390 |
+
return atom_poses
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def get_bond_lengths(edges, atom_poses):
|
394 |
+
"""get bond lengths"""
|
395 |
+
bond_lengths = []
|
396 |
+
for src_node_i, tar_node_j in edges:
|
397 |
+
bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
|
398 |
+
bond_lengths = np.array(bond_lengths, 'float32')
|
399 |
+
return bond_lengths
|
400 |
+
|
401 |
+
@staticmethod
|
402 |
+
def get_superedge_angles(edges, atom_poses, dir_type='HT'):
|
403 |
+
"""get superedge angles"""
|
404 |
+
def _get_vec(atom_poses, edge):
|
405 |
+
return atom_poses[edge[1]] - atom_poses[edge[0]]
|
406 |
+
def _get_angle(vec1, vec2):
|
407 |
+
norm1 = np.linalg.norm(vec1)
|
408 |
+
norm2 = np.linalg.norm(vec2)
|
409 |
+
if norm1 == 0 or norm2 == 0:
|
410 |
+
return 0
|
411 |
+
vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
|
412 |
+
vec2 = vec2 / (norm2 + 1e-5)
|
413 |
+
angle = np.arccos(np.dot(vec1, vec2))
|
414 |
+
return angle
|
415 |
+
|
416 |
+
E = len(edges)
|
417 |
+
edge_indices = np.arange(E)
|
418 |
+
super_edges = []
|
419 |
+
bond_angles = []
|
420 |
+
bond_angle_dirs = []
|
421 |
+
for tar_edge_i in range(E):
|
422 |
+
tar_edge = edges[tar_edge_i]
|
423 |
+
if dir_type == 'HT':
|
424 |
+
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
|
425 |
+
|
426 |
+
elif dir_type == 'HH':
|
427 |
+
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
|
428 |
+
else:
|
429 |
+
raise ValueError(dir_type)
|
430 |
+
for src_edge_i in src_edge_indices:
|
431 |
+
if src_edge_i == tar_edge_i:
|
432 |
+
continue
|
433 |
+
src_edge = edges[src_edge_i]
|
434 |
+
src_vec = _get_vec(atom_poses, src_edge)
|
435 |
+
tar_vec = _get_vec(atom_poses, tar_edge)
|
436 |
+
super_edges.append([src_edge_i, tar_edge_i])
|
437 |
+
angle = _get_angle(src_vec, tar_vec)
|
438 |
+
bond_angles.append(angle)
|
439 |
+
bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T
|
440 |
+
|
441 |
+
if len(super_edges) == 0:
|
442 |
+
super_edges = np.zeros([0, 2], 'int64')
|
443 |
+
bond_angles = np.zeros([0,], 'float32')
|
444 |
+
else:
|
445 |
+
super_edges = np.array(super_edges, 'int64')
|
446 |
+
bond_angles = np.array(bond_angles, 'float32')
|
447 |
+
return super_edges, bond_angles, bond_angle_dirs
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
def new_smiles_to_graph_data(smiles, **kwargs):
|
452 |
+
"""
|
453 |
+
Convert smiles to graph data.
|
454 |
+
"""
|
455 |
+
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
|
456 |
+
if mol is None:
|
457 |
+
return None
|
458 |
+
data = new_mol_to_graph_data(mol)
|
459 |
+
return data
|
460 |
+
|
461 |
+
|
462 |
+
def new_mol_to_graph_data(mol):
|
463 |
+
"""
|
464 |
+
mol_to_graph_data
|
465 |
+
Args:
|
466 |
+
atom_features: Atom features.
|
467 |
+
edge_features: Edge features.
|
468 |
+
morgan_fingerprint: Morgan fingerprint.
|
469 |
+
functional_groups: Functional groups.
|
470 |
+
"""
|
471 |
+
if len(mol.GetAtoms()) == 0:
|
472 |
+
return None
|
473 |
+
|
474 |
+
atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
|
475 |
+
bond_id_names = list(CompoundKit.bond_vocab_dict.keys())
|
476 |
+
|
477 |
+
data = {}
|
478 |
+
|
479 |
+
### atom features
|
480 |
+
data = {name: [] for name in atom_id_names}
|
481 |
+
|
482 |
+
raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
|
483 |
+
for atom_feat in raw_atom_feat_dicts:
|
484 |
+
for name in atom_id_names:
|
485 |
+
data[name].append(atom_feat[name])
|
486 |
+
|
487 |
+
### bond and bond features
|
488 |
+
for name in bond_id_names:
|
489 |
+
data[name] = []
|
490 |
+
data['edges'] = []
|
491 |
+
|
492 |
+
for bond in mol.GetBonds():
|
493 |
+
i = bond.GetBeginAtomIdx()
|
494 |
+
j = bond.GetEndAtomIdx()
|
495 |
+
# i->j and j->i
|
496 |
+
data['edges'] += [(i, j), (j, i)]
|
497 |
+
for name in bond_id_names:
|
498 |
+
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
|
499 |
+
data[name] += [bond_feature_id] * 2
|
500 |
+
|
501 |
+
#### self loop
|
502 |
+
N = len(data[atom_id_names[0]])
|
503 |
+
for i in range(N):
|
504 |
+
data['edges'] += [(i, i)]
|
505 |
+
for name in bond_id_names:
|
506 |
+
bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
|
507 |
+
data[name] += [bond_feature_id] * N
|
508 |
+
|
509 |
+
### make ndarray and check length
|
510 |
+
for name in list(CompoundKit.atom_vocab_dict.keys()):
|
511 |
+
data[name] = np.array(data[name], 'int64')
|
512 |
+
for name in CompoundKit.atom_float_names:
|
513 |
+
data[name] = np.array(data[name], 'float32')
|
514 |
+
for name in bond_id_names:
|
515 |
+
data[name] = np.array(data[name], 'int64')
|
516 |
+
data['edges'] = np.array(data['edges'], 'int64')
|
517 |
+
|
518 |
+
### morgan fingerprint
|
519 |
+
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
|
520 |
+
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
|
521 |
+
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
|
522 |
+
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
|
523 |
+
return data
|
524 |
+
|
525 |
+
|
526 |
+
def mol_to_graph_data(mol):
|
527 |
+
"""
|
528 |
+
mol_to_graph_data
|
529 |
+
Args:
|
530 |
+
atom_features: Atom features.
|
531 |
+
edge_features: Edge features.
|
532 |
+
morgan_fingerprint: Morgan fingerprint.
|
533 |
+
functional_groups: Functional groups.
|
534 |
+
"""
|
535 |
+
if len(mol.GetAtoms()) == 0:
|
536 |
+
return None
|
537 |
+
|
538 |
+
atom_id_names = [
|
539 |
+
"atomic_num"
|
540 |
+
]
|
541 |
+
bond_id_names = [
|
542 |
+
"bond_dir", "bond_type"
|
543 |
+
]
|
544 |
+
|
545 |
+
data = {}
|
546 |
+
for name in atom_id_names:
|
547 |
+
data[name] = []
|
548 |
+
data['mass'] = []
|
549 |
+
for name in bond_id_names:
|
550 |
+
data[name] = []
|
551 |
+
data['edges'] = []
|
552 |
+
|
553 |
+
### atom features
|
554 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
555 |
+
if atom.GetAtomicNum() == 0:
|
556 |
+
return None
|
557 |
+
for name in atom_id_names:
|
558 |
+
|
559 |
+
data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
|
560 |
+
data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)
|
561 |
+
|
562 |
+
### bond features
|
563 |
+
for bond in mol.GetBonds():
|
564 |
+
|
565 |
+
i = bond.GetBeginAtomIdx()
|
566 |
+
j = bond.GetEndAtomIdx()
|
567 |
+
# i->j and j->i
|
568 |
+
data['edges'] += [(i, j), (j, i)]
|
569 |
+
for name in bond_id_names:
|
570 |
+
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
|
571 |
+
data[name] += [bond_feature_id] * 2
|
572 |
+
num_atoms = mol.GetNumAtoms()
|
573 |
+
atoms_list = []
|
574 |
+
for i in range(num_atoms):
|
575 |
+
atom = mol.GetAtomWithIdx(i)
|
576 |
+
atoms_list.append(atom.GetSymbol())
|
577 |
+
### self loop (+2)
|
578 |
+
|
579 |
+
|
580 |
+
N = len(data[atom_id_names[0]])
|
581 |
+
for i in range(N):
|
582 |
+
data['edges'] += [(i, i)]
|
583 |
+
for name in bond_id_names:
|
584 |
+
bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
|
585 |
+
data[name] += [bond_feature_id] * N
|
586 |
+
|
587 |
+
### check whether edge exists
|
588 |
+
if len(data['edges']) == 0: # mol has no bonds
|
589 |
+
for name in bond_id_names:
|
590 |
+
data[name] = np.zeros((0,), dtype="int64")
|
591 |
+
data['edges'] = np.zeros((0, 2), dtype="int64")
|
592 |
+
|
593 |
+
### make ndarray and check length
|
594 |
+
for name in atom_id_names:
|
595 |
+
data[name] = np.array(data[name], 'int64')
|
596 |
+
data['mass'] = np.array(data['mass'], 'float32')
|
597 |
+
for name in bond_id_names:
|
598 |
+
data[name] = np.array(data[name], 'int64')
|
599 |
+
data['edges'] = np.array(data['edges'], 'int64')
|
600 |
+
data['atoms'] = np.array(atoms_list)
|
601 |
+
### morgan fingerprint
|
602 |
+
#data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
|
603 |
+
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
|
604 |
+
#data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
|
605 |
+
#data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
|
606 |
+
#return data['bonds_dir'],data['adj_angle']
|
607 |
+
return data
|
608 |
+
|
609 |
+
|
610 |
+
def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
|
611 |
+
"""
|
612 |
+
mol: rdkit molecule
|
613 |
+
dir_type: direction type for bond_angle grpah
|
614 |
+
"""
|
615 |
+
if len(mol.GetAtoms()) == 0:
|
616 |
+
return None
|
617 |
+
|
618 |
+
data = mol_to_graph_data(mol)
|
619 |
+
|
620 |
+
data['atom_pos'] = np.array(atom_poses, 'float32')
|
621 |
+
data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
|
622 |
+
# BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
|
623 |
+
# Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
|
624 |
+
# data['BondAngleGraph_edges'] = BondAngleGraph_edges
|
625 |
+
# data['bond_angle'] = np.array(bond_angles, 'float32')
|
626 |
+
data['adj_node'] = gen_adj(len(data['atoms']),data['edges'],data['bond_length'])
|
627 |
+
# data['adj_edge'] = gen_adj(len(data['bond_dir']),data['BondAngleGraph_edges'],data['bond_angle'])
|
628 |
+
return data['atoms'], data['adj_node']
|
629 |
+
|
630 |
+
|
631 |
+
def mol_to_geognn_graph_data_MMFF3d(smiles):
|
632 |
+
"""tbd"""
|
633 |
+
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
|
634 |
+
if len(mol.GetAtoms()) <= 400:
|
635 |
+
mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
|
636 |
+
else:
|
637 |
+
atom_poses = Compound3DKit.get_2d_atom_poses(mol)
|
638 |
+
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
|
639 |
+
|
640 |
+
|
641 |
+
def mol_to_geognn_graph_data_raw3d(mol):
|
642 |
+
"""tbd"""
|
643 |
+
atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
|
644 |
+
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
|
645 |
+
def gen_adj(shape,edges,length):
|
646 |
+
|
647 |
+
adj=edges
|
648 |
+
e = shape
|
649 |
+
ones = np.zeros([e,e])
|
650 |
+
|
651 |
+
#for i in range(e):
|
652 |
+
for i in range (len(length)):
|
653 |
+
if adj[i,0] != adj[i,1]:
|
654 |
+
ones[adj[i,0],adj[i,1]]=format(float(length[i] ), '.3f')
|
655 |
+
|
656 |
+
return ones
|
657 |
+
|
658 |
+
|
659 |
+
if __name__ == "__main__":
|
660 |
+
import pandas as pd
|
661 |
+
from tqdm import tqdm
|
662 |
+
f = pd.read_csv (r"data/reg/train3.csv")
|
663 |
+
re = []
|
664 |
+
pce = f['PCE']
|
665 |
+
for ind,smile in enumerate ( f.iloc[:,1]):
|
666 |
+
print(ind)
|
667 |
+
atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
|
668 |
+
np.save('data/reg/train/adj'+str(ind)+'.npy',np.array(adj))
|
669 |
+
re.append([atom,'data/reg/train/adj'+str(ind)+'.npy',pce[ind] ])
|
670 |
+
r = pd.DataFrame(re)
|
671 |
+
r.to_csv('data/reg/train/train.csv')
|
672 |
+
re = []
|
673 |
+
|
674 |
+
f = pd.read_csv(r'data/reg/test3.csv')
|
675 |
+
re = []
|
676 |
+
pce = f['PCE']
|
677 |
+
|
678 |
+
for ind,smile in enumerate ( f.iloc[:,1]):
|
679 |
+
print(ind)
|
680 |
+
atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
|
681 |
+
np.save('data/reg/test/adj'+str(ind)+'.npy',np.array(adj))
|
682 |
+
re.append([atom,'data/reg/test/adj'+str(ind)+'.npy',pce[ind] ])
|
683 |
+
r = pd.DataFrame(re)
|
684 |
+
r.to_csv('data/reg/test/test.csv')
|
685 |
+
|
686 |
+
f = pd.read_csv(r'val.csv')
|
687 |
+
re = []
|
688 |
+
pce = f['PCE']
|
689 |
+
|
690 |
+
for ind,smile in enumerate ( f.iloc[:,1]):
|
691 |
+
print(ind)
|
692 |
+
atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
|
693 |
+
np.save('data/reg/val/adj'+str(ind)+'.npy',np.array(adj))
|
694 |
+
re.append([atom,'data/reg/val/adj'+str(ind)+'.npy',pce[ind] ])
|
695 |
+
r = pd.DataFrame(re)
|
696 |
+
r.to_csv('data/reg/val/val.csv')
|