Spaces:
Sleeping
Sleeping
Upload 275 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +12 -0
- __pycache__/demo.cpython-38.pyc +0 -0
- app.py +33 -0
- chroma/CONTRIBUTING.md +11 -0
- chroma/Dockerfile +24 -0
- chroma/LICENSE.txt +202 -0
- chroma/README.md +255 -0
- chroma/Untitled.ipynb +6 -0
- chroma/assets/LiberationSans-Regular.ttf +0 -0
- chroma/assets/chroma_logo.svg +85 -0
- chroma/assets/chroma_logo_outline.svg +109 -0
- chroma/assets/conditioners.png +0 -0
- chroma/assets/lattice.png +3 -0
- chroma/assets/logo.png +0 -0
- chroma/assets/proteins.png +3 -0
- chroma/assets/refolding.png +3 -0
- chroma/chroma/__init__.py +19 -0
- chroma/chroma/__pycache__/__init__.cpython-38.pyc +0 -0
- chroma/chroma/constants/__init__.py +16 -0
- chroma/chroma/constants/__pycache__/__init__.cpython-38.pyc +0 -0
- chroma/chroma/constants/__pycache__/geometry.cpython-38.pyc +0 -0
- chroma/chroma/constants/__pycache__/named_models.cpython-38.pyc +0 -0
- chroma/chroma/constants/__pycache__/sequence.cpython-38.pyc +0 -0
- chroma/chroma/constants/geometry.py +558 -0
- chroma/chroma/constants/named_models.py +54 -0
- chroma/chroma/constants/sequence.py +112 -0
- chroma/chroma/data/__init__.py +19 -0
- chroma/chroma/data/__pycache__/__init__.cpython-38.pyc +0 -0
- chroma/chroma/data/__pycache__/protein.cpython-38.pyc +0 -0
- chroma/chroma/data/__pycache__/system.cpython-38.pyc +0 -0
- chroma/chroma/data/__pycache__/xcs.cpython-38.pyc +0 -0
- chroma/chroma/data/protein.py +513 -0
- chroma/chroma/data/system.py +0 -0
- chroma/chroma/data/xcs.py +121 -0
- chroma/chroma/layers/__init__.py +18 -0
- chroma/chroma/layers/__pycache__/__init__.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/attention.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/basic.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/complexity.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/conv.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/graph.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/linalg.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/norm.cpython-38.pyc +0 -0
- chroma/chroma/layers/__pycache__/sde.cpython-38.pyc +0 -0
- chroma/chroma/layers/attention.py +347 -0
- chroma/chroma/layers/basic.py +467 -0
- chroma/chroma/layers/complexity.py +201 -0
- chroma/chroma/layers/conv.py +58 -0
- chroma/chroma/layers/graph.py +1126 -0
- chroma/chroma/layers/linalg.py +98 -0
.gitattributes
CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
chroma/assets/lattice.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
chroma/assets/proteins.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
chroma/assets/refolding.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
chroma/notebooks/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
chroma/notebooks/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
chroma/notebooks/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
chroma/tests/_streamlit/demoapp/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
chroma/tests/_streamlit/demoapp/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
44 |
+
chroma/tests/_streamlit/demoapp/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
45 |
+
output/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
46 |
+
output/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
47 |
+
output/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
|
__pycache__/demo.cpython-38.pyc
ADDED
Binary file (10.4 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import demo
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Chroma Demos",
|
6 |
+
page_icon="🧬",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="expanded",
|
9 |
+
)
|
10 |
+
|
11 |
+
st.title("Demos for Chroma")
|
12 |
+
|
13 |
+
# sidebar
|
14 |
+
st.sidebar.header("Demo Config")
|
15 |
+
|
16 |
+
# 创建字典映射demo
|
17 |
+
demoDict={
|
18 |
+
"getProtein":demo.getProteinDemo,
|
19 |
+
"complexSample":demo.complexSampleDemo,
|
20 |
+
"symmetricSample":demo.symmetricSampleDemo,
|
21 |
+
"shapeSample":demo.shapeSampleDemo,
|
22 |
+
"foldSample":demo.foldSampleDemo,
|
23 |
+
"ssSample":demo.ssSampleDemo,
|
24 |
+
"substructureSample":demo.substructureSampleDemo,
|
25 |
+
|
26 |
+
}
|
27 |
+
# 在侧边栏中添加一个选择框,用于选择demo
|
28 |
+
selected_branch = st.sidebar.selectbox("Select demo", list(demoDict.keys()))
|
29 |
+
style=st.sidebar.selectbox("Select style:Can be 'stick', 'sphere', 'cross','cartoon'",('stick', 'sphere', 'cross','cartoon'),key='style')
|
30 |
+
resn=st.sidebar.selectbox("Select display resn:PDB resn labels:['ALA','ARG','LYS','THR','TRP','TYR','VAL']",('','ALA','ARG','LYS','THR','TRP','TYR','VAL'),key='resn')
|
31 |
+
|
32 |
+
# 执行选定分支对应的函数
|
33 |
+
demoDict[selected_branch](style,resn)
|
chroma/CONTRIBUTING.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code contributions
|
2 |
+
|
3 |
+
We welcome contributions to the Chroma code base, including new conditioners, integrators, patches, bug fixes, and others.
|
4 |
+
|
5 |
+
Note that your contributions will be governed by the Apache 2.0 license, meaning that you will be giving us permission to use your contributed code under the conditions specified in the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) (also available in [LICENSE.txt](LICENSE.txt)).
|
6 |
+
|
7 |
+
## How to Contribute
|
8 |
+
|
9 |
+
Please use GitHub pull requests to contribute code. See
|
10 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
11 |
+
information on using pull requests. We will try to monitor incoming requests with some regularity, but cannot promise a specific timeframe within which we will review your request.
|
chroma/Dockerfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.3.1-devel-ubuntu20.04
|
2 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
3 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
4 |
+
build-essential \
|
5 |
+
cmake \
|
6 |
+
git \
|
7 |
+
curl \
|
8 |
+
ca-certificates \
|
9 |
+
libjpeg-dev \
|
10 |
+
libpng-dev && \
|
11 |
+
rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
WORKDIR /tmp
|
14 |
+
|
15 |
+
RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
16 |
+
chmod +x ~/miniconda.sh && \
|
17 |
+
~/miniconda.sh -b -p /opt/conda && \
|
18 |
+
rm ~/miniconda.sh
|
19 |
+
RUN /opt/conda/bin/conda create --name chroma python=3.9.7
|
20 |
+
RUN /opt/conda/envs/chroma/bin/pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
|
21 |
+
WORKDIR /workspace
|
22 |
+
COPY . .
|
23 |
+
RUN /opt/conda/envs/chroma/bin/pip install .
|
24 |
+
ENV PATH /opt/conda/envs/chroma/bin:$PATH
|
chroma/LICENSE.txt
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
chroma/README.md
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<img src="assets/chroma_logo_outline.svg" width="280">
|
2 |
+
|
3 |
+
[**Get Started**](#get-started)
|
4 |
+
| [**Sampling**](#sampling)
|
5 |
+
| [**Design**](#design)
|
6 |
+
| [**Conditioners**](#conditioners)
|
7 |
+
| [**License**](#license)
|
8 |
+
|
9 |
+
Chroma is a generative model for designing proteins **programmatically**.
|
10 |
+
|
11 |
+
Protein space is complex and hard to navigate. With Chroma, protein design problems are represented in terms of [composable building blocks](#conditioners) from which diverse, [all-atom protein structures can be automatically generated](#sampling). As a joint model of structure and sequence, Chroma can also be used for common protein modeling tasks such as [generating sequences given backbones](#design), packing side-chains, and scoring designs.
|
12 |
+
|
13 |
+
We provide protein conditioners for a variety of constraints, including substructure, symmetry, shape, and neural-network predictions of some protein classes and annotations. We also provide an API for [creating your own conditioners](#conditioners-api) in a few lines of code.
|
14 |
+
|
15 |
+
Internally, Chroma uses diffusion modeling, equivariant graph neural networks, and conditional random fields to efficiently sample all-atom structures with a complexity that is sub-quadratic in the number of residues. It can generate large complexes in a few minutes on a commodity GPU. You can read more about Chroma, including biophysical and crystallographic validation of some early designs, in our paper, [*Illuminating protein space with a programmable generative model*. Nature 2023](https://doi.org/10.1038/s41586-023-06728-8).
|
16 |
+
|
17 |
+
<div align="center">
|
18 |
+
<img src="assets/proteins.png" alt="Generated protein examples" width="700px" align="middle"/>
|
19 |
+
</div>
|
20 |
+
|
21 |
+
## Get Started
|
22 |
+
> **Note:** An API key is required to download and use the pretrained model weights. It can be obtained [here](https://chroma-weights.generatebiomedicines.com/).
|
23 |
+
|
24 |
+
|
25 |
+
**Colab Notebooks**. The quickest way to get started with Chroma is our Colab notebooks, which provide starting points for a variety of use cases in a preconfigured, in-browser environment
|
26 |
+
|
27 |
+
* [Chroma Quickstart](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaDemo.ipynb): GUI notebook demonstrating unconditional and conditional generation of proteins with Chroma.
|
28 |
+
* [Chroma API Tutorial](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaAPI.ipynb): Code notebook demonstrating protein I/O, sampling, and design configuration directly in `python`.
|
29 |
+
* [Chroma Conditioner API Tutorial](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaConditioners.ipynb): A deeper dive under the hood for implementing new Chroma [Conditioners](#conditioner-api).
|
30 |
+
|
31 |
+
**PyPi package**.You can install the latest release of Chroma with:
|
32 |
+
```
|
33 |
+
pip install generate-chroma
|
34 |
+
```
|
35 |
+
|
36 |
+
**Install latest Chroma from github**
|
37 |
+
```
|
38 |
+
git clone https://github.com/generatebio/chroma.git
|
39 |
+
pip install -e chroma # use `-e` for it to be editable locally.
|
40 |
+
```
|
41 |
+
|
42 |
+
## Sampling
|
43 |
+
**Unconditional monomer**. We provide a unified entry point to both unconditional and conditional protein design with the `Chroma.sample()` method. When no conditioners are specified, we can sample a simple 200-amino acid monomeric protein with
|
44 |
+
```python
|
45 |
+
from chroma import Chroma
|
46 |
+
|
47 |
+
chroma = Chroma()
|
48 |
+
protein = chroma.sample(chain_lengths=[200])
|
49 |
+
|
50 |
+
protein.to("sample.cif")
|
51 |
+
display(protein)
|
52 |
+
```
|
53 |
+
|
54 |
+
Generally, `Chroma.sample()` takes as input design hyperparameters and [Conditioners](#conditioners) and outputs `Protein` objects representing the all-atom structures of protein systems which can be loaded to and from disk in PDB or mmCIF formats.
|
55 |
+
|
56 |
+
**Unconditional complex**. To sample a complex instead of a monomer, we can simply do
|
57 |
+
```python
|
58 |
+
from chroma import Chroma
|
59 |
+
|
60 |
+
chroma = Chroma()
|
61 |
+
protein = chroma.sample(chain_lengths=[100, 200])
|
62 |
+
|
63 |
+
protein.to("sample-complex.cif")
|
64 |
+
```
|
65 |
+
|
66 |
+
**Conditional complex**. We can further customize sampling towards design objectives via [Conditioners](#conditioners) and sampling hyperparameters. For example, to sample a C3-symmetric homo-trimer with 100 residues per monomer, we can do
|
67 |
+
|
68 |
+
```python
|
69 |
+
from chroma import Chroma, conditioners
|
70 |
+
|
71 |
+
chroma = Chroma()
|
72 |
+
conditioner = conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=2)
|
73 |
+
protein = chroma.sample(
|
74 |
+
chain_lengths=[100],
|
75 |
+
conditioner=conditioner,
|
76 |
+
langevin_factor=8,
|
77 |
+
inverse_temperature=8,
|
78 |
+
sde_func="langevin",
|
79 |
+
potts_symmetry_order=conditioner.potts_symmetry_order)
|
80 |
+
|
81 |
+
protein.to("sample-C3.cif")
|
82 |
+
```
|
83 |
+
|
84 |
+
Because compositions of conditioners are conditioners, even relatively complex design problems can follow this basic usage pattern. See the [demo notebooks](#get-started) and docstrings for more information on hyperparameters, conditioners, and starting points.
|
85 |
+
|
86 |
+
## Design
|
87 |
+
**Robust design**. Chroma is a joint model of sequence and structure that uses a common graph neural network base architecture to parameterize both backbone generation and conditional sequence and sidechain generation. These sequence and sidechain decoders are *diffusion-aware* in the sense that they have been trained to predict sequence and side chain not just for natural structures at diffusion time $t=0$ but also on noisy structures at all diffusion times $t \in [0,1]$. As a result, the $t$ hyperpameter of the design network provides a kind of tunable robustness via **diffusion augmentation** in we trade off between how much the model attempts to design the backbone *exactly* as specified (e.g. $t=0.0$) versus *robust* design within a small neighborhood of nearby backbone conformations (e.g. $t=0.5$).
|
88 |
+
|
89 |
+
While all results presented in the Chroma [publication](https://doi.org/10.1038/s41586-023-06728-8) were done with **exact design** at $t=0.0$, we have found **robust design** at times near $t=0.5$ frequently improves one-shot refolding while incurring only minor, often Ångstrom-scale, relaxation adjustments to target backbones. When we compare the performance of these two design modes on our set of 50,000 unconditional backbones that were analyzed in the paper, we see very large improvements in refolding across both [AlphaFold](https://github.com/google-deepmind/alphafold) and [ESMFold](https://github.com/facebookresearch/esm) that stratifies well across protein length, percent helicity, or similarity to a known structure (See Chroma [Supplementary Figure 14](https://doi.org/10.1038/s41586-023-06728-8) for further context).
|
90 |
+
|
91 |
+
|
92 |
+
<div align="center">
|
93 |
+
<img src="./assets/refolding.png" alt="alt text" width="700px" align="middle"/>
|
94 |
+
</div></br>
|
95 |
+
|
96 |
+
The value of diffusion time conditioning $t$ can be set via the `design_t` parameter in `Chroma.sample` and `Chroma.design`. We find that for generated structures, $t = 0.5$ produces highly robust refolding results and is, therefore, the default setting. For experimentally-precise structures, $t = 0.0$ may be more appropriate, and values in between may provide a useful tradeoff between these two regimes.
|
97 |
+
|
98 |
+
**Design *a la carte***. Chroma's design network can be accessed separately to design, redesign, and pack arbitrary protein systems. Here we load a protein from the PDB and redesign as
|
99 |
+
```python
|
100 |
+
# Redesign a Protein
|
101 |
+
from chroma import Protein, Chroma
|
102 |
+
chroma = Chroma()
|
103 |
+
|
104 |
+
protein = Protein('1GFP')
|
105 |
+
protein = chroma.design(protein)
|
106 |
+
|
107 |
+
protein.to("1GFP-redesign.cif")
|
108 |
+
```
|
109 |
+
|
110 |
+
Clamped sub-sequence redesign is also available and compatible with a built-in selection algebra, along with position- and mutation-specific mask constraints as
|
111 |
+
```python
|
112 |
+
# Redesign a Protein
|
113 |
+
from chroma import Protein, Chroma
|
114 |
+
chroma = Chroma()
|
115 |
+
|
116 |
+
protein = Protein('my_favorite_protein.cif') # PDB is fine too
|
117 |
+
protein = chroma.design(protein, design_selection="resid 20-50 around 5.0") # 5 angstrom bubble around indices 20-50
|
118 |
+
|
119 |
+
protein.to("my_favorite_protein_redesign.cif")
|
120 |
+
```
|
121 |
+
|
122 |
+
We provide more examples of design in the [demo notebooks](#get-started).
|
123 |
+
|
124 |
+
## Conditioners
|
125 |
+
Protein design with Chroma is **programmable**. Our `Conditioner` framework allows for automatic conditional sampling under arbitrary compositions of protein specifications, which can come in the forms of restraints (biasing the distribution of states) or constraints (directly restrict the domain of underlying sampling process); see Supplementary Appendix M in [our paper](https://doi.org/10.1038/s41586-023-06728-8). We have pre-defined multiple conditioners, including for controlling substructure, symmetry, shape, semantics, and natural-language prompts (see `chroma.layers.structure.conditioners`), which can be used in arbitrary combinations.
|
126 |
+
|
127 |
+
<div align="center">
|
128 |
+
|
129 |
+
| Conditioner | Class(es) in [`chroma.conditioners`](chroma/layers/structure/conditioners.py) | Example applications |
|
130 |
+
|----------|----------|----------|
|
131 |
+
| Symmetry constraint | `SymmetryConditioner`, `ScrewConditioner` | Large symmetric assemblies |
|
132 |
+
| Substructure constraint | `SubstructureConditioner` | Substructure grafting, scaffold enforcement |
|
133 |
+
| Shape restraint | `ShapeConditioner` | Molecular shape control |
|
134 |
+
| Secondary structure | `ProClassConditioner` | Secondary-structure specification |
|
135 |
+
| Domain classification | `ProClassConditioner` | Specification of class, such as Pfam, CATH, or Taxonomy |
|
136 |
+
| Text caption | `ProCapConditioner` | Natural language prompting |
|
137 |
+
| Sequence | `SubsequenceConditioner` | Subsequence constraints. |
|
138 |
+
|
139 |
+
</div>
|
140 |
+
|
141 |
+
**How it works**. The central idea of Conditioners is *composable state transformations*, where each Conditioner is a function that modifies the state and/or energy of a protein system in a differentiable way ([Supplementary Appendix M](https://doi.org/10.1038/s41586-023-06728-8)). For example, to encode symmetry as a *constraint* we can take as input the assymetric unit and tesselate it according to the desired symmetry group to output a protein system that is symmetric by construction. To encode something like a neural network restraint, we can adjust the total system energy by the negative log probability of the target condition. For both of these, we add on the diffusion energy to the output of the Conditioner(s) and then backpropagate the total energy through all intermediate transformations to compute the unconstrained forces that are compatible with generic sampling SDE such as annealed Langevin Dynamics.
|
142 |
+
|
143 |
+
We schematize this overall Conditioners framework below.
|
144 |
+
<div align="center">
|
145 |
+
<img src="./assets/conditioners.png" alt="alt text" width="600px" align="middle"/><br>
|
146 |
+
<figcaption><i>The <code>Conditioner</code> class is the composable building block of protein design with Chroma.</i></figcaption>
|
147 |
+
</div>
|
148 |
+
|
149 |
+
#### Conditioner API
|
150 |
+
It is simple to develop new conditioners. A `Conditioner` is a Pytorch `nn.Module` which takes in the system state - i.e. the structure, energy, and diffusion time - and outputs potentially updated structures and energies as
|
151 |
+
|
152 |
+
```python
|
153 |
+
|
154 |
+
class Conditioner(torch.nn.Module):
|
155 |
+
"""A composable function for parameterizing protein design problems.
|
156 |
+
"""
|
157 |
+
def __init__(self, *args, **kwargs):
|
158 |
+
super().__init__()
|
159 |
+
# Setup your conditioner's hyperparameters
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
X: torch.Tensor, # Input coordinates
|
164 |
+
C: torch.LongTensor, # Input chain map (for complexes)
|
165 |
+
O: torch.Tensor, # Input sequence (one-hot, not used)
|
166 |
+
U: torch.Tensor, # Input energy (one-hot, not used)
|
167 |
+
t: Union[torch.Tensor, float], # Diffusion time
|
168 |
+
):
|
169 |
+
# Update the state, e.g. map from an unconstrained to constrained manifold
|
170 |
+
X_update, C_update = update_state(X, C, t)
|
171 |
+
|
172 |
+
# Update the energy, e.g. add a restraint potential
|
173 |
+
U_update = U + update_energy(X, C, t)
|
174 |
+
return X_update, C_update, O, U_update, t
|
175 |
+
```
|
176 |
+
Roughly speaking, `Conditioner`s are composable by construction because their input and output type signatures are matched (i.e. they are an endomorphism). So we also simply build conditioners from conditioners by "stacking" them much as we would with traditional neural network layer developemnt. With the final `Conditioner` as an input, `Chroma.sample()` will then leverage Pytorch's automatic differentiation facilities to automaticallly furnish a diffusion-annealed MCMC sampling algorithm to sample with this conditioner (We note this isn't magic and taking care to scale and parameterize appropriately is [important](#note-on-conditioners)).
|
177 |
+
|
178 |
+
##### A minimal Conditioner: 2D lattice symmetry
|
179 |
+
The code snippet below shows how in a few lines of code we can add a conditioner that stipulates the generation of a 2D crystal-like object, where generated proteins are arrayed in an `M x N` rectangular lattice.
|
180 |
+
|
181 |
+
```python
|
182 |
+
import torch
|
183 |
+
from chroma.models import Chroma
|
184 |
+
from chroma.layers.structure import conditioners
|
185 |
+
|
186 |
+
class Lattice2DConditioner(conditioners.Conditioner):
|
187 |
+
def __init__(self, M, N, cell):
|
188 |
+
super().__init__()
|
189 |
+
# Setup the coordinates of a 2D lattice
|
190 |
+
self.order = M*N
|
191 |
+
x = torch.arange(M) * cell[0]
|
192 |
+
y = torch.arange(N) * cell[1]
|
193 |
+
xx, yy = torch.meshgrid(x, y, indexing="ij")
|
194 |
+
dX = torch.stack([xx.flatten(), yy.flatten(), torch.zeros(M * N)], dim=1)
|
195 |
+
self.register_buffer("dX", dX)
|
196 |
+
|
197 |
+
def forward(self, X, C, O, U, t):
|
198 |
+
# Tesselate the unit cell on the lattice
|
199 |
+
X = (X[:,None,...] + self.dX[None,:,None,None]).reshape(1, -1, 4, 3)
|
200 |
+
C = torch.cat([C + C.unique().max() * i for i in range(self.dX.shape[0])], dim=1)
|
201 |
+
# Average the gradient across the group (simplifies force scaling)
|
202 |
+
X.register_hook(lambda gradX: gradX / self.order)
|
203 |
+
return X, C, O, U, t
|
204 |
+
|
205 |
+
chroma = Chroma().cuda()
|
206 |
+
conditioner = Lattice2DConditioner(M=3, N=4, cell=[20., 15.]).cuda()
|
207 |
+
protein = chroma.sample(
|
208 |
+
chain_lengths=[70], conditioner=conditioner, sde_func='langevin',
|
209 |
+
potts_symmetry_order=conditioner.order
|
210 |
+
)
|
211 |
+
|
212 |
+
protein.to_CIF("lattice_protein.cif")
|
213 |
+
```
|
214 |
+
|
215 |
+
<div align="center">
|
216 |
+
<img src="./assets/lattice.png" alt="alt text" width="700px" align="middle"/>
|
217 |
+
</div>
|
218 |
+
|
219 |
+
#### Note on Conditioners
|
220 |
+
|
221 |
+
An attractive aspect of this conditioner framework is that it is very general, enabling both constraints (which involve operations on $x$) and restraints (which amount to changes to $U$). At the same time, generation under restraints can still be (and often is) challenging, as the resulting effective energy landscape can become arbitrarily rugged and difficult to integrate. We therefore advise caution when using and developing new conditioners or conditioner combinations. We find that inspecting diffusition trajectories (including unconstrained and denoised trajectories, $\hat{x}_t$ and $\tilde{x}_t$) can be a good tool for identifying integration challenges and defining either better conditioner forms or better sampling regimes.
|
222 |
+
|
223 |
+
## Citing Chroma
|
224 |
+
|
225 |
+
If you use Chroma in your research, please cite:
|
226 |
+
|
227 |
+
J. B. Ingraham, M. Baranov, Z. Costello, K. W. Barber, W. Wang, A. Ismail, V. Frappier, D. M. Lord, C. Ng-Thow-Hing, E. R. Van Vlack, S. Tie, V. Xue, S. C. Cowles, A. Leung, J. V. Rodrigues, C. L. Morales-Perez, A. M. Ayoub, R. Green, K. Puentes, F. Oplinger, N. V. Panwar, F. Obermeyer, A. R. Root, A. L. Beam, F. J. Poelwijk, and G. Grigoryan, "Illuminating protein space with a programmable generative model", *Nature*, 2023 (10.1038/s41586-023-06728-8).
|
228 |
+
|
229 |
+
```bibtex
|
230 |
+
@Article{Chroma2023,
|
231 |
+
author = {Ingraham, John B. and Baranov, Max and Costello, Zak and Barber, Karl W. and Wang, Wujie and Ismail, Ahmed and Frappier, Vincent and Lord, Dana M. and Ng-Thow-Hing, Christopher and Van Vlack, Erik R. and Tie, Shan and Xue, Vincent and Cowles, Sarah C. and Leung, Alan and Rodrigues, Jo\~{a}o V. and Morales-Perez, Claudio L. and Ayoub, Alex M. and Green, Robin and Puentes, Katherine and Oplinger, Frank and Panwar, Nishant V. and Obermeyer, Fritz and Root, Adam R. and Beam, Andrew L. and Poelwijk, Frank J. and Grigoryan, Gevorg},
|
232 |
+
journal = {Nature},
|
233 |
+
title = {Illuminating protein space with a programmable generative model},
|
234 |
+
year = {2023},
|
235 |
+
volume = {},
|
236 |
+
number = {},
|
237 |
+
pages = {},
|
238 |
+
doi = {10.1038/s41586-023-06728-8}
|
239 |
+
}
|
240 |
+
```
|
241 |
+
|
242 |
+
## Acknowledgements
|
243 |
+
The Chroma codebase is the work of many contributers at Generate Biomedicines. We would like to acknowledge: Ahmed Ismail, Alan Witmer, Alex Ramos, Alexander Bock, Ameya Harmalkar, Brinda Monian, Craig Mackenzie, Dan Luu, David Moore, Frank Oplinger, Fritz Obermeyer, George Kent-Scheller, Gevorg Grigoryan, Jacob Feala, James Lucas, Jenhan Tao, John Ingraham, Martin Jankowiak, Max Baranov, Meghan Franklin, Mick Ward, Rudraksh Tuwani, Ryan Nelson, Shan Tie, Vincent Frappier, Vincent Xue, William Wolfe-McGuire, Wujie Wang, Zak Costello, Zander Harteveld.
|
244 |
+
|
245 |
+
## License
|
246 |
+
|
247 |
+
Copyright Generate Biomedicines, Inc.
|
248 |
+
|
249 |
+
### Chroma Code License
|
250 |
+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this code except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.
|
251 |
+
|
252 |
+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. See the License for the specific language governing permissions and limitations under the License.
|
253 |
+
|
254 |
+
### Model Weights License
|
255 |
+
Chroma weights are freely available to academic researchers and non-profit entities who accept and agree to be bound under the terms of the Chroma Parameters License. Please visit the [weights download page](https://chroma-weights.generatebiomedicines.com/) for more information. If you are not eligible to use the Chroma Parameters under the terms of the provided License or if you would like to share the Chroma Parameters and/or otherwise use the Chroma Parameters beyond the scope of the rights granted in the License (including for commercial purposes), you may contact the Licensor at: [email protected].
|
chroma/Untitled.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
chroma/assets/LiberationSans-Regular.ttf
ADDED
Binary file (139 kB). View file
|
|
chroma/assets/chroma_logo.svg
ADDED
|
chroma/assets/chroma_logo_outline.svg
ADDED
|
chroma/assets/conditioners.png
ADDED
![]() |
chroma/assets/lattice.png
ADDED
![]() |
Git LFS Details
|
chroma/assets/logo.png
ADDED
![]() |
chroma/assets/proteins.png
ADDED
![]() |
Git LFS Details
|
chroma/assets/refolding.png
ADDED
![]() |
Git LFS Details
|
chroma/chroma/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
__version__ = "1.0.0"
|
16 |
+
from chroma.data.protein import Protein
|
17 |
+
from chroma.layers.structure import conditioners
|
18 |
+
from chroma.models.chroma import Chroma
|
19 |
+
from chroma.utility import api
|
chroma/chroma/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (365 Bytes). View file
|
|
chroma/chroma/constants/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from chroma.constants.geometry import AA_GEOMETRY
|
16 |
+
from chroma.constants.sequence import *
|
chroma/chroma/constants/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (248 Bytes). View file
|
|
chroma/chroma/constants/__pycache__/geometry.cpython-38.pyc
ADDED
Binary file (6.97 kB). View file
|
|
chroma/chroma/constants/__pycache__/named_models.cpython-38.pyc
ADDED
Binary file (1.24 kB). View file
|
|
chroma/chroma/constants/__pycache__/sequence.cpython-38.pyc
ADDED
Binary file (2.06 kB). View file
|
|
chroma/chroma/constants/geometry.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dictionary containing ideal internal coordinates and chi angle assignments
|
16 |
+
for building amino acid 3D coordinates"""
|
17 |
+
from typing import Dict
|
18 |
+
|
19 |
+
AA_GEOMETRY: Dict[str, dict] = {
|
20 |
+
"ALA": {
|
21 |
+
"atoms": ["CB"],
|
22 |
+
"chi_indices": [],
|
23 |
+
"parents": [["N", "C", "CA"]],
|
24 |
+
"types": {"C": "C", "CA": "CT1", "CB": "CT3", "N": "NH1", "O": "O"},
|
25 |
+
"z-angles": [111.09],
|
26 |
+
"z-dihedrals": [123.23],
|
27 |
+
"z-lengths": [1.55],
|
28 |
+
},
|
29 |
+
"ARG": {
|
30 |
+
"atoms": ["CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
|
31 |
+
"chi_indices": [1, 2, 3, 4],
|
32 |
+
"parents": [
|
33 |
+
["N", "C", "CA"],
|
34 |
+
["N", "CA", "CB"],
|
35 |
+
["CA", "CB", "CG"],
|
36 |
+
["CB", "CG", "CD"],
|
37 |
+
["CG", "CD", "NE"],
|
38 |
+
["CD", "NE", "CZ"],
|
39 |
+
["NH1", "NE", "CZ"],
|
40 |
+
],
|
41 |
+
"types": {
|
42 |
+
"C": "C",
|
43 |
+
"CA": "CT1",
|
44 |
+
"CB": "CT2",
|
45 |
+
"CD": "CT2",
|
46 |
+
"CG": "CT2",
|
47 |
+
"CZ": "C",
|
48 |
+
"N": "NH1",
|
49 |
+
"NE": "NC2",
|
50 |
+
"NH1": "NC2",
|
51 |
+
"NH2": "NC2",
|
52 |
+
"O": "O",
|
53 |
+
},
|
54 |
+
"z-angles": [112.26, 115.95, 114.01, 107.09, 123.05, 118.06, 122.14],
|
55 |
+
"z-dihedrals": [123.64, 180.0, 180.0, 180.0, 180.0, 180.0, 178.64],
|
56 |
+
"z-lengths": [1.56, 1.55, 1.54, 1.5, 1.34, 1.33, 1.33],
|
57 |
+
},
|
58 |
+
"ASN": {
|
59 |
+
"atoms": ["CB", "CG", "OD1", "ND2"],
|
60 |
+
"chi_indices": [1, 2],
|
61 |
+
"parents": [
|
62 |
+
["N", "C", "CA"],
|
63 |
+
["N", "CA", "CB"],
|
64 |
+
["CA", "CB", "CG"],
|
65 |
+
["OD1", "CB", "CG"],
|
66 |
+
],
|
67 |
+
"types": {
|
68 |
+
"C": "C",
|
69 |
+
"CA": "CT1",
|
70 |
+
"CB": "CT2",
|
71 |
+
"CG": "CC",
|
72 |
+
"N": "NH1",
|
73 |
+
"ND2": "NH2",
|
74 |
+
"O": "O",
|
75 |
+
"OD1": "O",
|
76 |
+
},
|
77 |
+
"z-angles": [113.04, 114.3, 122.56, 116.15],
|
78 |
+
"z-dihedrals": [121.18, 180.0, 180.0, -179.19],
|
79 |
+
"z-lengths": [1.56, 1.53, 1.23, 1.35],
|
80 |
+
},
|
81 |
+
"ASP": {
|
82 |
+
"atoms": ["CB", "CG", "OD1", "OD2"],
|
83 |
+
"chi_indices": [1, 2],
|
84 |
+
"parents": [
|
85 |
+
["N", "C", "CA"],
|
86 |
+
["N", "CA", "CB"],
|
87 |
+
["CA", "CB", "CG"],
|
88 |
+
["OD1", "CB", "CG"],
|
89 |
+
],
|
90 |
+
"types": {
|
91 |
+
"C": "C",
|
92 |
+
"CA": "CT1",
|
93 |
+
"CB": "CT2A",
|
94 |
+
"CG": "CC",
|
95 |
+
"N": "NH1",
|
96 |
+
"O": "O",
|
97 |
+
"OD1": "OC",
|
98 |
+
"OD2": "OC",
|
99 |
+
},
|
100 |
+
"z-angles": [114.1, 112.6, 117.99, 117.7],
|
101 |
+
"z-dihedrals": [122.33, 180.0, 180.0, -170.23],
|
102 |
+
"z-lengths": [1.56, 1.52, 1.26, 1.25],
|
103 |
+
},
|
104 |
+
"CYS": {
|
105 |
+
"atoms": ["CB", "SG"],
|
106 |
+
"chi_indices": [1],
|
107 |
+
"parents": [["N", "C", "CA"], ["N", "CA", "CB"]],
|
108 |
+
"types": {"C": "C", "CA": "CT1", "CB": "CT2", "N": "NH1", "O": "O", "SG": "S"},
|
109 |
+
"z-angles": [111.98, 113.87],
|
110 |
+
"z-dihedrals": [121.79, 180.0],
|
111 |
+
"z-lengths": [1.56, 1.84],
|
112 |
+
},
|
113 |
+
"GLN": {
|
114 |
+
"atoms": ["CB", "CG", "CD", "OE1", "NE2"],
|
115 |
+
"chi_indices": [1, 2, 3],
|
116 |
+
"parents": [
|
117 |
+
["N", "C", "CA"],
|
118 |
+
["N", "CA", "CB"],
|
119 |
+
["CA", "CB", "CG"],
|
120 |
+
["CB", "CG", "CD"],
|
121 |
+
["OE1", "CG", "CD"],
|
122 |
+
],
|
123 |
+
"types": {
|
124 |
+
"C": "C",
|
125 |
+
"CA": "CT1",
|
126 |
+
"CB": "CT2",
|
127 |
+
"CD": "CC",
|
128 |
+
"CG": "CT2",
|
129 |
+
"N": "NH1",
|
130 |
+
"NE2": "NH2",
|
131 |
+
"O": "O",
|
132 |
+
"OE1": "O",
|
133 |
+
},
|
134 |
+
"z-angles": [111.68, 115.52, 112.5, 121.52, 116.84],
|
135 |
+
"z-dihedrals": [121.91, 180.0, 180.0, 180.0, 179.57],
|
136 |
+
"z-lengths": [1.55, 1.55, 1.53, 1.23, 1.35],
|
137 |
+
},
|
138 |
+
"GLU": {
|
139 |
+
"atoms": ["CB", "CG", "CD", "OE1", "OE2"],
|
140 |
+
"chi_indices": [1, 2, 3],
|
141 |
+
"parents": [
|
142 |
+
["N", "C", "CA"],
|
143 |
+
["N", "CA", "CB"],
|
144 |
+
["CA", "CB", "CG"],
|
145 |
+
["CB", "CG", "CD"],
|
146 |
+
["OE1", "CG", "CD"],
|
147 |
+
],
|
148 |
+
"types": {
|
149 |
+
"C": "C",
|
150 |
+
"CA": "CT1",
|
151 |
+
"CB": "CT2A",
|
152 |
+
"CD": "CC",
|
153 |
+
"CG": "CT2",
|
154 |
+
"N": "NH1",
|
155 |
+
"O": "O",
|
156 |
+
"OE1": "OC",
|
157 |
+
"OE2": "OC",
|
158 |
+
},
|
159 |
+
"z-angles": [111.71, 115.69, 115.73, 114.99, 120.08],
|
160 |
+
"z-dihedrals": [121.9, 180.0, 180.0, 180.0, -179.1],
|
161 |
+
"z-lengths": [1.55, 1.56, 1.53, 1.26, 1.25],
|
162 |
+
},
|
163 |
+
"GLY": {
|
164 |
+
"atoms": [],
|
165 |
+
"chi_indices": [],
|
166 |
+
"parents": [],
|
167 |
+
"types": {"C": "C", "CA": "CT2", "N": "NH1", "O": "O"},
|
168 |
+
"z-angles": [],
|
169 |
+
"z-dihedrals": [],
|
170 |
+
"z-lengths": [],
|
171 |
+
},
|
172 |
+
"HIS": {
|
173 |
+
"atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
174 |
+
"chi_indices": [1, 2],
|
175 |
+
"parents": [
|
176 |
+
["N", "C", "CA"],
|
177 |
+
["N", "CA", "CB"],
|
178 |
+
["CA", "CB", "CG"],
|
179 |
+
["ND1", "CB", "CG"],
|
180 |
+
["CB", "CG", "ND1"],
|
181 |
+
["CB", "CG", "CD2"],
|
182 |
+
],
|
183 |
+
"types": {
|
184 |
+
"C": "C",
|
185 |
+
"CA": "CT1",
|
186 |
+
"CB": "CT2",
|
187 |
+
"CD2": "CPH1",
|
188 |
+
"CE1": "CPH2",
|
189 |
+
"CG": "CPH1",
|
190 |
+
"N": "NH1",
|
191 |
+
"ND1": "NR1",
|
192 |
+
"NE2": "NR2",
|
193 |
+
"O": "O",
|
194 |
+
},
|
195 |
+
"z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03],
|
196 |
+
"z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99],
|
197 |
+
"z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38],
|
198 |
+
},
|
199 |
+
"HSD": {
|
200 |
+
"atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
201 |
+
"chi_indices": [1, 2],
|
202 |
+
"parents": [
|
203 |
+
["N", "C", "CA"],
|
204 |
+
["N", "CA", "CB"],
|
205 |
+
["CA", "CB", "CG"],
|
206 |
+
["ND1", "CB", "CG"],
|
207 |
+
["CB", "CG", "ND1"],
|
208 |
+
["CB", "CG", "CD2"],
|
209 |
+
],
|
210 |
+
"types": {
|
211 |
+
"C": "C",
|
212 |
+
"CA": "CT1",
|
213 |
+
"CB": "CT2",
|
214 |
+
"CD2": "CPH1",
|
215 |
+
"CE1": "CPH2",
|
216 |
+
"CG": "CPH1",
|
217 |
+
"N": "NH1",
|
218 |
+
"ND1": "NR1",
|
219 |
+
"NE2": "NR2",
|
220 |
+
"O": "O",
|
221 |
+
},
|
222 |
+
"z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03],
|
223 |
+
"z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99],
|
224 |
+
"z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38],
|
225 |
+
},
|
226 |
+
"HSE": {
|
227 |
+
"atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
228 |
+
"chi_indices": [],
|
229 |
+
"parents": [
|
230 |
+
["N", "C", "CA"],
|
231 |
+
["N", "CA", "CB"],
|
232 |
+
["CA", "CB", "CG"],
|
233 |
+
["ND1", "CB", "CG"],
|
234 |
+
["CB", "CG", "ND1"],
|
235 |
+
["CB", "CG", "CD2"],
|
236 |
+
],
|
237 |
+
"types": {
|
238 |
+
"C": "C",
|
239 |
+
"CA": "CT1",
|
240 |
+
"CB": "CT2",
|
241 |
+
"CD2": "CPH1",
|
242 |
+
"CE1": "CPH2",
|
243 |
+
"CG": "CPH1",
|
244 |
+
"N": "NH1",
|
245 |
+
"ND1": "NR2",
|
246 |
+
"NE2": "NR1",
|
247 |
+
"O": "O",
|
248 |
+
},
|
249 |
+
"z-angles": [111.67, 116.94, 120.17, 129.71, 105.2, 105.8],
|
250 |
+
"z-dihedrals": [123.52, 180.0, 90.0, -178.26, -179.2, 178.66],
|
251 |
+
"z-lengths": [1.56, 1.51, 1.39, 1.36, 1.32, 1.38],
|
252 |
+
},
|
253 |
+
"HSP": {
|
254 |
+
"atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
255 |
+
"chi_indices": [],
|
256 |
+
"parents": [
|
257 |
+
["N", "C", "CA"],
|
258 |
+
["N", "CA", "CB"],
|
259 |
+
["CA", "CB", "CG"],
|
260 |
+
["ND1", "CB", "CG"],
|
261 |
+
["CB", "CG", "ND1"],
|
262 |
+
["CB", "CG", "CD2"],
|
263 |
+
],
|
264 |
+
"types": {
|
265 |
+
"C": "C",
|
266 |
+
"CA": "CT1",
|
267 |
+
"CB": "CT2A",
|
268 |
+
"CD2": "CPH1",
|
269 |
+
"CE1": "CPH2",
|
270 |
+
"CG": "CPH1",
|
271 |
+
"N": "NH1",
|
272 |
+
"ND1": "NR3",
|
273 |
+
"NE2": "NR3",
|
274 |
+
"O": "O",
|
275 |
+
},
|
276 |
+
"z-angles": [109.38, 114.18, 122.94, 128.93, 108.9, 106.93],
|
277 |
+
"z-dihedrals": [125.13, 180.0, 90.0, -165.26, -167.62, 167.13],
|
278 |
+
"z-lengths": [1.55, 1.52, 1.37, 1.35, 1.33, 1.37],
|
279 |
+
},
|
280 |
+
"ILE": {
|
281 |
+
"atoms": ["CB", "CG1", "CG2", "CD1"],
|
282 |
+
"chi_indices": [1, 3],
|
283 |
+
"parents": [
|
284 |
+
["N", "C", "CA"],
|
285 |
+
["N", "CA", "CB"],
|
286 |
+
["CG1", "CA", "CB"],
|
287 |
+
["CA", "CB", "CG1"],
|
288 |
+
],
|
289 |
+
"types": {
|
290 |
+
"C": "C",
|
291 |
+
"CA": "CT1",
|
292 |
+
"CB": "CT1",
|
293 |
+
"CD": "CT3",
|
294 |
+
"CG1": "CT2",
|
295 |
+
"CG2": "CT3",
|
296 |
+
"N": "NH1",
|
297 |
+
"O": "O",
|
298 |
+
},
|
299 |
+
"z-angles": [112.93, 113.63, 113.93, 114.09],
|
300 |
+
"z-dihedrals": [124.22, 180.0, -130.04, 180.0],
|
301 |
+
"z-lengths": [1.57, 1.55, 1.55, 1.54],
|
302 |
+
},
|
303 |
+
"LEU": {
|
304 |
+
"atoms": ["CB", "CG", "CD1", "CD2"],
|
305 |
+
"chi_indices": [1, 2],
|
306 |
+
"parents": [
|
307 |
+
["N", "C", "CA"],
|
308 |
+
["N", "CA", "CB"],
|
309 |
+
["CA", "CB", "CG"],
|
310 |
+
["CD1", "CB", "CG"],
|
311 |
+
],
|
312 |
+
"types": {
|
313 |
+
"C": "C",
|
314 |
+
"CA": "CT1",
|
315 |
+
"CB": "CT2",
|
316 |
+
"CD1": "CT3",
|
317 |
+
"CD2": "CT3",
|
318 |
+
"CG": "CT1",
|
319 |
+
"N": "NH1",
|
320 |
+
"O": "O",
|
321 |
+
},
|
322 |
+
"z-angles": [112.12, 117.46, 110.48, 112.57],
|
323 |
+
"z-dihedrals": [121.52, 180.0, 180.0, 120.0],
|
324 |
+
"z-lengths": [1.55, 1.55, 1.54, 1.54],
|
325 |
+
},
|
326 |
+
"LYS": {
|
327 |
+
"atoms": ["CB", "CG", "CD", "CE", "NZ"],
|
328 |
+
"chi_indices": [1, 2, 3, 4],
|
329 |
+
"parents": [
|
330 |
+
["N", "C", "CA"],
|
331 |
+
["N", "CA", "CB"],
|
332 |
+
["CA", "CB", "CG"],
|
333 |
+
["CB", "CG", "CD"],
|
334 |
+
["CG", "CD", "CE"],
|
335 |
+
],
|
336 |
+
"types": {
|
337 |
+
"C": "C",
|
338 |
+
"CA": "CT1",
|
339 |
+
"CB": "CT2",
|
340 |
+
"CD": "CT2",
|
341 |
+
"CE": "CT2",
|
342 |
+
"CG": "CT2",
|
343 |
+
"N": "NH1",
|
344 |
+
"NZ": "NH3",
|
345 |
+
"O": "O",
|
346 |
+
},
|
347 |
+
"z-angles": [111.36, 115.76, 113.28, 112.33, 110.46],
|
348 |
+
"z-dihedrals": [122.23, 180.0, 180.0, 180.0, 180.0],
|
349 |
+
"z-lengths": [1.56, 1.54, 1.54, 1.53, 1.46],
|
350 |
+
},
|
351 |
+
"MET": {
|
352 |
+
"atoms": ["CB", "CG", "SD", "CE"],
|
353 |
+
"chi_indices": [1, 2, 3],
|
354 |
+
"parents": [
|
355 |
+
["N", "C", "CA"],
|
356 |
+
["N", "CA", "CB"],
|
357 |
+
["CA", "CB", "CG"],
|
358 |
+
["CB", "CG", "SD"],
|
359 |
+
],
|
360 |
+
"types": {
|
361 |
+
"C": "C",
|
362 |
+
"CA": "CT1",
|
363 |
+
"CB": "CT2",
|
364 |
+
"CE": "CT3",
|
365 |
+
"CG": "CT2",
|
366 |
+
"N": "NH1",
|
367 |
+
"O": "O",
|
368 |
+
"SD": "S",
|
369 |
+
},
|
370 |
+
"z-angles": [111.88, 115.92, 110.28, 98.94],
|
371 |
+
"z-dihedrals": [121.62, 180.0, 180.0, 180.0],
|
372 |
+
"z-lengths": [1.55, 1.55, 1.82, 1.82],
|
373 |
+
},
|
374 |
+
"PHE": {
|
375 |
+
"atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
|
376 |
+
"chi_indices": [1, 2],
|
377 |
+
"parents": [
|
378 |
+
["N", "C", "CA"],
|
379 |
+
["N", "CA", "CB"],
|
380 |
+
["CA", "CB", "CG"],
|
381 |
+
["CD1", "CB", "CG"],
|
382 |
+
["CB", "CG", "CD1"],
|
383 |
+
["CB", "CG", "CD2"],
|
384 |
+
["CG", "CD1", "CE1"],
|
385 |
+
],
|
386 |
+
"types": {
|
387 |
+
"C": "C",
|
388 |
+
"CA": "CT1",
|
389 |
+
"CB": "CT2",
|
390 |
+
"CD1": "CA",
|
391 |
+
"CD2": "CA",
|
392 |
+
"CE1": "CA",
|
393 |
+
"CE2": "CA",
|
394 |
+
"CG": "CA",
|
395 |
+
"CZ": "CA",
|
396 |
+
"N": "NH1",
|
397 |
+
"O": "O",
|
398 |
+
},
|
399 |
+
"z-angles": [112.45, 112.76, 120.32, 120.76, 120.63, 120.62, 119.93],
|
400 |
+
"z-dihedrals": [122.49, 180.0, 90.0, -177.96, -177.37, 177.2, -0.12],
|
401 |
+
"z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4],
|
402 |
+
},
|
403 |
+
"PRO": {
|
404 |
+
"atoms": ["CB", "CG", "CD"],
|
405 |
+
"chi_indices": [1, 2],
|
406 |
+
"parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CA", "CB", "CG"]],
|
407 |
+
"types": {
|
408 |
+
"C": "C",
|
409 |
+
"CA": "CP1",
|
410 |
+
"CB": "CP2",
|
411 |
+
"CD": "CP3",
|
412 |
+
"CG": "CP2",
|
413 |
+
"N": "N",
|
414 |
+
"O": "O",
|
415 |
+
},
|
416 |
+
"z-angles": [111.74, 104.39, 103.21],
|
417 |
+
"z-dihedrals": [113.74, 31.61, -34.59],
|
418 |
+
"z-lengths": [1.54, 1.53, 1.53],
|
419 |
+
},
|
420 |
+
"SER": {
|
421 |
+
"atoms": ["CB", "OG"],
|
422 |
+
"chi_indices": [1],
|
423 |
+
"parents": [["N", "C", "CA"], ["N", "CA", "CB"]],
|
424 |
+
"types": {
|
425 |
+
"C": "C",
|
426 |
+
"CA": "CT1",
|
427 |
+
"CB": "CT2",
|
428 |
+
"N": "NH1",
|
429 |
+
"O": "O",
|
430 |
+
"OG": "OH1",
|
431 |
+
},
|
432 |
+
"z-angles": [111.4, 112.45],
|
433 |
+
"z-dihedrals": [124.75, 180.0],
|
434 |
+
"z-lengths": [1.56, 1.43],
|
435 |
+
},
|
436 |
+
"THR": {
|
437 |
+
"atoms": ["CB", "OG1", "CG2"],
|
438 |
+
"chi_indices": [1],
|
439 |
+
"parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["OG1", "CA", "CB"]],
|
440 |
+
"types": {
|
441 |
+
"C": "C",
|
442 |
+
"CA": "CT1",
|
443 |
+
"CB": "CT1",
|
444 |
+
"CG2": "CT3",
|
445 |
+
"N": "NH1",
|
446 |
+
"O": "O",
|
447 |
+
"OG1": "OH1",
|
448 |
+
},
|
449 |
+
"z-angles": [112.74, 112.16, 115.91],
|
450 |
+
"z-dihedrals": [126.46, 180.0, -124.13],
|
451 |
+
"z-lengths": [1.57, 1.43, 1.53],
|
452 |
+
},
|
453 |
+
"TRP": {
|
454 |
+
"atoms": ["CB", "CG", "CD2", "CD1", "CE2", "NE1", "CE3", "CZ3", "CH2", "CZ2"],
|
455 |
+
"chi_indices": [1, 2],
|
456 |
+
"parents": [
|
457 |
+
["N", "C", "CA"],
|
458 |
+
["N", "CA", "CB"],
|
459 |
+
["CA", "CB", "CG"],
|
460 |
+
["CD2", "CB", "CG"],
|
461 |
+
["CD1", "CG", "CD2"],
|
462 |
+
["CG", "CD2", "CE2"],
|
463 |
+
["CE2", "CG", "CD2"],
|
464 |
+
["CE2", "CD2", "CE3"],
|
465 |
+
["CD2", "CE3", "CZ3"],
|
466 |
+
["CE3", "CZ3", "CH2"],
|
467 |
+
],
|
468 |
+
"types": {
|
469 |
+
"C": "C",
|
470 |
+
"CA": "CT1",
|
471 |
+
"CB": "CT2",
|
472 |
+
"CD1": "CA",
|
473 |
+
"CD2": "CPT",
|
474 |
+
"CE2": "CPT",
|
475 |
+
"CE3": "CAI",
|
476 |
+
"CG": "CY",
|
477 |
+
"CH2": "CA",
|
478 |
+
"CZ2": "CAI",
|
479 |
+
"CZ3": "CA",
|
480 |
+
"N": "NH1",
|
481 |
+
"NE1": "NY",
|
482 |
+
"O": "O",
|
483 |
+
},
|
484 |
+
"z-angles": [
|
485 |
+
111.23,
|
486 |
+
115.14,
|
487 |
+
123.95,
|
488 |
+
129.18,
|
489 |
+
106.65,
|
490 |
+
107.87,
|
491 |
+
132.54,
|
492 |
+
118.16,
|
493 |
+
120.97,
|
494 |
+
120.87,
|
495 |
+
],
|
496 |
+
"z-dihedrals": [
|
497 |
+
122.68,
|
498 |
+
180.0,
|
499 |
+
90.0,
|
500 |
+
-172.81,
|
501 |
+
-0.08,
|
502 |
+
0.14,
|
503 |
+
179.21,
|
504 |
+
-0.2,
|
505 |
+
0.1,
|
506 |
+
0.01,
|
507 |
+
],
|
508 |
+
"z-lengths": [1.56, 1.52, 1.44, 1.37, 1.41, 1.37, 1.4, 1.4, 1.4, 1.4],
|
509 |
+
},
|
510 |
+
"TYR": {
|
511 |
+
"atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
|
512 |
+
"chi_indices": [1, 2],
|
513 |
+
"parents": [
|
514 |
+
["N", "C", "CA"],
|
515 |
+
["N", "CA", "CB"],
|
516 |
+
["CA", "CB", "CG"],
|
517 |
+
["CD1", "CB", "CG"],
|
518 |
+
["CB", "CG", "CD1"],
|
519 |
+
["CB", "CG", "CD2"],
|
520 |
+
["CG", "CD1", "CE1"],
|
521 |
+
["CE1", "CE2", "CZ"],
|
522 |
+
],
|
523 |
+
"types": {
|
524 |
+
"C": "C",
|
525 |
+
"CA": "CT1",
|
526 |
+
"CB": "CT2",
|
527 |
+
"CD1": "CA",
|
528 |
+
"CD2": "CA",
|
529 |
+
"CE1": "CA",
|
530 |
+
"CE2": "CA",
|
531 |
+
"CG": "CA",
|
532 |
+
"CZ": "CA",
|
533 |
+
"N": "NH1",
|
534 |
+
"O": "O",
|
535 |
+
"OH": "OH1",
|
536 |
+
},
|
537 |
+
"z-angles": [112.34, 112.94, 120.49, 120.46, 120.4, 120.56, 120.09, 120.25],
|
538 |
+
"z-dihedrals": [122.27, 180.0, 90.0, -176.46, -175.49, 175.32, -0.19, -178.98],
|
539 |
+
"z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4, 1.41],
|
540 |
+
},
|
541 |
+
"VAL": {
|
542 |
+
"atoms": ["CB", "CG1", "CG2"],
|
543 |
+
"chi_indices": [1],
|
544 |
+
"parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CG1", "CA", "CB"]],
|
545 |
+
"types": {
|
546 |
+
"C": "C",
|
547 |
+
"CA": "CT1",
|
548 |
+
"CB": "CT1",
|
549 |
+
"CG1": "CT3",
|
550 |
+
"CG2": "CT3",
|
551 |
+
"N": "NH1",
|
552 |
+
"O": "O",
|
553 |
+
},
|
554 |
+
"z-angles": [111.23, 113.97, 112.17],
|
555 |
+
"z-dihedrals": [122.95, 180.0, 123.99],
|
556 |
+
"z-lengths": [1.57, 1.54, 1.54],
|
557 |
+
},
|
558 |
+
}
|
chroma/chroma/constants/named_models.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
""" Paths for named models in the zoo """
|
16 |
+
|
17 |
+
GRAPH_BACKBONE_MODELS = {
|
18 |
+
"public": {
|
19 |
+
"s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt",
|
20 |
+
"data": "Generate Structure ETL: July 25 2022",
|
21 |
+
"task": "BLNL backbone model training with EMA, trained July 2023",
|
22 |
+
},
|
23 |
+
}
|
24 |
+
|
25 |
+
GRAPH_CLASSIFIER_MODELS = {
|
26 |
+
"public": {
|
27 |
+
"s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_proclass_v1.0.pt",
|
28 |
+
"data": "Generate Structure ETL: June 2022",
|
29 |
+
"task": "Backbone classification model training with cross-entropy loss",
|
30 |
+
},
|
31 |
+
}
|
32 |
+
|
33 |
+
GRAPH_DESIGN_MODELS = {
|
34 |
+
"public": {
|
35 |
+
"s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt",
|
36 |
+
"data": "Generate Structure ETL: July 25 2022",
|
37 |
+
"task": "Autoregressive joint prediction of sequence and chi angles, two-stage",
|
38 |
+
},
|
39 |
+
}
|
40 |
+
|
41 |
+
PROCAP_MODELS = {
|
42 |
+
"public": {
|
43 |
+
"s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_procap_v1.0.pt",
|
44 |
+
"data": "Generate Structure ETL: June 2022",
|
45 |
+
"task": "Backbone caption model training with cross-entropy loss, using M5 ProClass GNN embeddings",
|
46 |
+
},
|
47 |
+
}
|
48 |
+
|
49 |
+
NAMED_MODELS = {
|
50 |
+
"GraphBackbone": GRAPH_BACKBONE_MODELS,
|
51 |
+
"GraphDesign": GRAPH_DESIGN_MODELS,
|
52 |
+
"GraphClassifier": GRAPH_CLASSIFIER_MODELS,
|
53 |
+
"ProteinCaption": PROCAP_MODELS,
|
54 |
+
}
|
chroma/chroma/constants/sequence.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Constants used across protein representations.
|
16 |
+
|
17 |
+
These constants standardize protein tokenization alphabets, ideal structure
|
18 |
+
geometries and topologies, etc.
|
19 |
+
"""
|
20 |
+
from chroma.constants.geometry import AA_GEOMETRY
|
21 |
+
|
22 |
+
# Standard tokenization for Omniprot and Omniprot-interacting models
|
23 |
+
OMNIPROT_TOKENS = "ABCDEFGHIKLMNOPQRSTUVWYXZ*-#"
|
24 |
+
POTTS_EXTENDED_TOKENS = "ACDEFGHIKLMNPQRSTVWY-*#"
|
25 |
+
PAD = "-"
|
26 |
+
START = "@"
|
27 |
+
STOP = "*"
|
28 |
+
MASK = "#"
|
29 |
+
DNA_TOKENS = "ACGT"
|
30 |
+
RNA_TOKENS = "AGCU"
|
31 |
+
PROTEIN_TOKENS = "ACDEFGHIKLMNPQRSTVWY"
|
32 |
+
|
33 |
+
# Minimal 20-letter alphabet and corresponding triplet codes
|
34 |
+
AA20 = "ACDEFGHIKLMNPQRSTVWY"
|
35 |
+
AA20_3_TO_1 = {
|
36 |
+
"ALA": "A",
|
37 |
+
"ARG": "R",
|
38 |
+
"ASN": "N",
|
39 |
+
"ASP": "D",
|
40 |
+
"CYS": "C",
|
41 |
+
"GLN": "Q",
|
42 |
+
"GLU": "E",
|
43 |
+
"GLY": "G",
|
44 |
+
"HIS": "H",
|
45 |
+
"ILE": "I",
|
46 |
+
"LEU": "L",
|
47 |
+
"LYS": "K",
|
48 |
+
"MET": "M",
|
49 |
+
"PHE": "F",
|
50 |
+
"PRO": "P",
|
51 |
+
"SER": "S",
|
52 |
+
"THR": "T",
|
53 |
+
"TRP": "W",
|
54 |
+
"TYR": "Y",
|
55 |
+
"VAL": "V",
|
56 |
+
}
|
57 |
+
AA20_1_TO_3 = {
|
58 |
+
"A": "ALA",
|
59 |
+
"R": "ARG",
|
60 |
+
"N": "ASN",
|
61 |
+
"D": "ASP",
|
62 |
+
"C": "CYS",
|
63 |
+
"Q": "GLN",
|
64 |
+
"E": "GLU",
|
65 |
+
"G": "GLY",
|
66 |
+
"H": "HIS",
|
67 |
+
"I": "ILE",
|
68 |
+
"L": "LEU",
|
69 |
+
"K": "LYS",
|
70 |
+
"M": "MET",
|
71 |
+
"F": "PHE",
|
72 |
+
"P": "PRO",
|
73 |
+
"S": "SER",
|
74 |
+
"T": "THR",
|
75 |
+
"W": "TRP",
|
76 |
+
"Y": "TYR",
|
77 |
+
"V": "VAL",
|
78 |
+
}
|
79 |
+
AA20_3 = [AA20_1_TO_3[aa] for aa in AA20]
|
80 |
+
|
81 |
+
# Adding noncanonical amino acids
|
82 |
+
NONCANON_AA = [
|
83 |
+
"HSD",
|
84 |
+
"HSE",
|
85 |
+
"HSC",
|
86 |
+
"HSP",
|
87 |
+
"MSE",
|
88 |
+
"CSO",
|
89 |
+
"SEC",
|
90 |
+
"CSX",
|
91 |
+
"HIP",
|
92 |
+
"SEP",
|
93 |
+
"TPO",
|
94 |
+
]
|
95 |
+
AA31_3 = AA20_3 + NONCANON_AA
|
96 |
+
|
97 |
+
# Chain alphabet for PDB chain naming
|
98 |
+
CHAIN_ALPHABET = "_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
99 |
+
|
100 |
+
# Standard atom indexing
|
101 |
+
ATOMS_BB = ["N", "CA", "C", "O"]
|
102 |
+
|
103 |
+
ATOM_SYMMETRIES = {
|
104 |
+
"ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling
|
105 |
+
"ASP": [("OD1", "OD2")],
|
106 |
+
"GLU": [("OE1", "OE2")],
|
107 |
+
"PHE": [("CD1", "CD2"), ("CE1", "CE2")],
|
108 |
+
"TYR": [("CD1", "CD2"), ("CE1", "CE2")],
|
109 |
+
}
|
110 |
+
|
111 |
+
AA20_NUM_ATOMS = [4 + len(AA_GEOMETRY[aa]["atoms"]) for aa in AA20_3]
|
112 |
+
AA20_NUM_CHI = [len(AA_GEOMETRY[aa]["chi_indices"]) for aa in AA20_3]
|
chroma/chroma/data/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
This package includes io formats and tools for a few common datatypes,
|
17 |
+
including antibodies, proteins, sequences, and structures.
|
18 |
+
"""
|
19 |
+
from chroma.data.protein import Protein
|
chroma/chroma/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (339 Bytes). View file
|
|
chroma/chroma/data/__pycache__/protein.cpython-38.pyc
ADDED
Binary file (19.3 kB). View file
|
|
chroma/chroma/data/__pycache__/system.cpython-38.pyc
ADDED
Binary file (136 kB). View file
|
|
chroma/chroma/data/__pycache__/xcs.cpython-38.pyc
ADDED
Binary file (3.83 kB). View file
|
|
chroma/chroma/data/protein.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import copy
|
18 |
+
import os
|
19 |
+
import tempfile
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import nglview as nv
|
23 |
+
import torch
|
24 |
+
|
25 |
+
import chroma.utility.polyseq as polyseq
|
26 |
+
from chroma.constants import CHAIN_ALPHABET, PROTEIN_TOKENS
|
27 |
+
from chroma.data.system import System, SystemEntity
|
28 |
+
|
29 |
+
|
30 |
+
class Protein:
|
31 |
+
"""
|
32 |
+
Protein: A utility class for managing proteins within the Chroma ecosystem.
|
33 |
+
|
34 |
+
The Protein class offers a suite of methods for loading, saving, transforming, and viewing protein structures
|
35 |
+
and trajectories from a variety of input sources such as PDBID, CIF files, and XCS representations.
|
36 |
+
|
37 |
+
Attributes:
|
38 |
+
sys (System): A protein system object used for various molecular operations.
|
39 |
+
device (str): Specifies the device on which tensors are managed. Defaults to `cpu`.
|
40 |
+
"""
|
41 |
+
|
42 |
+
sys: System
|
43 |
+
device: str = "cpu"
|
44 |
+
|
45 |
+
def __new__(cls, *args, **kwargs):
|
46 |
+
"""Handles automatic loading of the protein based on the input.
|
47 |
+
Specifically deals with XCS
|
48 |
+
|
49 |
+
Args:
|
50 |
+
protein_input (_type_): _description_
|
51 |
+
"""
|
52 |
+
|
53 |
+
if len(args) == 1 and isinstance(args[0], System):
|
54 |
+
return cls.from_system(*args, **kwargs)
|
55 |
+
|
56 |
+
elif len(args) == 3: # 3 Tensor Arguments
|
57 |
+
X, C, S = args
|
58 |
+
assert isinstance(
|
59 |
+
C, torch.Tensor
|
60 |
+
), f"arg[1] must be a chain (C) torch.Tensor, but get {type(C)}"
|
61 |
+
assert isinstance(
|
62 |
+
S, torch.Tensor
|
63 |
+
), f"arg[2] must be a sequence (S) torch.Tensor, but get {type(S)}"
|
64 |
+
if isinstance(X, list):
|
65 |
+
assert all(
|
66 |
+
isinstance(x, torch.Tensor) for x in X
|
67 |
+
), "arg[0] must be an X torch.Tensor or a list of X torch.Tensors"
|
68 |
+
return cls.from_XCS_trajectory(X, C, S)
|
69 |
+
elif isinstance(X, torch.Tensor):
|
70 |
+
return cls.from_XCS(X, C, S)
|
71 |
+
else:
|
72 |
+
raise TypeError(
|
73 |
+
f"X must be a list of torch.Tensor that respects XCS format, but get {type(X), type(C), type(S)}"
|
74 |
+
)
|
75 |
+
|
76 |
+
elif len(args) == 1 and isinstance(args[0], str):
|
77 |
+
if args[0].lower().startswith("s3:"):
|
78 |
+
raise NotImplementedError(
|
79 |
+
"download of cifs or pdbs from s3 not supported."
|
80 |
+
)
|
81 |
+
|
82 |
+
if args[0].endswith(".cif"):
|
83 |
+
return cls.from_CIF(*args, **kwargs)
|
84 |
+
|
85 |
+
elif args[0].endswith(".pdb"):
|
86 |
+
return cls.from_PDB(*args, **kwargs)
|
87 |
+
|
88 |
+
else: # PDB or Sequence String
|
89 |
+
# Check if it is a valid PDB
|
90 |
+
import requests
|
91 |
+
|
92 |
+
url = f"https://data.rcsb.org/rest/v1/core/entry/{args[0]}"
|
93 |
+
VALID_PDBID = requests.get(url).status_code == 200
|
94 |
+
VALID_SEQUENCE = all([s in PROTEIN_TOKENS for s in args[0]])
|
95 |
+
|
96 |
+
if VALID_PDBID:
|
97 |
+
# This only works if connected to the internet,
|
98 |
+
# so maybe better status checking will help here
|
99 |
+
if VALID_PDBID and VALID_SEQUENCE:
|
100 |
+
raise Warning(
|
101 |
+
"Ambuguous input, this is both a valid Sequence string and"
|
102 |
+
" a valid PDBID. Interpreting as a PDBID, if you wish to"
|
103 |
+
" initialize as a sequence string please explicitly"
|
104 |
+
" initialize as Protein.from_sequence(MY_SEQUENCE)."
|
105 |
+
)
|
106 |
+
return cls.from_PDBID(*args, **kwargs)
|
107 |
+
elif VALID_SEQUENCE:
|
108 |
+
return cls.from_sequence(*args, **kwargs)
|
109 |
+
else:
|
110 |
+
raise NotImplementedError(
|
111 |
+
"Could Not Identify a valid input Type. See docstring for"
|
112 |
+
" details."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
raise NotImplementedError(
|
116 |
+
"Inputs must either be a 3-tuple of XCS tensors, or a single string"
|
117 |
+
)
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
def from_system(cls, system: System, device: str = "cpu") -> Protein:
|
121 |
+
protein = super(Protein, cls).__new__(cls)
|
122 |
+
protein.sys = system
|
123 |
+
protein.device = device
|
124 |
+
return protein
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def from_XCS(cls, X: torch.Tensor, C: torch.Tensor, S: torch.Tensor) -> Protein:
|
128 |
+
"""
|
129 |
+
Create a Protein object from XCS representations.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
X (torch.Tensor): A 4D tensor representing atomic coordinates of proteins.
|
133 |
+
Dimensions are `(batch, residues, atoms (4 or 14), coordinates (3))`.
|
134 |
+
C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers.
|
135 |
+
Sign of the value indicates presence (+) or absence (-) of structural
|
136 |
+
information for that residue. Magnitude indicates which chain the residue belongs to.
|
137 |
+
S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains
|
138 |
+
non-negative integers representing residue types at each position.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Protein: Initialized Protein object from the given XCS representation.
|
142 |
+
"""
|
143 |
+
protein = super(Protein, cls).__new__(cls)
|
144 |
+
protein.sys = System.from_XCS(X, C, S)
|
145 |
+
protein.device = X.device
|
146 |
+
return protein
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def from_XCS_trajectory(
|
150 |
+
cls, X_traj: List[torch.Tensor], C: torch.Tensor, S: torch.Tensor
|
151 |
+
) -> Protein:
|
152 |
+
"""
|
153 |
+
Initialize a Protein object from a trajectory of XCS representations.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
X_traj (List[torch.Tensor]): List of X tensor representations over time. Each tensor represents atomic
|
157 |
+
coordinates of proteins with dimensions `(batch, residues, atoms (4 or 14), coordinates (3))`.
|
158 |
+
C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers.
|
159 |
+
Sign of the value indicates presence (+) or absence (-) of structural
|
160 |
+
information for that residue. Magnitude indicates which chain the residue belongs to.
|
161 |
+
S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains
|
162 |
+
non-negative integers representing residue types at each position.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Protein: Protein object initialized from the XCS trajectory.
|
166 |
+
"""
|
167 |
+
protein = super(Protein, cls).__new__(cls)
|
168 |
+
protein.sys = System.from_XCS(X_traj[0], C, S)
|
169 |
+
protein.device = C.device
|
170 |
+
for X in X_traj[1:]:
|
171 |
+
protein.sys.add_model_from_X(X[C > 0])
|
172 |
+
return protein
|
173 |
+
|
174 |
+
@classmethod
|
175 |
+
def from_PDB(cls, input_file: str, device: str = "cpu") -> Protein:
|
176 |
+
"""
|
177 |
+
Load a Protein object from a provided PDB file.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
input_file (str): Path to the PDB file to be loaded.
|
181 |
+
device (str, optional): The device for tensor operations. Defaults to 'cpu'.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
Protein: Initialized Protein object from the provided PDB file.
|
185 |
+
"""
|
186 |
+
protein = super(Protein, cls).__new__(cls)
|
187 |
+
protein.sys = System.from_PDB(input_file)
|
188 |
+
protein.device = device
|
189 |
+
return protein
|
190 |
+
|
191 |
+
@classmethod
|
192 |
+
def from_CIF(
|
193 |
+
cls, input_file: str, canonicalize: bool = True, device: str = "cpu"
|
194 |
+
) -> Protein:
|
195 |
+
"""
|
196 |
+
Load a Protein object from a provided CIF format.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
input_file (str): Path to the CIF file to be loaded.
|
200 |
+
device (str, optional): The device for tensor operations. Defaults to 'cpu'.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Protein: Initialized Protein object from the provided CIF file.
|
204 |
+
"""
|
205 |
+
protein = super(Protein, cls).__new__(cls)
|
206 |
+
protein.sys = System.from_CIF(input_file)
|
207 |
+
protein.device = device
|
208 |
+
if canonicalize:
|
209 |
+
protein.canonicalize()
|
210 |
+
return protein
|
211 |
+
|
212 |
+
@classmethod
|
213 |
+
def from_PDBID(
|
214 |
+
cls, pdb_id: str, canonicalize: bool = True, device: str = "cpu"
|
215 |
+
) -> Protein:
|
216 |
+
"""
|
217 |
+
Load a Protein object using its PDBID by fetching the corresponding CIF file from the Protein Data Bank.
|
218 |
+
|
219 |
+
This method downloads the CIF file for the specified PDBID, processes it to create a Protein object,
|
220 |
+
and then deletes the temporary CIF file.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
pdb_id (str): The PDBID of the protein to fetch.
|
224 |
+
canonicalize (bool, optional): If set to True, the protein will be canonicalized post-loading. Defaults to True.
|
225 |
+
device (str, optional): The device for tensor operations. Defaults to 'cpu'.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Protein: An instance of the Protein class initialized from the fetched CIF file corresponding to the PDBID.
|
229 |
+
"""
|
230 |
+
from os import unlink
|
231 |
+
|
232 |
+
from chroma.utility.fetchdb import RCSB_file_download
|
233 |
+
|
234 |
+
file_cif = os.path.join(tempfile.gettempdir(), f"{pdb_id}.cif")
|
235 |
+
RCSB_file_download(pdb_id, ".cif", file_cif)
|
236 |
+
protein = cls.from_CIF(file_cif, canonicalize=canonicalize, device=device)
|
237 |
+
unlink(file_cif)
|
238 |
+
return protein
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def from_sequence(
|
242 |
+
cls, chains: Union[List[str], str], device: str = "cpu"
|
243 |
+
) -> Protein:
|
244 |
+
"""
|
245 |
+
Load a protein object purely from Sequence with no structural content.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
chains (Union[List[str],str]): a list of sequence strings, or a sequence string to create the protein.
|
249 |
+
device (str, optional): which device for torch outputs should be used. Defaults to "cpu".
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Protein: An instance of the Protein class initialized a sequence or list of sequences.
|
253 |
+
"""
|
254 |
+
|
255 |
+
if isinstance(chains, str):
|
256 |
+
chains = [chains]
|
257 |
+
|
258 |
+
system = System("system")
|
259 |
+
for c_ix, seq in enumerate(chains):
|
260 |
+
chain_id = CHAIN_ALPHABET[c_ix + 1]
|
261 |
+
chain = system.add_chain(chain_id)
|
262 |
+
|
263 |
+
# Populate the Chain
|
264 |
+
three_letter_sequence = []
|
265 |
+
for s_ix, s in enumerate(seq):
|
266 |
+
resname = polyseq.to_triple(s)
|
267 |
+
three_letter_sequence.append(resname)
|
268 |
+
chain.add_residue(resname, s_ix + 1, "")
|
269 |
+
|
270 |
+
# Add Entity
|
271 |
+
sys_entity = SystemEntity(
|
272 |
+
"polymer",
|
273 |
+
f"Sequence Chain {chain_id}",
|
274 |
+
"polypeptide(L)",
|
275 |
+
three_letter_sequence,
|
276 |
+
[False] * len(three_letter_sequence),
|
277 |
+
)
|
278 |
+
system.add_new_entity(sys_entity, [c_ix])
|
279 |
+
|
280 |
+
protein = super(Protein, cls).__new__(cls)
|
281 |
+
protein.sys = system
|
282 |
+
protein.device = device
|
283 |
+
return protein
|
284 |
+
|
285 |
+
def to_CIF(self, output_file: str, force: bool = False) -> None:
|
286 |
+
"""
|
287 |
+
Save the current Protein object to a file in CIF format.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
output_file (str): The path where the CIF file should be saved.
|
291 |
+
|
292 |
+
"""
|
293 |
+
if output_file.lower().startswith("s3:"):
|
294 |
+
raise NotImplementedError("cif output to an s3 bucket not supported.")
|
295 |
+
else:
|
296 |
+
self.sys.to_CIF(output_file)
|
297 |
+
|
298 |
+
def to_PDB(self, output_file: str, force: bool = False) -> None:
|
299 |
+
"""
|
300 |
+
Save the current Protein object to a file in PDB format.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
output_file (str): The path where the PDB file should be saved.
|
304 |
+
"""
|
305 |
+
if output_file.lower().startswith("s3:"):
|
306 |
+
raise NotImplementedError("pdb output to an s3 bucket not supported.")
|
307 |
+
|
308 |
+
else:
|
309 |
+
self.sys.to_PDB(output_file)
|
310 |
+
|
311 |
+
def to_XCS(
|
312 |
+
self, all_atom: bool = False, device: Optional[str] = None
|
313 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
314 |
+
"""
|
315 |
+
Convert the current Protein object to its XCS tensor representations.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
all_atom (bool, optional): Indicates if all atoms should be considered in the conversion. Defaults to False.
|
319 |
+
device (str, optional): the device to export XCS tensors to. If not specified uses the device property
|
320 |
+
set in the class. Default None.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
X (torch.Tensor): A 4D tensor representing atomic coordinates of proteins with dimensions
|
324 |
+
`(batch, residues, atoms (4 or 14), coordinates (3))`.
|
325 |
+
C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers. Sign of
|
326 |
+
the value indicates presence (+) or absence (-) of structural information for that residue.
|
327 |
+
Magnitude indicates which chain the residue belongs to.
|
328 |
+
S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains non-negative
|
329 |
+
integers representing residue types at each position.
|
330 |
+
"""
|
331 |
+
|
332 |
+
if device is None:
|
333 |
+
device = self.device
|
334 |
+
|
335 |
+
X, C, S = [tensor.to(device) for tensor in self.sys.to_XCS(all_atom=all_atom)]
|
336 |
+
|
337 |
+
return X, C, S
|
338 |
+
|
339 |
+
def to_XCS_trajectory(
|
340 |
+
self,
|
341 |
+
device: Optional[str] = None,
|
342 |
+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
|
343 |
+
"""
|
344 |
+
Convert the current Protein object to its XCS tensor representations over a trajectory.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
device (str, optional): the device to export XCS tensors to. If not specified uses the device property
|
348 |
+
set in the class. Default None.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
X_traj (List[torch.Tensor]): List of X tensor representations over time. Each tensor represents atomic
|
352 |
+
coordinates of proteins with dimensions `(batch, residues, atoms (4 or 14), coordinates (3))`.
|
353 |
+
C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers. Sign of
|
354 |
+
the value indicates presence (+) or absence (-) of structural information for that residue.
|
355 |
+
Magnitude indicates which chain the residue belongs to.
|
356 |
+
S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains non-negative
|
357 |
+
integers representing residue types at each position.
|
358 |
+
"""
|
359 |
+
X, C, S = [], None, None
|
360 |
+
for i in range(self.sys.num_models()):
|
361 |
+
self.sys.swap_model(i)
|
362 |
+
if i == 0:
|
363 |
+
X_frame, C, S, loc_indices = self.sys.to_XCS(get_indices=True)
|
364 |
+
else:
|
365 |
+
X_frame.flatten(0, 2)[:] = torch.from_numpy(
|
366 |
+
self.sys._locations["coor"][loc_indices, 0:3]
|
367 |
+
)
|
368 |
+
X.append(X_frame.clone())
|
369 |
+
self.sys.swap_model(i)
|
370 |
+
X = torch.cat(X)
|
371 |
+
|
372 |
+
if device is None:
|
373 |
+
device = self.device
|
374 |
+
|
375 |
+
Xtraj, C, S = [tensor.to(device) for tensor in [X, C, S]]
|
376 |
+
return [each.unsqueeze(0) for each in Xtraj], C, S
|
377 |
+
|
378 |
+
def to(self, file_path: str, force: bool = False) -> None:
|
379 |
+
"""
|
380 |
+
General Export for the Protein Class
|
381 |
+
|
382 |
+
This method allows for export in pdf or cif based on the file extension.
|
383 |
+
explicit saving is still available with the respective export methods.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
device (str): The desired device for tensor operations, e.g., 'cpu' or 'cpu'.
|
387 |
+
"""
|
388 |
+
if file_path.lower().endswith(".pdb"):
|
389 |
+
self.to_PDB(file_path, force=force)
|
390 |
+
elif file_path.lower().endswith(".cif"):
|
391 |
+
self.to_CIF(file_path, force=force)
|
392 |
+
else:
|
393 |
+
raise NotImplementedError(
|
394 |
+
"file path must end with either *.cif or *.pdb for export."
|
395 |
+
)
|
396 |
+
|
397 |
+
def length(self, structured: bool = False) -> None:
|
398 |
+
"""
|
399 |
+
Retrieve the length of the protein.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
structured (bool, optional): If set to True, returns the residue size of the structured part of the protein.
|
403 |
+
Otherwise, returns the length of the entire protein. Defaults to False.
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
int: Length of the protein or its structured part based on the 'structured' argument.
|
407 |
+
"""
|
408 |
+
if structured:
|
409 |
+
return self.sys.num_structured_residues()
|
410 |
+
return self.sys.num_residues()
|
411 |
+
|
412 |
+
__len__ = length
|
413 |
+
|
414 |
+
def canonicalize(self) -> None:
|
415 |
+
"""
|
416 |
+
Canonicalize the protein's backbone geometry.
|
417 |
+
|
418 |
+
This method processes the protein to ensure it conforms to a canonical form.
|
419 |
+
"""
|
420 |
+
self.sys.canonicalize_protein(
|
421 |
+
level=2,
|
422 |
+
drop_coors_unknowns=True,
|
423 |
+
drop_coors_missing_backbone=True,
|
424 |
+
)
|
425 |
+
|
426 |
+
def sequence(self, format: str = "one-letter-string") -> Union[List[str], str]:
|
427 |
+
"""
|
428 |
+
Retrieve the sequence of the protein in the specified format.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
format (str, optional): The desired format for the sequence. Can be 'three-letter-list' or 'one-letter-string'.
|
432 |
+
Defaults to 'one-letter-string'.
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
Union[List[str], str]: The protein sequence in the desired format.
|
436 |
+
|
437 |
+
Raises:
|
438 |
+
Exception: If an unknown sequence format is provided.
|
439 |
+
"""
|
440 |
+
if format == "three-letter-list":
|
441 |
+
return list(self.sys.sequence())
|
442 |
+
elif format == "one-letter-string":
|
443 |
+
return self.sys.sequence("one-letter-string")
|
444 |
+
else:
|
445 |
+
raise Exception(f"unknown sequence format {format}")
|
446 |
+
|
447 |
+
def display(self, representations: list = []) -> None:
|
448 |
+
"""
|
449 |
+
Display the protein using the provided representations in NGL view.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
representations (list, optional): List of visual representations to use in the display. Defaults to an empty list.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
viewer: A viewer object for interactive visualization.
|
456 |
+
"""
|
457 |
+
from chroma.utility.ngl import SystemTrajectory, view_gsystem
|
458 |
+
|
459 |
+
if self.sys.num_models() == 1:
|
460 |
+
viewer = view_gsystem(self.sys)
|
461 |
+
for rep in representations:
|
462 |
+
viewer.add_representation(rep)
|
463 |
+
|
464 |
+
else:
|
465 |
+
t = SystemTrajectory(self)
|
466 |
+
viewer = nv.NGLWidget(t)
|
467 |
+
return viewer
|
468 |
+
|
469 |
+
def _ipython_display_(self):
|
470 |
+
display(self.display())
|
471 |
+
|
472 |
+
def __str__(self):
|
473 |
+
"""Define Print Behavior
|
474 |
+
Return Protein Sequence Along with some useful statistics.
|
475 |
+
"""
|
476 |
+
protein_string = f"Protein: {self.sys.name}\n"
|
477 |
+
for chain in self.sys.chains():
|
478 |
+
if chain.sequence is not None:
|
479 |
+
protein_string += (
|
480 |
+
f"> Chain {chain.cid} ({len(chain.sequence())} residues)\n"
|
481 |
+
)
|
482 |
+
protein_string += "".join(
|
483 |
+
[polyseq.to_single(s) for s in chain.sequence()]
|
484 |
+
)
|
485 |
+
protein_string += "\n\n"
|
486 |
+
|
487 |
+
return protein_string
|
488 |
+
|
489 |
+
def get_mask(self, selection: str) -> torch.Tensor:
|
490 |
+
"""
|
491 |
+
Generate a mask tensor based on the provided residue selection.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
selection (str): A selection string to specify which residues should be included in the mask.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
torch.Tensor: A mask tensor of shape `(1, protein length)`, where positions corresponding to selected residues have a value of 1.
|
498 |
+
"""
|
499 |
+
residue_gtis = self.sys.select_residues(selection, gti=True)
|
500 |
+
D = torch.zeros(1, self.sys.num_residues(), device=self.device)
|
501 |
+
for gti in residue_gtis:
|
502 |
+
D[0, gti] = 1
|
503 |
+
return D
|
504 |
+
|
505 |
+
def __copy__(self):
|
506 |
+
new_system = copy.copy(self.sys)
|
507 |
+
device = self.device
|
508 |
+
return Protein(new_system, device=device)
|
509 |
+
|
510 |
+
def __deepcopy__(self, memo):
|
511 |
+
new_system = copy.deepcopy(self.sys)
|
512 |
+
device = self.device
|
513 |
+
return Protein(new_system, device=device)
|
chroma/chroma/data/system.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
chroma/chroma/data/xcs.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""XCS represents protein structure as a tuple of PyTorch tensors.
|
16 |
+
|
17 |
+
The tensors in an XCS representation are:
|
18 |
+
|
19 |
+
`X` (FloatTensor), the Cartesian coordinates representing the protein
|
20 |
+
structure with shape `(num_batch, num_residues, num_atoms, 3)`. The
|
21 |
+
`num_atoms` dimension can be one of two sizes: `num_atoms=4` for
|
22 |
+
backbone-only structures or `num_atoms=14` for all-atom structures
|
23 |
+
(excluding hydrogens). The first four atoms will always be
|
24 |
+
`N, CA, C, O`, and the meaning of the optional 10 additional atom
|
25 |
+
positions will vary based on the residue identity at
|
26 |
+
a given position. Atom orders for each amino acid are defined in
|
27 |
+
`constants.AA_GEOMETRY[TRIPLET_CODE]["atoms"]`.
|
28 |
+
|
29 |
+
`C` (LongTensor), the chain map encoding per-residue chain assignments with
|
30 |
+
shape `(num_batch, num_residues)`.The chain map codes positions as `0`
|
31 |
+
when masked, poitive integers for chain indices, and negative integers
|
32 |
+
to represent missing residues (of the corresponding positive integers).
|
33 |
+
|
34 |
+
`S` (LongTensor), the sequence of the protein as alphabet indices with
|
35 |
+
shape `(num_batch, num_residues)`. The standard alphabet is
|
36 |
+
`ACDEFGHIKLMNPQRSTVWY`, also defined in `constants.AA20`.
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
from functools import partial, wraps
|
41 |
+
from inspect import getfullargspec
|
42 |
+
|
43 |
+
import torch
|
44 |
+
from torch.nn import functional as F
|
45 |
+
|
46 |
+
try:
|
47 |
+
pass
|
48 |
+
except ImportError:
|
49 |
+
print("MST not installed!")
|
50 |
+
|
51 |
+
|
52 |
+
def validate_XCS(all_atom=None, sequence=True):
|
53 |
+
"""Decorator factory that adds XCS validation to any function.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
all_atom (bool, optional): If True, requires that input structure
|
57 |
+
tensors have 14 residues per atom. If False, reduces to 4 residues
|
58 |
+
per atom. If None, applies no transformation on input structures.
|
59 |
+
sequence (bool, optional): If True, makes sure that if S and O are both
|
60 |
+
provided, that they match, i.e. that O is a one-hot version of S.
|
61 |
+
If only one of S or O is provided, the other is generated, and both
|
62 |
+
are passed.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def decorator(func):
|
66 |
+
@wraps(func)
|
67 |
+
def new_func(*args, **kwargs):
|
68 |
+
args = list(args)
|
69 |
+
arg_list = getfullargspec(func)[0]
|
70 |
+
tensors = {}
|
71 |
+
for var in ["X", "C", "S", "O"]:
|
72 |
+
try:
|
73 |
+
if var in kwargs:
|
74 |
+
tensors[var] = kwargs[var]
|
75 |
+
else:
|
76 |
+
tensors[var] = args[arg_list.index(var)]
|
77 |
+
except IndexError: # empty args_list
|
78 |
+
tensors[var] = None
|
79 |
+
except ValueError: # variable not an argument of function
|
80 |
+
if not sequence and var in ["S", "O"]:
|
81 |
+
pass
|
82 |
+
else:
|
83 |
+
raise Exception(
|
84 |
+
f"Variable {var} is required by validation but not defined!"
|
85 |
+
)
|
86 |
+
if tensors["X"] is not None and tensors["C"] is not None:
|
87 |
+
if tensors["X"].shape[:2] != tensors["C"].shape[:2]:
|
88 |
+
raise ValueError(
|
89 |
+
f"X shape {tensors['X'].shape} does not match C shape"
|
90 |
+
f" {tensors['C'].shape}"
|
91 |
+
)
|
92 |
+
if all_atom is not None and tensors["X"] is not None:
|
93 |
+
if all_atom and tensors["X"].shape[2] != 14:
|
94 |
+
raise ValueError("Side chain atoms missing!")
|
95 |
+
elif not all_atom:
|
96 |
+
if "X" in kwargs:
|
97 |
+
kwargs["X"] = tensors["X"][:, :, :4]
|
98 |
+
else:
|
99 |
+
args[arg_list.index("X")] = tensors["X"][:, :, :4]
|
100 |
+
if sequence and (tensors["S"] is not None or tensors["O"] is not None):
|
101 |
+
if tensors["O"] is None:
|
102 |
+
if "O" in kwargs:
|
103 |
+
kwargs["O"] = F.one_hot(tensors["S"], 20).float()
|
104 |
+
else:
|
105 |
+
args[arg_list.index("O")] = F.one_hot(tensors["S"], 20).float()
|
106 |
+
elif tensors["S"] is None:
|
107 |
+
if "S" in kwargs:
|
108 |
+
kwargs["S"] = tensors["O"].argmax(dim=2)
|
109 |
+
else:
|
110 |
+
args[arg_list.index("S")] = tensors["O"].argmax(dim=2)
|
111 |
+
else:
|
112 |
+
if not torch.allclose(tensors["O"].argmax(dim=2), tensors["S"]):
|
113 |
+
raise ValueError("S and O are both provided but don't match!")
|
114 |
+
return func(*args, **kwargs)
|
115 |
+
|
116 |
+
return new_func
|
117 |
+
|
118 |
+
return decorator
|
119 |
+
|
120 |
+
|
121 |
+
validate_XC = partial(validate_XCS, sequence=False)
|
chroma/chroma/layers/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
This package contains low-level PyTorch layers, including ``nn.Module`` s and ops.
|
17 |
+
These layers are often used in :mod:`chroma.models`.
|
18 |
+
"""
|
chroma/chroma/layers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (291 Bytes). View file
|
|
chroma/chroma/layers/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (12.8 kB). View file
|
|
chroma/chroma/layers/__pycache__/basic.cpython-38.pyc
ADDED
Binary file (18.6 kB). View file
|
|
chroma/chroma/layers/__pycache__/complexity.cpython-38.pyc
ADDED
Binary file (5.45 kB). View file
|
|
chroma/chroma/layers/__pycache__/conv.cpython-38.pyc
ADDED
Binary file (1.14 kB). View file
|
|
chroma/chroma/layers/__pycache__/graph.cpython-38.pyc
ADDED
Binary file (34.6 kB). View file
|
|
chroma/chroma/layers/__pycache__/linalg.cpython-38.pyc
ADDED
Binary file (3.2 kB). View file
|
|
chroma/chroma/layers/__pycache__/norm.cpython-38.pyc
ADDED
Binary file (7.03 kB). View file
|
|
chroma/chroma/layers/__pycache__/sde.cpython-38.pyc
ADDED
Binary file (2.83 kB). View file
|
|
chroma/chroma/layers/attention.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
"""
|
19 |
+
们实现了Transformer模型中的关键组件:缩放点积注意力(Scaled Dot Product Attention)和多头注意力(Multi-Head Attention)。
|
20 |
+
"""
|
21 |
+
class ScaledDotProductAttention(nn.Module):
|
22 |
+
"""Scaled dot product attention as described in Eqn 1 of Vaswani et al. 2017 [https://arxiv.org/abs/1706.03762].
|
23 |
+
|
24 |
+
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
|
25 |
+
|
26 |
+
Note that the dimension of the query has to match the dimension of the keys (here specified as ```d_k```) and the length of keys has to match
|
27 |
+
the length of the values. See for instance 'The Illustrated Transformer' [http://jalammar.github.io/illustrated-transformer/]
|
28 |
+
for pictorial depiction of attention.
|
29 |
+
|
30 |
+
Inputs:
|
31 |
+
Q (torch.tensor): of shape (batch_size, sequence_length_q, d_k)
|
32 |
+
K (torch.tensor): of shape (batch_size, sequence_length_k, d_k)
|
33 |
+
V (torch.tensor): of shape (batch_size, sequence_length_k, d_v)
|
34 |
+
mask (torch.tensor): of dtype (bool) or (byte) and shape (batch_size, 1, sequence_length_k), optional
|
35 |
+
zeroes (or False) indicate positions that cannot contribute to attention
|
36 |
+
Outputs:
|
37 |
+
output (torch.tensor) of shape (batch_size, sequence_length_q, d_v). The [i-j]-entry output[i,j,:] is formed as a convex combination of values:
|
38 |
+
\sum_k a_k V[i,k,:] and \sum_k a_k = 1.
|
39 |
+
attentions (torch.tensor) of shape (batch_size, sequence_length_q, sequence_length_k)) where the [b,i,j]-element
|
40 |
+
corresponds to the attention value (e.g relative contribution) of position j in the key-tensor to position i in the query tensor in element b of the batch.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self):
|
44 |
+
super(ScaledDotProductAttention, self).__init__()
|
45 |
+
self.softmax = nn.Softmax(dim=-1)
|
46 |
+
|
47 |
+
def forward(self, Q, K, V, mask=None):
|
48 |
+
_, _, d = K.size()
|
49 |
+
attn = torch.bmm(Q, K.transpose(1, 2)) / d ** 0.5
|
50 |
+
if mask is not None:
|
51 |
+
attn = attn.float().masked_fill(mask == 0, -1e9)
|
52 |
+
|
53 |
+
attn = self.softmax(attn)
|
54 |
+
if mask is not None:
|
55 |
+
attn = attn.float().masked_fill(mask == 0, 0)
|
56 |
+
|
57 |
+
if V.dtype == torch.float16:
|
58 |
+
attn = attn.half()
|
59 |
+
output = torch.bmm(attn, V)
|
60 |
+
return output, attn
|
61 |
+
|
62 |
+
|
63 |
+
class MultiHeadAttention(nn.Module):
|
64 |
+
"""Multi-head attention with scaled dot product attention. See 'The Annotated Transformer'
|
65 |
+
http://nlp.seas.harvard.edu/2018/04/03/attention.html or 'The Illustrated Transformer' http://jalammar.github.io/illustrated-transformer/
|
66 |
+
for details and intuition.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
n_head (int): number of attention heads
|
70 |
+
d_k (int): dimension of the keys and queries in each attention head
|
71 |
+
d_v (int): dimension of the values in each attention head
|
72 |
+
d_model (int): input and output dimension for the layer
|
73 |
+
dropout (float): dropout rate, default is 0.1
|
74 |
+
|
75 |
+
Inputs:
|
76 |
+
Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)```
|
77 |
+
K (torch.tensor): key tensor of shape ```(batch_size, sequence_length_k, d_model)```
|
78 |
+
V (torch.tensor): value tensor of shape ```(batch_size, sequence_length_k, d_model)```
|
79 |
+
mask (torch.tensor): (optional) of dtype ```bool`` or ```byte``` and size (batch_size, 1, sequence_length_k),
|
80 |
+
zeroes (or False) indicate positions that cannot contribute to attention
|
81 |
+
|
82 |
+
Outputs:
|
83 |
+
output (torch.tensor) : of shape ```(batch_size, sequence_length_q, d_model)```
|
84 |
+
attentions (torch.tensor): of shape ```(batch_size * n_head, sequence_length_q, sequence_length_k) where
|
85 |
+
```attentions[batch_size*(i):batch_size*(i+1),:,:]``` corresponds to the batch of attention blocks for i'th head. See
|
86 |
+
```chroma.layers.attention.ScaledDotProductAttention``` for more details
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, n_head, d_k, d_v, d_model, dropout=0.1):
|
90 |
+
super(MultiHeadAttention, self).__init__()
|
91 |
+
self.n_head = n_head
|
92 |
+
self.d_k = d_k
|
93 |
+
self.d_v = d_v
|
94 |
+
self.d_model = d_model
|
95 |
+
self.Wq = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
|
96 |
+
self.Wk = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
|
97 |
+
self.Wv = nn.Parameter(torch.Tensor(n_head, d_model, d_v))
|
98 |
+
self.Wo = nn.Parameter(torch.Tensor(n_head * d_v, d_model))
|
99 |
+
self.attention = ScaledDotProductAttention()
|
100 |
+
self.dropout = nn.Dropout(p=dropout)
|
101 |
+
self.reset_parameters()
|
102 |
+
|
103 |
+
def reset_parameters(self):
|
104 |
+
nn.init.xavier_normal_(self.Wq)
|
105 |
+
nn.init.xavier_normal_(self.Wk)
|
106 |
+
nn.init.xavier_normal_(self.Wv)
|
107 |
+
nn.init.kaiming_uniform_(self.Wo)
|
108 |
+
|
109 |
+
def forward(self, Q, K, V, bias=None, mask=None):
|
110 |
+
mb_size, len_q, d_q_in = Q.size()
|
111 |
+
mb_size, len_k, d_k_in = K.size()
|
112 |
+
mb_size, len_v, d_v_in = V.size()
|
113 |
+
d_model = self.d_model
|
114 |
+
if d_q_in != d_model:
|
115 |
+
raise ValueError("Dimension of Q does not match d_model.")
|
116 |
+
|
117 |
+
if d_k_in != d_model:
|
118 |
+
raise ValueError("Dimension of K does not match d_model.")
|
119 |
+
|
120 |
+
if d_v_in != d_model:
|
121 |
+
raise ValueError("Dimension of V does not match d_model.")
|
122 |
+
|
123 |
+
# treat as a (n_head) size batch and project to d_k and d_v
|
124 |
+
q_s = torch.cat([Q @ W for W in self.Wq]) # (n_head*mb_size) x len_q x d_k
|
125 |
+
k_s = torch.cat([K @ W for W in self.Wk]) # (n_head*mb_size) x len_k x d_k
|
126 |
+
v_s = torch.cat([V @ W for W in self.Wv]) # (n_head*mb_size) x len_v x d_v
|
127 |
+
|
128 |
+
# Attention
|
129 |
+
if mask is not None:
|
130 |
+
mask = mask.repeat(self.n_head, 1, 1)
|
131 |
+
outputs, attns = self.attention(q_s, k_s, v_s, mask=mask)
|
132 |
+
|
133 |
+
# Back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
|
134 |
+
outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)
|
135 |
+
|
136 |
+
# Project back to residual size
|
137 |
+
outputs = outputs @ self.Wo
|
138 |
+
outputs = self.dropout(outputs)
|
139 |
+
return outputs, attns
|
140 |
+
|
141 |
+
|
142 |
+
class AttentionChainPool(nn.Module):
|
143 |
+
"""Pools residue-based representations to chain-based representations using a chain mask and attention.
|
144 |
+
Args:
|
145 |
+
n_head (int): number of attention heads
|
146 |
+
d_model (int): dimension of embeddings to be pooled
|
147 |
+
|
148 |
+
Inputs:
|
149 |
+
h (torch.tensor): of size (batch_size, sequence_length, d_model)
|
150 |
+
C (torch.tensor): of size (batch_size, sequence_length)
|
151 |
+
|
152 |
+
Outputs:
|
153 |
+
output (torch.tensor): of size (batch_size, n_chains, d_model)
|
154 |
+
chain_mask (torch.tensor): of size (batch_size, n_chains)
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, n_head, d_model):
|
158 |
+
super().__init__()
|
159 |
+
self.attention = MultiHeadAttention(
|
160 |
+
n_head, d_model, d_model, d_model, dropout=0.0
|
161 |
+
)
|
162 |
+
|
163 |
+
def get_query(self, x):
|
164 |
+
return torch.ones(x.size(0), 1, x.size(2)).type(x.dtype).to(x.device)
|
165 |
+
|
166 |
+
def forward(self, h, C):
|
167 |
+
bs, num_res = C.size()
|
168 |
+
chains = C.abs().unique()
|
169 |
+
chains = (
|
170 |
+
chains[chains > 0].unsqueeze(-1).repeat(1, bs).reshape(-1).unsqueeze(-1)
|
171 |
+
)
|
172 |
+
num_chains = len(chains.unique())
|
173 |
+
|
174 |
+
h_repeat = h.repeat(num_chains, 1, 1)
|
175 |
+
C_repeat = C.repeat(num_chains, 1)
|
176 |
+
mask = (C_repeat == chains).unsqueeze(-2)
|
177 |
+
|
178 |
+
output, _ = self.attention(
|
179 |
+
self.get_query(h_repeat), h_repeat, h_repeat, mask=mask
|
180 |
+
)
|
181 |
+
output = torch.cat(output.split(bs), 1)
|
182 |
+
chain_mask = torch.stack(mask.squeeze(1).any(dim=-1).split(bs), -1)
|
183 |
+
return output, chain_mask
|
184 |
+
|
185 |
+
|
186 |
+
class Attention(nn.Module):
|
187 |
+
"""
|
188 |
+
A multi-head attention layer with optional gating and bias as implemented in Jumper et al. (2021)
|
189 |
+
Args:
|
190 |
+
n_head (int): Number of heads of attention
|
191 |
+
d_model (int): Dimension of input and outputs
|
192 |
+
d_k (int): Dimension of keys/queries
|
193 |
+
d_v (int): Dimension of values
|
194 |
+
gate (bool): Whether to include a gate connection (as in Jumper et al. (2021))
|
195 |
+
|
196 |
+
Inputs:
|
197 |
+
Q (torch.tensor): of size (batch_size, num_queries, d_model)
|
198 |
+
K (torch.tensor): of size (batch_size, num_keys, d_model)
|
199 |
+
V (torch.tensor): of size (batch_size, num_keys, d_model)
|
200 |
+
bias (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
|
201 |
+
mask (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
|
202 |
+
|
203 |
+
Outputs:
|
204 |
+
output (torch.tensor): of size (batch_size, num_queries, d_model)
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, n_head, d_model, d_k=None, d_v=None, gate=False):
|
208 |
+
super().__init__()
|
209 |
+
self.n_head = n_head
|
210 |
+
self.d_model = d_model
|
211 |
+
self.d_k = d_model // n_head if d_k is None else d_k
|
212 |
+
self.d_v = d_model // n_head if d_v is None else d_v
|
213 |
+
self.gate = gate
|
214 |
+
self.q_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
|
215 |
+
self.k_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
|
216 |
+
self.v_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
|
217 |
+
self.o_weights = nn.Parameter(torch.Tensor(n_head, self.d_v, d_model))
|
218 |
+
self.o_bias = nn.Parameter(torch.Tensor(d_model))
|
219 |
+
if self.gate:
|
220 |
+
self.g_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
|
221 |
+
self.g_bias = nn.Parameter(torch.Tensor(n_head, self.d_v))
|
222 |
+
self.softmax = nn.Softmax(dim=-1)
|
223 |
+
self.reset_parameters()
|
224 |
+
|
225 |
+
def reset_parameters(self):
|
226 |
+
nn.init.xavier_uniform_(self.q_weights)
|
227 |
+
nn.init.xavier_uniform_(self.k_weights)
|
228 |
+
nn.init.xavier_uniform_(self.v_weights)
|
229 |
+
nn.init.xavier_uniform_(self.o_weights)
|
230 |
+
nn.init.zeros_(self.o_bias)
|
231 |
+
if self.gate:
|
232 |
+
nn.init.zeros_(self.g_weights)
|
233 |
+
nn.init.ones_(self.g_bias)
|
234 |
+
|
235 |
+
def forward(self, Q, K, V, bias=None, mask=None):
|
236 |
+
self._check_inputs(Q, K, V, bias, mask)
|
237 |
+
q = torch.einsum("bqa,ahc->bqhc", Q, self.q_weights) * self.d_k ** (-0.5)
|
238 |
+
k = torch.einsum("bka,ahc->bkhc", K, self.k_weights)
|
239 |
+
v = torch.einsum("bka,ahc->bkhc", V, self.v_weights)
|
240 |
+
logits = torch.einsum("bqhc,bkhc->bhqk", q, k)
|
241 |
+
|
242 |
+
if bias is not None:
|
243 |
+
logits = logits + bias
|
244 |
+
|
245 |
+
weights = torch.nn.functional.softmax(logits, dim=-1)
|
246 |
+
|
247 |
+
if mask is not None:
|
248 |
+
weights = weights.masked_fill(~mask, 0.0)
|
249 |
+
|
250 |
+
weighted_avg = torch.einsum("bhqk,bkhc->bqhc", weights, v)
|
251 |
+
|
252 |
+
if self.gate:
|
253 |
+
gate_values = torch.einsum("bqa,ahc->bqhc", Q, self.g_weights) + self.g_bias
|
254 |
+
gate_values = torch.sigmoid(gate_values, dim=-1)
|
255 |
+
weighted_avg = weighted_avg * gate_values
|
256 |
+
|
257 |
+
output = (
|
258 |
+
torch.einsum("bqhc,hco->bqo", weighted_avg, self.o_weights) + self.o_bias
|
259 |
+
)
|
260 |
+
return output
|
261 |
+
|
262 |
+
def _check_inputs(self, Q, K, V, bias, mask):
|
263 |
+
batch_size_q, num_queries, d_q_in = Q.size()
|
264 |
+
batch_size_k, num_keys, d_k_in = K.size()
|
265 |
+
batch_size_v, num_values, d_v_in = V.size()
|
266 |
+
|
267 |
+
if d_q_in != self.d_model:
|
268 |
+
raise ValueError(
|
269 |
+
f"Dimension of Q tensor needs to be (batch_size, number_queries, d_model)"
|
270 |
+
)
|
271 |
+
|
272 |
+
if d_k_in != self.d_model:
|
273 |
+
raise ValueError(
|
274 |
+
f"Dimension of K tensor needs to be (batch_size, number_keys, d_model)"
|
275 |
+
)
|
276 |
+
|
277 |
+
if d_v_in != self.d_model:
|
278 |
+
raise ValueError(
|
279 |
+
f"Dimension of V tensor needs to be (batch_size, number_values, d_model)"
|
280 |
+
)
|
281 |
+
|
282 |
+
if num_keys != num_values:
|
283 |
+
raise ValueError(f"Number of keys needs to match number of values passed")
|
284 |
+
|
285 |
+
if (batch_size_q != batch_size_k) or (batch_size_k != batch_size_v):
|
286 |
+
raise ValueError(
|
287 |
+
f"Found batch size mismatch among inputs, all tensors must agree in size of dimension 0"
|
288 |
+
)
|
289 |
+
|
290 |
+
if bias is not None:
|
291 |
+
if (bias.dim() != 3) and (bias.dim() != 4):
|
292 |
+
raise ValueError(
|
293 |
+
f"Bias specified but dimension mismatched: passed {bias.dim()}-dimensional tensor but should be 3-dimensional"
|
294 |
+
f"of shape (n_head, num_queries, num_keys) or 4-dimensional of shape (batch_size, n_head, num_queries, num_keys)"
|
295 |
+
)
|
296 |
+
if bias.dim() == 3:
|
297 |
+
n_head_b, num_queries_b, num_keys_b = bias.size()
|
298 |
+
if n_head_b != self.n_head:
|
299 |
+
raise ValueError(
|
300 |
+
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
|
301 |
+
)
|
302 |
+
if num_queries_b != num_queries:
|
303 |
+
raise ValueError(
|
304 |
+
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
|
305 |
+
)
|
306 |
+
if num_keys_b != num_keys:
|
307 |
+
raise ValueError(
|
308 |
+
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
|
309 |
+
f"(dimenson of axis=1)"
|
310 |
+
)
|
311 |
+
elif bias.dim() == 4:
|
312 |
+
if bias.dim() == 3:
|
313 |
+
n_batch_b, n_head_b, num_queries_b, num_keys_b = bias.size()
|
314 |
+
if n_head_b != self.n_head:
|
315 |
+
raise ValueError(
|
316 |
+
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
|
317 |
+
)
|
318 |
+
if num_queries_b != num_queries:
|
319 |
+
raise ValueError(
|
320 |
+
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
|
321 |
+
)
|
322 |
+
if num_keys_b != num_keys:
|
323 |
+
raise ValueError(
|
324 |
+
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
|
325 |
+
f"(dimenson of axis=1)"
|
326 |
+
)
|
327 |
+
|
328 |
+
if mask is not None:
|
329 |
+
if mask.dtype != torch.bool:
|
330 |
+
raise ValueError(
|
331 |
+
f"Mask specified but not given by correct dtype, should be torch.bool but found {mask.dtype}"
|
332 |
+
)
|
333 |
+
if mask.dim() != 4:
|
334 |
+
raise ValueError(
|
335 |
+
f"Mask specified but dimension mismatched: passed {mask.dim()}-dimensional tensor but should be 4-dimensional"
|
336 |
+
f"of shape (batch_size, n_head, num_queries, num_keys)"
|
337 |
+
)
|
338 |
+
batch_size_b, _, num_queries_b, num_keys_b = mask.size()
|
339 |
+
if (num_queries_b != num_queries) and (num_queries_b != 1):
|
340 |
+
raise ValueError(
|
341 |
+
f"Bias specified but number of queries (dim of axis=2) does not match number of queries given in Q tensor"
|
342 |
+
)
|
343 |
+
if (num_keys_b != num_keys) and (num_keys_b != 1):
|
344 |
+
raise ValueError(
|
345 |
+
f"Bias specified but number of keys (dim of axis=3) does not match number of queries given in K tensor "
|
346 |
+
f"(dimenson of axis=1)"
|
347 |
+
)
|
chroma/chroma/layers/basic.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from chroma.layers.norm import MaskedBatchNorm1d
|
23 |
+
|
24 |
+
|
25 |
+
class NoOp(nn.Module):
|
26 |
+
"""A dummy nn.Module wrapping an identity operation.
|
27 |
+
空操作模块,用来满足代码结构
|
28 |
+
Inputs:
|
29 |
+
x (any)
|
30 |
+
|
31 |
+
Outputs:
|
32 |
+
x (any)
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
def forward(self, x, **kwargs):
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Transpose(nn.Module):
|
43 |
+
"""An nn.Module wrapping ```torch.transpose```.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
d1 (int): the first (of two) dimensions to swap
|
47 |
+
d2 (int): the second (of two) dimensions to swap
|
48 |
+
|
49 |
+
Inputs:
|
50 |
+
x (torch.tensor)
|
51 |
+
|
52 |
+
Outputs:
|
53 |
+
y (torch.tensor): ```y = x.transpose(d1,d2)```
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, d1=1, d2=2):
|
57 |
+
super().__init__()
|
58 |
+
self.d1 = d1
|
59 |
+
self.d2 = d2
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
return x.transpose(self.d1, self.d2)
|
63 |
+
|
64 |
+
|
65 |
+
class Unsqueeze(nn.Module):
|
66 |
+
"""An nn.Module wrapping ```torch.unsqueeze```.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
dim (int): the dimension to unsqueeze input tensors
|
70 |
+
|
71 |
+
Inputs:
|
72 |
+
x (torch.tensor):
|
73 |
+
|
74 |
+
Outputs:
|
75 |
+
y (torch.tensor): where ```y=x.unsqueeze(dim)```
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, dim=1):
|
79 |
+
super().__init__()
|
80 |
+
self.dim = dim
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
return x.unsqueeze(self.dim)
|
84 |
+
|
85 |
+
|
86 |
+
class OneHot(nn.Module):
|
87 |
+
"""An nn.Module that wraps F.one_hot```.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
n_tokens (int): the number of tokens comprising input sequences
|
91 |
+
|
92 |
+
Inputs:
|
93 |
+
x (torch.LongTensor): of size ```(batch_size, *)```
|
94 |
+
|
95 |
+
Outputs:
|
96 |
+
y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, n_tokens):
|
100 |
+
super().__init__()
|
101 |
+
self.n_tokens = n_tokens
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return F.one_hot(x, self.n_tokens)
|
105 |
+
|
106 |
+
|
107 |
+
class MeanEmbedding(nn.Module):
|
108 |
+
"""A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation),
|
109 |
+
optionally applying a softmax to the last dimension if the input corresponds to a log-PMF.
|
110 |
+
Args:
|
111 |
+
embedding (nn.Embedding): Embedding to wrap
|
112 |
+
use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like.
|
113 |
+
|
114 |
+
Inputs:
|
115 |
+
x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation)
|
116 |
+
|
117 |
+
Outputs:
|
118 |
+
y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if
|
119 |
+
input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note
|
120 |
+
that if the input is a one-hot matrix the output is the same regardless of representation.
|
121 |
+
这个模块是nn.Embedding 的包装器,它允许输是one-hot-like的表示(以及标准的tokenized表示),
|
122 |
+
并且如果输入对应于log-PMF,还以选择性地对最后 个维度应用softmax
|
123 |
+
"""
|
124 |
+
|
125 |
+
|
126 |
+
def __init__(self, embedding, use_softmax=True):
|
127 |
+
super(MeanEmbedding, self).__init__()
|
128 |
+
self.embedding = embedding
|
129 |
+
self.use_softmax = use_softmax
|
130 |
+
self.softmax = nn.Softmax(dim=-1)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
if len(x.shape) == 2:
|
134 |
+
return self.embedding(x)
|
135 |
+
elif len(x.shape) == 3:
|
136 |
+
if self.use_softmax:
|
137 |
+
return self.softmax(x) @ self.embedding.weight
|
138 |
+
else:
|
139 |
+
return x @ self.embedding.weight
|
140 |
+
else:
|
141 |
+
raise (NotImplementedError)
|
142 |
+
|
143 |
+
|
144 |
+
class PeriodicPositionalEncoding(nn.Module):
|
145 |
+
"""Positional encoding, adapted from 'The Annotated Transformer'
|
146 |
+
http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
147 |
+
这个模块实现了周期性的位置编码,这是Transformer模型的一个重要组成部分。
|
148 |
+
它使用正弦和余弦函数来生成位置编码
|
149 |
+
Args:
|
150 |
+
d_model (int): input and output dimension for the layer
|
151 |
+
max_seq_len (int): maximum allowed sequence length
|
152 |
+
dropout (float): Dropout rate
|
153 |
+
|
154 |
+
Inputs:
|
155 |
+
x (torch.tensor): of size (batch_size, sequence_length, d_model)
|
156 |
+
|
157 |
+
Outputs:
|
158 |
+
y (torch.tensor): of size (batch_size, sequence_length, d_model)
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, d_model, max_seq_len=4000, dropout=0.0):
|
162 |
+
super(PeriodicPositionalEncoding, self).__init__()
|
163 |
+
self.dropout = nn.Dropout(p=dropout)
|
164 |
+
|
165 |
+
pe = torch.zeros(max_seq_len, d_model)
|
166 |
+
position = torch.arange(0.0, max_seq_len).unsqueeze(1)
|
167 |
+
div_term = torch.exp(
|
168 |
+
torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
|
169 |
+
)
|
170 |
+
|
171 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
172 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
173 |
+
pe = pe.unsqueeze(0)
|
174 |
+
self.register_buffer("pe", pe)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = x + self.pe[:, : x.size(1)]
|
178 |
+
return self.dropout(x)
|
179 |
+
|
180 |
+
|
181 |
+
class PositionWiseFeedForward(nn.Module):
|
182 |
+
"""Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized).
|
183 |
+
这个模块实现了位置感知的前馈网络,这也是Transformer模型的一个重要组成部分。
|
184 |
+
它使用1x1的卷积来实现前馈网络。
|
185 |
+
Args:
|
186 |
+
d_model (int): input and output dimension for the layer
|
187 |
+
d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer
|
188 |
+
|
189 |
+
Inputs:
|
190 |
+
x (torch.tensor): of size (batch_size, sequence_length, d_model)
|
191 |
+
Outputs:
|
192 |
+
y (torch.tensor): of size (batch_size, sequence_length, d_model)
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, d_model, d_hidden, dropout=0.1):
|
196 |
+
super(PositionWiseFeedForward, self).__init__()
|
197 |
+
self.activation = nn.ReLU()
|
198 |
+
self.linear1 = nn.Conv1d(d_model, d_hidden, 1)
|
199 |
+
self.linear2 = nn.Conv1d(d_hidden, d_model, 1)
|
200 |
+
self.dropout = nn.Dropout(p=dropout)
|
201 |
+
|
202 |
+
def reset_parameters(self):
|
203 |
+
self.linear1.reset_parameters()
|
204 |
+
self.linear2.reset_parameters()
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
output = self.activation(self.linear1(x.transpose(1, 2)))
|
208 |
+
output = self.linear2(output).transpose(1, 2)
|
209 |
+
return self.dropout(output)
|
210 |
+
|
211 |
+
|
212 |
+
class DropNormLin(nn.Module):
|
213 |
+
"""nn.Module applying a linear layer, normalization, dropout, and activation
|
214 |
+
这个模块应用了一个线性层、归一化、dropout和激活函数。你可以选择使用层归一化 (In') 或批归一 (bn) ,或者跳过过归一化。
|
215 |
+
Args:
|
216 |
+
in_features (int): input dimension
|
217 |
+
out_features (int): output dimension
|
218 |
+
norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization
|
219 |
+
dropout (float): dropout to apply
|
220 |
+
actn (nn.Module): activation function to apply
|
221 |
+
|
222 |
+
Input:
|
223 |
+
x (torch.tensor): of size (batch_size, sequence_length, in_features)
|
224 |
+
input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional)
|
225 |
+
|
226 |
+
Output:
|
227 |
+
y (torch.tensor): of size (batch_size, sequence_length, out_features)
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU()
|
232 |
+
):
|
233 |
+
super(DropNormLin, self).__init__()
|
234 |
+
self.linear = nn.Linear(in_features, out_features)
|
235 |
+
if norm_type == "ln":
|
236 |
+
self.norm_layer = nn.LayerNorm(out_features)
|
237 |
+
elif norm_type == "bn":
|
238 |
+
self.norm_layer = MaskedBatchNorm1d(out_features)
|
239 |
+
else:
|
240 |
+
self.norm_layer = NoOp()
|
241 |
+
self.dropout = nn.Dropout(p=dropout)
|
242 |
+
self.actn = actn
|
243 |
+
|
244 |
+
def forward(self, x, input_mask=None):
|
245 |
+
h = self.linear(x)
|
246 |
+
if isinstance(self.norm_layer, MaskedBatchNorm1d):
|
247 |
+
h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose(
|
248 |
+
1, 2
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
h = self.norm_layer(h)
|
252 |
+
return self.dropout(self.actn(h))
|
253 |
+
|
254 |
+
|
255 |
+
class ResidualLinearLayer(nn.Module):
|
256 |
+
"""A Simple Residual Layer using a linear layer a relu and an optional layer norm.
|
257 |
+
这个模块实现了一个简单的残差层,使用了一个线性层、ReLU激活函数和一个可选的层归一化。
|
258 |
+
Args:
|
259 |
+
d_model (int): Model Dimension
|
260 |
+
use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`.
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(self, d_model, use_norm=True):
|
264 |
+
super(ResidualLinearLayer, self).__init__()
|
265 |
+
self.linear = nn.Linear(d_model, d_model)
|
266 |
+
self.ReLU = nn.ReLU()
|
267 |
+
self.use_norm = use_norm
|
268 |
+
self.norm = nn.LayerNorm(d_model)
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
z = self.linear(x)
|
272 |
+
z = self.ReLU(z)
|
273 |
+
if self.use_norm:
|
274 |
+
z = self.norm(z)
|
275 |
+
return x + z
|
276 |
+
|
277 |
+
|
278 |
+
class TriangleMultiplication(nn.Module):
|
279 |
+
def __init__(self, d_model=512, mode="outgoing"):
|
280 |
+
"""
|
281 |
+
Triangle multiplication as defined in Jumper et al. (2021)
|
282 |
+
这个模块实现了Jumper等人在2021年的论文中定义的三角乘法。它接受一个四维的张量作为输入
|
283 |
+
并通过一系列的线性变换和非线性激活函数,以及一个特殊的乘法操作(由 torch.einsum实现) ,来计算输出。
|
284 |
+
Args:
|
285 |
+
d_model (int): dimension of the embedding at each position
|
286 |
+
mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12).
|
287 |
+
|
288 |
+
Inputs:
|
289 |
+
X (torch.tensor): Pair representations of size (batch, nres, nres, channels)
|
290 |
+
mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size)
|
291 |
+
|
292 |
+
Outputs:
|
293 |
+
Y (torch.tensor): Pair representations of size (batch, nres, nres, channels)
|
294 |
+
"""
|
295 |
+
super().__init__()
|
296 |
+
self.mode = mode
|
297 |
+
assert self.mode in ["outgoing", "incoming"]
|
298 |
+
self.equation = (
|
299 |
+
"bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc"
|
300 |
+
)
|
301 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
302 |
+
self.left_edge_mlp = nn.Sequential(
|
303 |
+
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
|
304 |
+
)
|
305 |
+
self.right_edge_mlp = nn.Sequential(
|
306 |
+
nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
|
307 |
+
)
|
308 |
+
self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid())
|
309 |
+
self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model))
|
310 |
+
|
311 |
+
def forward(self, X, mask=None):
|
312 |
+
h = self.layer_norm(X)
|
313 |
+
|
314 |
+
A = self.left_edge_mlp(h)
|
315 |
+
B = self.right_edge_mlp(h)
|
316 |
+
G = self.skip(h)
|
317 |
+
|
318 |
+
if mask is not None:
|
319 |
+
A = A.masked_fill(~mask, 0.0)
|
320 |
+
B = B.masked_fill(~mask, 0.0)
|
321 |
+
|
322 |
+
h = torch.einsum(self.equation, A, B)
|
323 |
+
h = self.combine(h) * G
|
324 |
+
return h
|
325 |
+
|
326 |
+
|
327 |
+
class NodeProduct(nn.Module):
|
328 |
+
"""Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension,
|
329 |
+
process for single-sequence inputs.
|
330 |
+
这个模块实现了Jumper等人在2021年的论文中描述的节点乘积算法。
|
331 |
+
它接受一个二维的张量作为输入,然后通过一系列的线性变换和层归一化操作,来计算输出。
|
332 |
+
Args:
|
333 |
+
d_in (int): dimension of node embeddings (inputs)
|
334 |
+
d_out (int): dimension of edge embeddings (outputs)
|
335 |
+
|
336 |
+
Inputs:
|
337 |
+
node_features (torch.tensor): of size (batch_size, nres, d_model)
|
338 |
+
node_mask (torch.tensor): of size (batch_size, nres)
|
339 |
+
edge_mask (torch.tensor): of size (batch_size, nres, nres)
|
340 |
+
|
341 |
+
Outputs:
|
342 |
+
edge_features (torch.tensor): of size (batch_size, nres, nres, d_model)
|
343 |
+
"""
|
344 |
+
|
345 |
+
def __init__(self, d_in, d_out):
|
346 |
+
super().__init__()
|
347 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
348 |
+
self.left_lin = nn.Linear(d_in, d_in)
|
349 |
+
self.right_lin = nn.Linear(d_in, d_in)
|
350 |
+
self.edge_lin = nn.Linear(2 * d_in, d_out)
|
351 |
+
|
352 |
+
def forward(self, node_features, node_mask=None, edge_mask=None):
|
353 |
+
_, nres, _ = node_features.size()
|
354 |
+
|
355 |
+
node_features = self.layer_norm(node_features)
|
356 |
+
left_embs = self.left_lin(node_features)
|
357 |
+
right_embs = self.right_lin(node_features)
|
358 |
+
|
359 |
+
if node_mask is not None:
|
360 |
+
mask = node_mask[:, :, None]
|
361 |
+
left_embs = left_embs.masked_fill(~mask, 0.0)
|
362 |
+
right_embs = right_embs.masked_fill(~mask, 0.0)
|
363 |
+
|
364 |
+
left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1)
|
365 |
+
right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1)
|
366 |
+
edge_features = torch.cat([left_embs, right_embs], dim=-1)
|
367 |
+
edge_features = self.edge_lin(edge_features)
|
368 |
+
|
369 |
+
if edge_mask is not None:
|
370 |
+
mask = edge_mask[:, :, :, None]
|
371 |
+
edge_features = edge_features.masked_fill(~mask, 0.0)
|
372 |
+
|
373 |
+
return edge_features
|
374 |
+
|
375 |
+
|
376 |
+
class FourierFeaturization(nn.Module):
|
377 |
+
"""Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] ,
|
378 |
+
optionally trainable as described in [https://arxiv.org/abs/2106.02795].
|
379 |
+
这个模块实现了对低维输入数据的傅里叶特征化,这是一种将输入数据转换为频域表示的方法。
|
380 |
+
这个模块可以选择是否学习傅里叶特征的频率
|
381 |
+
Args:
|
382 |
+
d_input (int): dimension of inputs
|
383 |
+
d_model (int): dimension of outputs
|
384 |
+
trainable (bool): whether to learn the frequency of fourier features
|
385 |
+
scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!)
|
386 |
+
|
387 |
+
Inputs:
|
388 |
+
input (torch.tensor): of size (batch_size, *, d_input)
|
389 |
+
|
390 |
+
Outputs:
|
391 |
+
output (torch.tensor): of size (batch_size, *, d_output)
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(self, d_input, d_model, trainable=False, scale=1.0):
|
395 |
+
super().__init__()
|
396 |
+
self.scale = scale
|
397 |
+
|
398 |
+
if d_model % 2 != 0:
|
399 |
+
raise ValueError(
|
400 |
+
"d_model needs to be even for this featurization, try again!"
|
401 |
+
)
|
402 |
+
|
403 |
+
B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2)
|
404 |
+
self.trainable = trainable
|
405 |
+
if not trainable:
|
406 |
+
self.register_buffer("B", B)
|
407 |
+
else:
|
408 |
+
self.register_parameter("B", torch.nn.Parameter(B))
|
409 |
+
|
410 |
+
def forward(self, inputs):
|
411 |
+
h = inputs @ self.B
|
412 |
+
return torch.cat([h.cos(), h.sin()], -1)
|
413 |
+
|
414 |
+
|
415 |
+
class PositionalEncoding(nn.Module):
|
416 |
+
"""Axis-aligned positional encodings with log-linear spacing.
|
417 |
+
这个模块实现了对输入数据的位置编码,这是一种将输入数据的位置信息编码为连续的向量的方法。
|
418 |
+
这个模块使用了对数线性间隔的频率组件。
|
419 |
+
Args:
|
420 |
+
d_input (int): dimension of inputs
|
421 |
+
d_model (int): dimension of outputs
|
422 |
+
period_range (tuple of floats): Min and maximum periods for the
|
423 |
+
frequency components. Fourier features will be log-linearly spaced
|
424 |
+
between these values (inclusive).
|
425 |
+
|
426 |
+
Inputs:
|
427 |
+
input (torch.tensor): of size (..., d_input)
|
428 |
+
|
429 |
+
Outputs:
|
430 |
+
output (torch.tensor): of size (..., d_model)
|
431 |
+
"""
|
432 |
+
|
433 |
+
def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)):
|
434 |
+
super().__init__()
|
435 |
+
|
436 |
+
if d_model % (2 * d_input) != 0:
|
437 |
+
raise ValueError(
|
438 |
+
"d_model needs to be divisible by 2*d_input for this featurization, "
|
439 |
+
f"but got {d_model} versus {d_input}"
|
440 |
+
)
|
441 |
+
|
442 |
+
num_frequencies = d_model // (2 * d_input)
|
443 |
+
log_bounds = np.log10(period_range)
|
444 |
+
p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0)
|
445 |
+
w = 2 * math.pi / p
|
446 |
+
self.register_buffer("w", w)
|
447 |
+
|
448 |
+
def forward(self, inputs):
|
449 |
+
batch_dims = list(inputs.shape)[:-1]
|
450 |
+
# (..., 1, num_out) * (..., num_in, 1)
|
451 |
+
w = self.w.reshape(len(batch_dims) * [1] + [1, -1])
|
452 |
+
h = w * inputs[..., None]
|
453 |
+
h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1])
|
454 |
+
return h
|
455 |
+
|
456 |
+
|
457 |
+
class MaybeOnehotEmbedding(nn.Embedding):
|
458 |
+
"""Wrapper around :class:`torch.nn.Embedding` to support either int-encoded
|
459 |
+
LongTensors or one-hot encoded FloatTensors.
|
460 |
+
这个模块是torch.nn.Embedding 的包装器,它支持整数编码的LongTensor输入或者独热编码的FloatTensor输入。
|
461 |
+
如果输入是浮点类型,那么它会通过矩阵乘法来计算嵌入,否则,它会调用父类的 forward 方法来计算嵌入。
|
462 |
+
"""
|
463 |
+
|
464 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
465 |
+
if x.dtype.is_floating_point: # onehot
|
466 |
+
return x @ self.weight
|
467 |
+
return super().forward(x)
|
chroma/chroma/layers/complexity.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Layers for computing sequence complexities.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from chroma.constants import AA20
|
23 |
+
from chroma.layers.graph import collect_neighbors
|
24 |
+
|
25 |
+
|
26 |
+
def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30):
|
27 |
+
"""Compute local compositions per residue.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
|
31 |
+
(long) or `(num_batch, num_residues, num_alphabet)` (float).
|
32 |
+
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
|
33 |
+
w (int, optional): Window size.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
P (torch.Tensor): Local compositions with shape
|
37 |
+
`(num_batch, num_residues - w + 1, num_alphabet)`.
|
38 |
+
N (torch.Tensor): Local counts with shape
|
39 |
+
`(num_batch, num_residues - w + 1, num_alphabet)`.
|
40 |
+
mask_P (torch.Tensor): Mask with shape
|
41 |
+
`(num_batch, num_residues - w + 1)`.
|
42 |
+
"""
|
43 |
+
device = S.device
|
44 |
+
Q = len(AA20)
|
45 |
+
mask_i = (C > 0).float()
|
46 |
+
if len(S.shape) == 2:
|
47 |
+
S = F.one_hot(S, Q)
|
48 |
+
|
49 |
+
# Build neighborhoods and masks
|
50 |
+
S_onehot = mask_i[..., None] * S
|
51 |
+
kx = torch.arange(w, device=S.device) - w // 2
|
52 |
+
edge_idx = (
|
53 |
+
torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :]
|
54 |
+
)
|
55 |
+
mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1])
|
56 |
+
edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1)
|
57 |
+
C_i = C[..., None]
|
58 |
+
C_j = collect_neighbors(C_i, edge_idx)[..., 0]
|
59 |
+
mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float()
|
60 |
+
|
61 |
+
# Sum neighborhood composition
|
62 |
+
S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx)
|
63 |
+
N = S_j.sum(2)
|
64 |
+
|
65 |
+
num_N = N.sum(-1, keepdims=True)
|
66 |
+
P = N / (num_N + 1e-5)
|
67 |
+
mask_i = ((num_N[..., 0] > 0) & (C > 0)).float()
|
68 |
+
mask_ij = mask_i[..., None] * mask_ij
|
69 |
+
return P, N, edge_idx, mask_i, mask_ij
|
70 |
+
|
71 |
+
|
72 |
+
def complexity_lcp(
|
73 |
+
S: torch.LongTensor,
|
74 |
+
C: torch.LongTensor,
|
75 |
+
w: int = 30,
|
76 |
+
entropy_min: float = 2.32,
|
77 |
+
method: str = "naive",
|
78 |
+
differentiable=True,
|
79 |
+
eps: float = 1e-5,
|
80 |
+
min_coverage=0.9,
|
81 |
+
# entropy_min: float = 2.52,
|
82 |
+
# method = "chao-shen"
|
83 |
+
) -> torch.Tensor:
|
84 |
+
"""Compute the Local Composition Perplexity metric.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
|
88 |
+
(index tensor) or `(num_batch, num_residues, num_alphabet)`.
|
89 |
+
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
|
90 |
+
w (int): Window size.
|
91 |
+
grad_pseudocount (float): Pseudocount for stabilizing entropy gradients
|
92 |
+
on backwards pass.
|
93 |
+
eps (float): Small number for numerical stability in division and logarithms.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
U (torch.Tensor): Complexities with shape `(num_batch)`.
|
97 |
+
"""
|
98 |
+
|
99 |
+
# adjust window size based on sequence length
|
100 |
+
if S.shape[1] < w:
|
101 |
+
w = S.shape[1]
|
102 |
+
|
103 |
+
P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w)
|
104 |
+
|
105 |
+
# Only count windows with `min_coverage`
|
106 |
+
min_N = int(min_coverage * w)
|
107 |
+
mask_coverage = N.sum(-1) > int(min_coverage * w)
|
108 |
+
|
109 |
+
H = estimate_entropy(N, method=method)
|
110 |
+
U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
|
111 |
+
|
112 |
+
# Compute entropy as a function of perturbed counts
|
113 |
+
if differentiable and len(S.shape) == 3:
|
114 |
+
# Compute how a mutation changes entropy for each neighbor
|
115 |
+
N_neighbors = collect_neighbors(N, edge_idx)
|
116 |
+
mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx)
|
117 |
+
N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye(
|
118 |
+
N.shape[-1], device=N.device
|
119 |
+
)[None, None, None, ...]
|
120 |
+
N_ij = N_ij.clamp(min=0)
|
121 |
+
H_ij = estimate_entropy(N_ij, method=method)
|
122 |
+
U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square()
|
123 |
+
U_ij = mask_ij[..., None] * mask_coverage_j * U_ij
|
124 |
+
U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2])
|
125 |
+
U = U.detach() + U_differentiable - U_differentiable.detach()
|
126 |
+
|
127 |
+
U = (mask_i * U).sum(1)
|
128 |
+
return U
|
129 |
+
|
130 |
+
|
131 |
+
def complexity_scores_lcp_t(
|
132 |
+
t,
|
133 |
+
S: torch.LongTensor,
|
134 |
+
C: torch.LongTensor,
|
135 |
+
idx: torch.LongTensor,
|
136 |
+
edge_idx_t: torch.LongTensor,
|
137 |
+
mask_ij_t: torch.Tensor,
|
138 |
+
w: int = 30,
|
139 |
+
entropy_min: float = 2.515,
|
140 |
+
eps: float = 1e-5,
|
141 |
+
method: str = "chao-shen",
|
142 |
+
) -> torch.Tensor:
|
143 |
+
"""Compute local LCP scores for autoregressive decoding."""
|
144 |
+
Q = len(AA20)
|
145 |
+
O = F.one_hot(S, Q)
|
146 |
+
O_j = collect_neighbors(O, edge_idx_t)
|
147 |
+
idx_i = idx[:, t, None]
|
148 |
+
C_i = C[:, t, None]
|
149 |
+
idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0]
|
150 |
+
C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0]
|
151 |
+
|
152 |
+
# Sum valid neighbor counts
|
153 |
+
is_near = (idx_i - idx_j).abs() <= w / 2
|
154 |
+
same_chain = C_i == C_j
|
155 |
+
valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None]
|
156 |
+
N_k = (valid_ij_t * O_j).sum(-2)
|
157 |
+
|
158 |
+
# Compute counts under all possible extensions
|
159 |
+
N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...]
|
160 |
+
|
161 |
+
H = estimate_entropy(N_k, method=method)
|
162 |
+
U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
|
163 |
+
return U
|
164 |
+
|
165 |
+
|
166 |
+
def estimate_entropy(
|
167 |
+
N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11
|
168 |
+
) -> torch.Tensor:
|
169 |
+
"""Estimate entropy from counts.
|
170 |
+
|
171 |
+
See Chao, A., & Shen, T. J. (2003) for more details.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
H (torch.Tensor): Estimated entropy with shape `(...)`.
|
178 |
+
"""
|
179 |
+
N = N.float()
|
180 |
+
N_total = N.sum(-1, keepdims=True)
|
181 |
+
P = N / (N_total + eps)
|
182 |
+
|
183 |
+
if method == "chao-shen":
|
184 |
+
# Estimate coverage and adjusted frequencies
|
185 |
+
singletons = N.long().eq(1).sum(-1, keepdims=True).float()
|
186 |
+
C = 1.0 - singletons / (N_total + eps)
|
187 |
+
P_adjust = C * P
|
188 |
+
P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps)
|
189 |
+
H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1)
|
190 |
+
elif method == "miller-maddow":
|
191 |
+
bins = (N > 0).float().sum(-1)
|
192 |
+
bias = (bins - 1) / (2 * N_total[..., 0] + eps)
|
193 |
+
H = -(P * torch.log(P + eps)).sum(-1) + bias
|
194 |
+
elif method == "laplace":
|
195 |
+
N = N.float() + 1 / N.shape[-1]
|
196 |
+
N_total = N.sum(-1, keepdims=True)
|
197 |
+
P = N / (N_total + eps)
|
198 |
+
H = -(P * torch.log(P)).sum(-1)
|
199 |
+
else:
|
200 |
+
H = -(P * torch.log(P + eps)).sum(-1)
|
201 |
+
return H
|
chroma/chroma/layers/conv.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import platform
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
MACHINE = platform.machine()
|
21 |
+
"""
|
22 |
+
一维线性衰减滤波器
|
23 |
+
"""
|
24 |
+
|
25 |
+
def filter1D_linear_decay(Z, B):
|
26 |
+
"""Apply a low-pass filter with batch-heterogeneous coefficients.
|
27 |
+
|
28 |
+
Computes `x_i = z_i + b * x_{i-1}` where `b` varies per batch member.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
Z (torch.Tensor): Batch of one-dimensional signals with shape `(N, W)`.
|
32 |
+
B (torch.Tensor): Batch of coefficients with shape `(N)`.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
X (torch.Tensor): Result of applying linear recurrence with shape `(N, W)`.
|
36 |
+
"""
|
37 |
+
|
38 |
+
# Build filter coefficients as powers of B
|
39 |
+
N, W = Z.shape
|
40 |
+
k = (W - 1) - torch.arange(W, device=Z.device)
|
41 |
+
kernel = B[:, None, None] ** k[None, None, :]
|
42 |
+
|
43 |
+
# Pad on left to convolve from backwards in time
|
44 |
+
Z_pad = F.pad(Z, (W - 1, 0))[None, ...]
|
45 |
+
|
46 |
+
# Group convolution can effectively do one filter per batch
|
47 |
+
while True:
|
48 |
+
X = F.conv1d(Z_pad, kernel, stride=1, padding=0, groups=N)[0, :, :]
|
49 |
+
# on arm64 (M1 Mac) this convolution erroneously sometimes produces NaNs
|
50 |
+
if (
|
51 |
+
(MACHINE == "arm64")
|
52 |
+
and torch.isnan(X).any()
|
53 |
+
and (not torch.isnan(Z_pad).any())
|
54 |
+
and (not torch.isnan(kernel).any())
|
55 |
+
):
|
56 |
+
continue
|
57 |
+
break
|
58 |
+
return X
|
chroma/chroma/layers/graph.py
ADDED
@@ -0,0 +1,1126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Layers for building graph neural networks.
|
16 |
+
|
17 |
+
This module contains layers for building neural networks that can process
|
18 |
+
graph-structured data. The internal representations of these layers
|
19 |
+
are node and edge embeddings.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from typing import Callable, List, Optional, Tuple
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
from torch.utils.checkpoint import checkpoint
|
27 |
+
from tqdm.autonotebook import tqdm
|
28 |
+
|
29 |
+
from chroma.layers.attention import Attention
|
30 |
+
|
31 |
+
|
32 |
+
class GraphNN(nn.Module):
|
33 |
+
"""Graph neural network with optional edge updates.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
num_layers (int): Number of layers.
|
37 |
+
dim_nodes (int): Hidden dimension of node tensor.
|
38 |
+
dim_edges (int): Hidden dimension of edge tensor.
|
39 |
+
dropout (float): Dropout rate.
|
40 |
+
node_mlp_layers (int): Node update function, number of hidden layers.
|
41 |
+
Default is 1.
|
42 |
+
node_mlp_dim (int): Node update function, hidden dimension.
|
43 |
+
Default is to match MLP output dimension.
|
44 |
+
update_edge (Boolean): Include an edge-update step. Default: True
|
45 |
+
edge_mlp_layers (int): Edge update function, number of hidden layers.
|
46 |
+
Default is 1.
|
47 |
+
edge_mlp_dim (int): Edge update function, hidden dimension.
|
48 |
+
Default is to match MLP output dimension.
|
49 |
+
mlp_activation (str): MLP nonlinearity.
|
50 |
+
`'relu'`: Rectified linear unit.
|
51 |
+
`'softplus'`: Softplus.
|
52 |
+
norm (str): Which normalization function to apply between layers.
|
53 |
+
`'transformer'`: Default layernorm
|
54 |
+
`'layer'`: Masked Layer norm with shape (input.shape[1:])
|
55 |
+
`'instance'`: Masked Instance norm
|
56 |
+
scale (float): Scaling factor of edge input when updating node (default=1.0)
|
57 |
+
attentional (bool): If True, use attention for message aggregation function
|
58 |
+
instead of a sum. Default is False.
|
59 |
+
num_attention_heads (int): Number of attention heads (if attentional) to use.
|
60 |
+
Default is 4.
|
61 |
+
|
62 |
+
Inputs:
|
63 |
+
node_h (torch.Tensor): Node features with shape
|
64 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
65 |
+
edge_h (torch.Tensor): Edge features with shape
|
66 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
67 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
68 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
69 |
+
mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
|
70 |
+
mask_ij (tensor, optional): Edge mask with shape
|
71 |
+
`(num_batch, num_nodes, num_neighbors)`
|
72 |
+
|
73 |
+
Outputs:
|
74 |
+
node_h_out (torch.Tensor): Updated node features with shape
|
75 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
76 |
+
edge_h_out (torch.Tensor): Updated edge features with shape
|
77 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
num_layers: int,
|
83 |
+
dim_nodes: int,
|
84 |
+
dim_edges: int,
|
85 |
+
node_mlp_layers: int = 1,
|
86 |
+
node_mlp_dim: Optional[int] = None,
|
87 |
+
edge_update: bool = True,
|
88 |
+
edge_mlp_layers: int = 1,
|
89 |
+
edge_mlp_dim: Optional[int] = None,
|
90 |
+
mlp_activation: str = "relu",
|
91 |
+
dropout: float = 0.0,
|
92 |
+
norm: str = "transformer",
|
93 |
+
scale: float = 1.0,
|
94 |
+
skip_connect_input: bool = False,
|
95 |
+
attentional: bool = False,
|
96 |
+
num_attention_heads: int = 4,
|
97 |
+
checkpoint_gradients: bool = False,
|
98 |
+
):
|
99 |
+
super(GraphNN, self).__init__()
|
100 |
+
## 残差网络
|
101 |
+
self.skip_connect_input = skip_connect_input
|
102 |
+
"""
|
103 |
+
优化内存:正常的训练过程中,为了计算梯度,需要存储前向传播中所有层的激活值。
|
104 |
+
使用梯度检查点时,只在特定层保留这些激活值,并在需要时重新计算它们
|
105 |
+
"""
|
106 |
+
self.checkpoint_gradients = checkpoint_gradients
|
107 |
+
self.layers = nn.ModuleList(
|
108 |
+
[
|
109 |
+
GraphLayer(
|
110 |
+
dim_nodes=dim_nodes,
|
111 |
+
dim_edges=dim_edges,
|
112 |
+
node_mlp_layers=node_mlp_layers,
|
113 |
+
node_mlp_dim=node_mlp_dim,
|
114 |
+
edge_update=edge_update,
|
115 |
+
edge_mlp_layers=edge_mlp_layers,
|
116 |
+
edge_mlp_dim=edge_mlp_dim,
|
117 |
+
mlp_activation=mlp_activation,
|
118 |
+
dropout=dropout,
|
119 |
+
norm=norm,
|
120 |
+
scale=scale,
|
121 |
+
attentional=attentional,
|
122 |
+
num_attention_heads=num_attention_heads,
|
123 |
+
)
|
124 |
+
for _ in range(num_layers)
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
node_h: torch.Tensor,
|
131 |
+
edge_h: torch.Tensor,
|
132 |
+
edge_idx: torch.LongTensor,
|
133 |
+
mask_i: Optional[torch.Tensor] = None,
|
134 |
+
mask_ij: Optional[torch.Tensor] = None,
|
135 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
136 |
+
# Run every layer sequentially
|
137 |
+
node_h_init = node_h
|
138 |
+
edge_h_init = edge_h
|
139 |
+
for i, layer in enumerate(self.layers):
|
140 |
+
if self.skip_connect_input:
|
141 |
+
node_h = node_h + node_h_init
|
142 |
+
edge_h = edge_h + edge_h_init
|
143 |
+
|
144 |
+
# Update edge and node
|
145 |
+
node_h, edge_h = self.checkpoint(
|
146 |
+
layer, node_h, edge_h, edge_idx, mask_i, mask_ij
|
147 |
+
)
|
148 |
+
|
149 |
+
if self.skip_connect_input:
|
150 |
+
node_h = node_h - node_h_init
|
151 |
+
edge_h = edge_h - edge_h_init
|
152 |
+
|
153 |
+
# If mask was provided, apply it
|
154 |
+
if mask_i is not None:
|
155 |
+
node_h = node_h * (mask_i.unsqueeze(-1) != 0).type(torch.float32)
|
156 |
+
if mask_ij is not None:
|
157 |
+
edge_h = edge_h * (mask_ij.unsqueeze(-1) != 0).type(torch.float32)
|
158 |
+
return node_h, edge_h
|
159 |
+
|
160 |
+
def checkpoint(self, layer, *args):
|
161 |
+
if self.checkpoint_gradients:
|
162 |
+
return checkpoint(layer, *args)
|
163 |
+
else:
|
164 |
+
return layer(*args)
|
165 |
+
|
166 |
+
def sequential(
|
167 |
+
self,
|
168 |
+
tensors: dict,
|
169 |
+
pre_step_function: Callable = None,
|
170 |
+
post_step_function: Callable = None,
|
171 |
+
) -> dict:
|
172 |
+
"""Decode the GNN sequentially along the node index `t`, with callbacks.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
tensors (dict): Initial set of state tensors. At minimum this should
|
176 |
+
include the arguments to `forward`, namely `node_h`, `edge_h`,
|
177 |
+
`edge_idx`, `mask_i`, and `mask_ij`.
|
178 |
+
pre_step_function (function, optional): Callback function that is
|
179 |
+
optionally applied to `tensors` before each sequential GNN step as
|
180 |
+
`tensors_new = pre_step_function(t, pre_step_function)` where `t` is
|
181 |
+
the node index being updated. It should update elements of the
|
182 |
+
`tensors` dictionary, and it can access and update the intermediate
|
183 |
+
GNN state cache via the keyed lists of tensors in `node_h_cache` and
|
184 |
+
`edge_h_cache`.
|
185 |
+
post_step_function (function, optional): Same as `pre_step_function`, but
|
186 |
+
optionally applied after each sequential GNN step.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
tensors (dict): Processed set of tensors.
|
190 |
+
"""
|
191 |
+
|
192 |
+
# Initialize the state cache
|
193 |
+
tensors["node_h_cache"], tensors["edge_h_cache"] = self.init_steps(
|
194 |
+
tensors["node_h"], tensors["edge_h"]
|
195 |
+
)
|
196 |
+
|
197 |
+
# Sequential iteration
|
198 |
+
num_steps = tensors["node_h"].size(1)
|
199 |
+
for t in tqdm(range(num_steps), desc="Sequential decoding"):
|
200 |
+
if pre_step_function is not None:
|
201 |
+
tensors = pre_step_function(t, tensors)
|
202 |
+
|
203 |
+
tensors["node_h_cache"], tensors["edge_h_cache"] = self.step(
|
204 |
+
t,
|
205 |
+
tensors["node_h_cache"],
|
206 |
+
tensors["edge_h_cache"],
|
207 |
+
tensors["edge_idx"],
|
208 |
+
tensors["mask_i"],
|
209 |
+
tensors["mask_ij"],
|
210 |
+
)
|
211 |
+
|
212 |
+
if post_step_function is not None:
|
213 |
+
tensors = post_step_function(t, tensors)
|
214 |
+
|
215 |
+
return tensors
|
216 |
+
|
217 |
+
def init_steps(
|
218 |
+
self, node_h: torch.Tensor, edge_h: torch.Tensor
|
219 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
220 |
+
"""Initialize cached node and edge features.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
node_h (torch.Tensor): Node features with shape
|
224 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
225 |
+
edge_h (torch.Tensor): Edge features with shape
|
226 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
node_h_cache (torch.Tensor): List of cached node features with `num_layers + 1`
|
230 |
+
tensors of shape `(num_batch, num_nodes, dim_nodes)`.
|
231 |
+
edge_h_cache (torch.Tensor): List of cached edge features with `num_layers + 1`
|
232 |
+
tensors of shape `(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
233 |
+
"""
|
234 |
+
num_layers = len(self.layers)
|
235 |
+
node_h_cache = [node_h.clone() for _ in range(num_layers + 1)]
|
236 |
+
edge_h_cache = [edge_h.clone() for _ in range(num_layers + 1)]
|
237 |
+
return node_h_cache, edge_h_cache
|
238 |
+
|
239 |
+
def step(
|
240 |
+
self,
|
241 |
+
t: int,
|
242 |
+
node_h_cache: List[torch.Tensor],
|
243 |
+
edge_h_cache: List[torch.Tensor],
|
244 |
+
edge_idx: torch.LongTensor,
|
245 |
+
mask_i: Optional[torch.Tensor] = None,
|
246 |
+
mask_ij: Optional[torch.Tensor] = None,
|
247 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
248 |
+
"""Process GNN update for a specific node index t from cached intermediates.
|
249 |
+
|
250 |
+
Inputs:
|
251 |
+
t (int): Node index to decode.
|
252 |
+
node_h_cache (List[torch.Tensor]): List of cached node features with
|
253 |
+
`num_layers + 1` tensors of shape `(num_batch, num_nodes, dim_nodes)`.
|
254 |
+
edge_h_cache (List[torch.Tensor]): List of cached edge features with
|
255 |
+
`num_layers + 1` tensors of shape
|
256 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
257 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
258 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
259 |
+
mask_i (torch.Tensor, optional): Node mask with shape
|
260 |
+
`(num_batch, num_nodes)`.
|
261 |
+
mask_ij (torch.Tensor, optional): Edge mask with shape
|
262 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
263 |
+
|
264 |
+
Outputs:
|
265 |
+
node_h_cache (List[torch.Tensor]): Updated list of cached node features
|
266 |
+
with `num_layers + 1` tensors of shape
|
267 |
+
`(num_batch, num_nodes, dim_nodes)`. This method updates the tensors
|
268 |
+
in place for memory.
|
269 |
+
edge_h_cache (List[torch.Tensor]): Updated list of cached edge features
|
270 |
+
with `num_layers + 1` tensors of shape
|
271 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
272 |
+
"""
|
273 |
+
if self.skip_connect_input:
|
274 |
+
raise NotImplementedError
|
275 |
+
|
276 |
+
for i, layer in enumerate(self.layers):
|
277 |
+
# Because the edge updates depend on the updated nodes,
|
278 |
+
# we need both the input node features node_h and also
|
279 |
+
# the previous output node states node_h
|
280 |
+
node_h = node_h_cache[i]
|
281 |
+
node_h_out = node_h_cache[i + 1]
|
282 |
+
edge_h = edge_h_cache[i]
|
283 |
+
# Update edge and node
|
284 |
+
node_h_t, edge_h_t = checkpoint(
|
285 |
+
layer.step, t, node_h, node_h_out, edge_h, edge_idx, mask_i, mask_ij
|
286 |
+
)
|
287 |
+
|
288 |
+
# Scatter them in place
|
289 |
+
node_h_cache[i + 1].scatter_(
|
290 |
+
1, (t * torch.ones_like(node_h_t)).long(), node_h_t
|
291 |
+
)
|
292 |
+
edge_h_cache[i + 1].scatter_(
|
293 |
+
1, (t * torch.ones_like(edge_h_t)).long(), edge_h_t
|
294 |
+
)
|
295 |
+
|
296 |
+
return node_h_cache, edge_h_cache
|
297 |
+
|
298 |
+
## GNNLayer
|
299 |
+
class GraphLayer(nn.Module):
|
300 |
+
"""Graph layer that updates each node i given adjacent nodes and edges.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
dim_nodes (int): Hidden dimension of node tensor.
|
304 |
+
dim_edges (int): Hidden dimension of edge tensor.
|
305 |
+
node_mlp_layers (int): Node update function, number of hidden layers.
|
306 |
+
Default: 1.
|
307 |
+
node_mlp_dim (int): Node update function, hidden dimension.
|
308 |
+
Default: Matches MLP output dimension.
|
309 |
+
update_edge (Boolean): Include an edge-update step. Default: True
|
310 |
+
edge_mlp_layers (int): Edge update function, number of hidden layers.
|
311 |
+
Default: 1.
|
312 |
+
edge_mlp_dim (int): Edge update function, hidden dimension.
|
313 |
+
Default: Matches MLP output dimension.
|
314 |
+
mlp_activation (str): MLP nonlinearity.
|
315 |
+
`'relu'`: Rectified linear unit.
|
316 |
+
`'softplus'`: Softplus.
|
317 |
+
dropout (float): Dropout rate.
|
318 |
+
norm (str): Which normalization function to apply between layers.
|
319 |
+
`'transformer'`: Default layernorm
|
320 |
+
`'layer'`: Masked Layer norm with shape (input.shape[1:])
|
321 |
+
`'instance'`: Masked Instance norm
|
322 |
+
scale (float): Scaling factor of edge input when updating node (default=1.0)
|
323 |
+
|
324 |
+
Inputs:
|
325 |
+
node_h (torch.Tensor): Node features with shape
|
326 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
327 |
+
edge_h (torch.Tensor): Edge features with shape
|
328 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
329 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
330 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
331 |
+
mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
|
332 |
+
mask_ij (tensor, optional): Edge mask with shape
|
333 |
+
`(num_batch, num_nodes, num_neighbors)`
|
334 |
+
|
335 |
+
Outputs:
|
336 |
+
node_h_out (torch.Tensor): Updated node features with shape
|
337 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
338 |
+
edge_h_out (torch.Tensor): Updated edge features with shape
|
339 |
+
`(num_batch, num_nodes, num_neighbors, dim_nodes)`.
|
340 |
+
"""
|
341 |
+
|
342 |
+
def __init__(
|
343 |
+
self,
|
344 |
+
dim_nodes: int,
|
345 |
+
dim_edges: int,
|
346 |
+
node_mlp_layers: int = 1,
|
347 |
+
node_mlp_dim: Optional[int] = None,
|
348 |
+
edge_update: bool = True,
|
349 |
+
edge_mlp_layers: int = 1,
|
350 |
+
edge_mlp_dim: Optional[int] = None,
|
351 |
+
mlp_activation: str = "relu",
|
352 |
+
dropout: float = 0.0,
|
353 |
+
norm: str = "transformer",
|
354 |
+
scale: float = 1.0,
|
355 |
+
attentional: bool = False,
|
356 |
+
num_attention_heads: int = 4,
|
357 |
+
):
|
358 |
+
super(GraphLayer, self).__init__()
|
359 |
+
|
360 |
+
# Store scale
|
361 |
+
self.scale = scale
|
362 |
+
self.dim_nodes = dim_nodes
|
363 |
+
self.dim_edges = dim_edges
|
364 |
+
self.attentional = attentional
|
365 |
+
|
366 |
+
self.node_norm_layer = MaskedNorm(
|
367 |
+
dim=1, num_features=dim_nodes, affine=True, norm=norm
|
368 |
+
)
|
369 |
+
|
370 |
+
self.message_mlp = MLP(
|
371 |
+
dim_in=2 * dim_nodes + dim_edges,
|
372 |
+
dim_out=dim_nodes,
|
373 |
+
num_layers_hidden=edge_mlp_layers,
|
374 |
+
dim_hidden=edge_mlp_dim,
|
375 |
+
activation=mlp_activation,
|
376 |
+
dropout=dropout,
|
377 |
+
)
|
378 |
+
self.update_mlp = MLP(
|
379 |
+
dim_in=2 * dim_nodes,
|
380 |
+
dim_out=dim_nodes,
|
381 |
+
num_layers_hidden=node_mlp_layers,
|
382 |
+
dim_hidden=node_mlp_dim,
|
383 |
+
activation=mlp_activation,
|
384 |
+
dropout=dropout,
|
385 |
+
)
|
386 |
+
self.edge_update = edge_update
|
387 |
+
self.edge_norm_layer = MaskedNorm(
|
388 |
+
dim=2, num_features=dim_edges, affine=True, norm=norm
|
389 |
+
)
|
390 |
+
if self.edge_update:
|
391 |
+
self.edge_mlp = MLP(
|
392 |
+
dim_in=2 * dim_nodes + dim_edges,
|
393 |
+
dim_out=dim_edges,
|
394 |
+
num_layers_hidden=edge_mlp_layers,
|
395 |
+
dim_hidden=edge_mlp_dim,
|
396 |
+
activation=mlp_activation,
|
397 |
+
dropout=dropout,
|
398 |
+
)
|
399 |
+
|
400 |
+
if self.attentional:
|
401 |
+
self.attention = Attention(n_head=num_attention_heads, d_model=dim_nodes)
|
402 |
+
## attention
|
403 |
+
def attend(
|
404 |
+
self, node_h: torch.Tensor, messages: torch.Tensor, mask_ij: torch.Tensor
|
405 |
+
) -> torch.Tensor:
|
406 |
+
B, L, K, D = messages.size()
|
407 |
+
queries = node_h.reshape(-1, 1, D)
|
408 |
+
keys = messages.reshape(-1, K, D)
|
409 |
+
values = messages.reshape(-1, K, D)
|
410 |
+
mask = mask_ij.reshape(-1, 1, 1, K).bool() if mask_ij is not None else None
|
411 |
+
return self.attention(queries, keys, values, mask=mask).reshape(B, L, D)
|
412 |
+
## _normalize:Edge and node
|
413 |
+
def _normalize(self, node_h, edge_h, mask_i=None, mask_ij=None):
|
414 |
+
# Normalize node and edge embeddings
|
415 |
+
node_h_norm = self.node_norm_layer(node_h, mask_i)
|
416 |
+
edge_h_norm = self.edge_norm_layer(edge_h, mask_ij)
|
417 |
+
return node_h_norm, edge_h_norm
|
418 |
+
## ?
|
419 |
+
def _normalize_t(
|
420 |
+
self, edge_node_stack_t, mask_ij_t, include_nodes=True, include_edges=True
|
421 |
+
):
|
422 |
+
# Apply normalization (since we have only normalized time t information)
|
423 |
+
node_i_t = edge_node_stack_t[:, :, :, : self.dim_nodes]
|
424 |
+
node_j_t = edge_node_stack_t[:, :, :, self.dim_nodes : 2 * self.dim_nodes]
|
425 |
+
edge_h_t = edge_node_stack_t[:, :, :, 2 * self.dim_nodes :]
|
426 |
+
if include_nodes:
|
427 |
+
node_i_t = self.node_norm_layer(node_i_t, mask_ij_t)
|
428 |
+
node_j_t = self.node_norm_layer(node_j_t, mask_ij_t)
|
429 |
+
if include_edges:
|
430 |
+
edge_h_t = self.edge_norm_layer(edge_h_t, mask_ij_t)
|
431 |
+
edge_node_stack_t = torch.cat([node_i_t, node_j_t, edge_h_t], -1)
|
432 |
+
return edge_node_stack_t
|
433 |
+
|
434 |
+
def _update_nodes(
|
435 |
+
self, node_h, node_h_norm, edge_h_norm, edge_idx, mask_i=None, mask_ij=None
|
436 |
+
):
|
437 |
+
"""Update nodes given adjacent nodes and edges"""
|
438 |
+
# Compute messages at each ij
|
439 |
+
edge_node_stack = pack_edges(node_h_norm, edge_h_norm, edge_idx)
|
440 |
+
messages = self.message_mlp(edge_node_stack)
|
441 |
+
if mask_ij is not None:
|
442 |
+
messages = messages * mask_ij.unsqueeze(-1)
|
443 |
+
|
444 |
+
# Aggregate messages
|
445 |
+
if self.attentional:
|
446 |
+
message = self.attend(node_h_norm, messages, mask_ij)
|
447 |
+
else:
|
448 |
+
message = messages.sum(2) / self.scale
|
449 |
+
|
450 |
+
node_stack = torch.cat([node_h_norm, message], -1)
|
451 |
+
|
452 |
+
# Update nodes given aggregated messages
|
453 |
+
node_h_out = node_h + self.update_mlp(node_stack)
|
454 |
+
if mask_i is not None:
|
455 |
+
node_h_out = node_h_out * mask_i.unsqueeze(-1)
|
456 |
+
return node_h_out
|
457 |
+
|
458 |
+
def _update_nodes_t(
|
459 |
+
self,
|
460 |
+
t,
|
461 |
+
node_h,
|
462 |
+
node_h_norm_t,
|
463 |
+
edge_h_norm_t,
|
464 |
+
edge_idx_t,
|
465 |
+
mask_i_t=None,
|
466 |
+
mask_ij_t=None,
|
467 |
+
):
|
468 |
+
"""Update nodes at index t given adjacent nodes and edges"""
|
469 |
+
# Compute messages at each ij
|
470 |
+
edge_node_stack_t = mask_ij_t.unsqueeze(-1) * pack_edges_step(
|
471 |
+
t, node_h, edge_h_norm_t, edge_idx_t
|
472 |
+
)
|
473 |
+
|
474 |
+
# Apply normalization of gathered tensors
|
475 |
+
edge_node_stack_t = self._normalize_t(
|
476 |
+
edge_node_stack_t, mask_ij_t, include_edges=False
|
477 |
+
)
|
478 |
+
|
479 |
+
messages_t = self.message_mlp(edge_node_stack_t)
|
480 |
+
if mask_ij_t is not None:
|
481 |
+
messages_t = messages_t * mask_ij_t.unsqueeze(-1)
|
482 |
+
|
483 |
+
# Aggregate messages
|
484 |
+
if self.attentional:
|
485 |
+
message_t = self.attend(node_h_norm_t, messages_t, mask_ij_t)
|
486 |
+
else:
|
487 |
+
message_t = messages_t.sum(2) / self.scale
|
488 |
+
|
489 |
+
node_stack_t = torch.cat([node_h_norm_t, message_t], -1)
|
490 |
+
# Update nodes given aggregated messages
|
491 |
+
node_h_t = node_h[:, t, :].unsqueeze(1)
|
492 |
+
node_h_out_t = node_h_t + self.update_mlp(node_stack_t)
|
493 |
+
if mask_i_t is not None:
|
494 |
+
node_h_out_t = node_h_out_t * mask_i_t.unsqueeze(-1)
|
495 |
+
return node_h_out_t
|
496 |
+
|
497 |
+
def _update_edges(self, edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij):
|
498 |
+
"""Update edges given adjacent nodes and edges"""
|
499 |
+
edge_node_stack = pack_edges(node_h_out, edge_h_norm, edge_idx)
|
500 |
+
|
501 |
+
edge_h_out = edge_h + self.edge_mlp(edge_node_stack)
|
502 |
+
if mask_ij is not None:
|
503 |
+
edge_h_out = edge_h_out * mask_ij.unsqueeze(-1)
|
504 |
+
return edge_h_out
|
505 |
+
|
506 |
+
def _update_edges_t(
|
507 |
+
self, t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t
|
508 |
+
):
|
509 |
+
"""Update edges given adjacent nodes and edges"""
|
510 |
+
edge_node_stack_t = pack_edges_step(t, node_h_out, edge_h_t_norm, edge_idx_t)
|
511 |
+
|
512 |
+
edge_h_out_t = edge_h_t + self.edge_mlp(edge_node_stack_t)
|
513 |
+
if mask_ij_t is not None:
|
514 |
+
edge_h_out_t = edge_h_out_t * mask_ij_t.unsqueeze(-1)
|
515 |
+
return edge_h_out_t
|
516 |
+
|
517 |
+
def forward(
|
518 |
+
self,
|
519 |
+
node_h: torch.Tensor,
|
520 |
+
edge_h: torch.Tensor,
|
521 |
+
edge_idx: torch.LongTensor,
|
522 |
+
mask_i: Optional[torch.Tensor] = None,
|
523 |
+
mask_ij: Optional[torch.Tensor] = None,
|
524 |
+
):
|
525 |
+
node_h_norm, edge_h_norm = self._normalize(node_h, edge_h, mask_i, mask_ij)
|
526 |
+
if mask_i is not None:
|
527 |
+
mask_i = (mask_i != 0).type(torch.float32)
|
528 |
+
if mask_ij is not None:
|
529 |
+
mask_ij = (mask_ij != 0).type(torch.float32)
|
530 |
+
node_h_out = self._update_nodes(
|
531 |
+
node_h, node_h_norm, edge_h_norm, edge_idx, mask_i, mask_ij
|
532 |
+
)
|
533 |
+
edge_h_out = None
|
534 |
+
if self.edge_update:
|
535 |
+
edge_h_out = self._update_edges(
|
536 |
+
edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij
|
537 |
+
)
|
538 |
+
return node_h_out, edge_h_out
|
539 |
+
|
540 |
+
def step(
|
541 |
+
self,
|
542 |
+
t: int,
|
543 |
+
node_h: torch.Tensor,
|
544 |
+
node_h_out: torch.Tensor,
|
545 |
+
edge_h: torch.Tensor,
|
546 |
+
edge_idx: torch.LongTensor,
|
547 |
+
mask_i: Optional[torch.Tensor] = None,
|
548 |
+
mask_ij: Optional[torch.Tensor] = None,
|
549 |
+
):
|
550 |
+
"""Compute update for a single node index `t`.
|
551 |
+
|
552 |
+
This function can be useful for sequential computation of graph
|
553 |
+
updates, for example with autoregressive architectures.
|
554 |
+
|
555 |
+
Args:
|
556 |
+
t (int): Index of node dimension to update
|
557 |
+
node_h (torch.Tensor): Node features with shape
|
558 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
559 |
+
node_h_out (torch.Tensor): Cached outputs of preceding steps with shape
|
560 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
561 |
+
edge_h (torch.Tensor): Edge features with shape
|
562 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
563 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
564 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
565 |
+
mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
|
566 |
+
mask_ij (tensor, optional): Edge mask with shape
|
567 |
+
`(num_batch, num_nodes, num_neighbors)`
|
568 |
+
|
569 |
+
Resturns:
|
570 |
+
node_h_t (torch.Tensor): Updated node features with shape
|
571 |
+
`(num_batch, 1, dim_nodes)`.
|
572 |
+
edge_h_t (torch.Tensor): Updated edge features with shape
|
573 |
+
`(num_batch, 1, num_neighbors, dim_nodes)`.
|
574 |
+
"""
|
575 |
+
node_h_t = node_h[:, t, :].unsqueeze(1)
|
576 |
+
edge_h_t = edge_h[:, t, :, :].unsqueeze(1)
|
577 |
+
edge_idx_t = edge_idx[:, t, :].unsqueeze(1)
|
578 |
+
mask_i_t = mask_i[:, t].unsqueeze(1)
|
579 |
+
mask_ij_t = mask_ij[:, t, :].unsqueeze(1)
|
580 |
+
|
581 |
+
""" For a single step we need to apply the normalization both at node t and
|
582 |
+
also for all of the neighborhood tensors that feed in at t.
|
583 |
+
"""
|
584 |
+
node_h_t_norm, edge_h_t_norm = self._normalize(
|
585 |
+
node_h_t, edge_h_t, mask_i_t, mask_ij_t
|
586 |
+
)
|
587 |
+
node_h_t = self._update_nodes_t(
|
588 |
+
t, node_h, node_h_t_norm, edge_h_t_norm, edge_idx_t, mask_i_t, mask_ij_t
|
589 |
+
)
|
590 |
+
|
591 |
+
if self.edge_update:
|
592 |
+
node_h_out = node_h_out.scatter(
|
593 |
+
1, (t * torch.ones_like(node_h_t)).long(), node_h_t
|
594 |
+
)
|
595 |
+
edge_h_t = self._update_edges_t(
|
596 |
+
t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t
|
597 |
+
)
|
598 |
+
return node_h_t, edge_h_t
|
599 |
+
|
600 |
+
## 单纯进行线性变换:Equivariance
|
601 |
+
class MLP(nn.Module):
|
602 |
+
"""Multilayer perceptron with variable input, hidden, and output dims.
|
603 |
+
|
604 |
+
Args:
|
605 |
+
dim_in (int): Feature dimension of input tensor.
|
606 |
+
dim_hidden (int or None): Feature dimension of intermediate layers.
|
607 |
+
Defaults to matching output dimension.
|
608 |
+
dim_out (int or None): Feature dimension of output tensor.
|
609 |
+
Defaults to matching input dimension.
|
610 |
+
num_layers_hidden (int): Number of hidden MLP layers.
|
611 |
+
activation (str): MLP nonlinearity.
|
612 |
+
`'relu'`: Rectified linear unit.
|
613 |
+
`'softplus'`: Softplus.
|
614 |
+
dropout (float): Dropout rate. Default is 0.
|
615 |
+
|
616 |
+
Inputs:
|
617 |
+
h (torch.Tensor): Input tensor with shape `(..., dim_in)`
|
618 |
+
|
619 |
+
Outputs:
|
620 |
+
h (torch.Tensor): Input tensor with shape `(..., dim_in)`
|
621 |
+
"""
|
622 |
+
|
623 |
+
def __init__(
|
624 |
+
self,
|
625 |
+
dim_in: int,
|
626 |
+
dim_hidden: Optional[int] = None,
|
627 |
+
dim_out: Optional[int] = None,
|
628 |
+
num_layers_hidden: int = 1,
|
629 |
+
activation: str = "relu",
|
630 |
+
dropout: float = 0.0,
|
631 |
+
):
|
632 |
+
super(MLP, self).__init__()
|
633 |
+
|
634 |
+
# Default is dimension preserving
|
635 |
+
dim_out = dim_out if dim_out is not None else dim_in
|
636 |
+
dim_hidden = dim_hidden if dim_hidden is not None else dim_out
|
637 |
+
|
638 |
+
nonlinearites = {"relu": nn.ReLU, "softplus": nn.Softplus}
|
639 |
+
activation_func = nonlinearites[activation]
|
640 |
+
|
641 |
+
if num_layers_hidden == 0:
|
642 |
+
layers = [nn.Linear(dim_in, dim_out)]
|
643 |
+
else:
|
644 |
+
layers = []
|
645 |
+
for i in range(num_layers_hidden):
|
646 |
+
d_1 = dim_in if i == 0 else dim_hidden
|
647 |
+
layers = layers + [
|
648 |
+
nn.Linear(d_1, dim_hidden),
|
649 |
+
activation_func(),
|
650 |
+
nn.Dropout(dropout),
|
651 |
+
]
|
652 |
+
layers = layers + [nn.Linear(dim_hidden, dim_out)]
|
653 |
+
self.layers = nn.Sequential(*layers)
|
654 |
+
|
655 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
656 |
+
return self.layers(h)
|
657 |
+
|
658 |
+
|
659 |
+
def collect_neighbors(node_h: torch.Tensor, edge_idx: torch.Tensor) -> torch.Tensor:
|
660 |
+
"""Collect neighbor node features as edge features.
|
661 |
+
|
662 |
+
For each node i, collect the embeddings of neighbors {j in N(i)} as edge
|
663 |
+
features neighbor_ij.
|
664 |
+
|
665 |
+
Args:
|
666 |
+
node_h (torch.Tensor): Node features with shape
|
667 |
+
`(num_batch, num_nodes, num_features)`.
|
668 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
669 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
neighbor_h (torch.Tensor): Edge features containing neighbor node information
|
673 |
+
with shape `(num_batch, num_nodes, num_neighbors, num_features)`.
|
674 |
+
"""
|
675 |
+
num_batch, num_nodes, num_neighbors = edge_idx.shape
|
676 |
+
num_features = node_h.shape[2]
|
677 |
+
|
678 |
+
# Flatten for the gather operation then reform the full tensor
|
679 |
+
idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
|
680 |
+
idx_flat = idx_flat.expand(-1, -1, num_features)
|
681 |
+
neighbor_h = torch.gather(node_h, 1, idx_flat)
|
682 |
+
neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
|
683 |
+
return neighbor_h
|
684 |
+
|
685 |
+
|
686 |
+
def collect_edges(
|
687 |
+
edge_h_dense: torch.Tensor, edge_idx: torch.LongTensor
|
688 |
+
) -> torch.Tensor:
|
689 |
+
"""Collect sparse edge features from a dense pairwise tensor.
|
690 |
+
|
691 |
+
Args:
|
692 |
+
edge_h_dense (torch.Tensor): Dense edges features with shape
|
693 |
+
`(num_batch, num_nodes, num_nodes, num_features)`.
|
694 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
695 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
696 |
+
|
697 |
+
Returns:
|
698 |
+
edge_h (torch.Tensor): Edge features with shape
|
699 |
+
(num_batch, num_nodes, num_neighbors, num_features)`.
|
700 |
+
"""
|
701 |
+
gather_idx = edge_idx.unsqueeze(-1).expand(-1, -1, -1, edge_h_dense.size(-1))
|
702 |
+
edge_h = torch.gather(edge_h_dense, 2, gather_idx)
|
703 |
+
return edge_h
|
704 |
+
|
705 |
+
|
706 |
+
def collect_edges_transpose(
|
707 |
+
edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_ij: torch.Tensor
|
708 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
709 |
+
"""Collect edge embeddings of reversed (transposed) edges in-place.
|
710 |
+
|
711 |
+
Args:
|
712 |
+
edge_h (torch.Tensor): Edge features with shape
|
713 |
+
`(num_batch, num_nodes, num_neighbors, num_features_edges)`.
|
714 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
715 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
716 |
+
mask_ij (torch.Tensor): Edge mask with shape
|
717 |
+
`(num_batch, num_nodes, num_neighbors)`
|
718 |
+
|
719 |
+
Returns:
|
720 |
+
edge_h_transpose (torch.Tensor): Edge features of transpose with shape
|
721 |
+
`(num_batch, num_nodes, num_neighbors, num_features_edges)`.
|
722 |
+
mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape
|
723 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
724 |
+
"""
|
725 |
+
num_batch, num_residues, num_k, num_features = list(edge_h.size())
|
726 |
+
|
727 |
+
# Get indices of reverse edges
|
728 |
+
ij_to_ji, mask_ji = transpose_edge_idx(edge_idx, mask_ij)
|
729 |
+
|
730 |
+
# Gather features at reverse edges
|
731 |
+
edge_h_flat = edge_h.reshape(num_batch, num_residues * num_k, -1)
|
732 |
+
ij_to_ji = ij_to_ji.unsqueeze(-1).expand(-1, -1, num_features)
|
733 |
+
edge_h_transpose = torch.gather(edge_h_flat, 1, ij_to_ji)
|
734 |
+
edge_h_transpose = edge_h_transpose.reshape(
|
735 |
+
num_batch, num_residues, num_k, num_features
|
736 |
+
)
|
737 |
+
edge_h_transpose = mask_ji.unsqueeze(-1) * edge_h_transpose
|
738 |
+
return edge_h_transpose, mask_ji
|
739 |
+
|
740 |
+
|
741 |
+
def scatter_edges(edge_h: torch.Tensor, edge_idx: torch.LongTensor) -> torch.Tensor:
|
742 |
+
"""Scatter sparse edge features into a dense pairwise tensor.
|
743 |
+
Args:
|
744 |
+
edge_h (torch.Tensor): Edge features with shape
|
745 |
+
`(num_batch, num_nodes, num_neighbors, num_features_edges)`.
|
746 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
747 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
748 |
+
|
749 |
+
Returns:
|
750 |
+
edge_h_dense (torch.Tensor): Dense edge features with shape
|
751 |
+
`(batch_size, num_nodes, num_nodes, dimensions)`.
|
752 |
+
"""
|
753 |
+
assert edge_h.dim() == 4
|
754 |
+
assert edge_idx.dim() == 3
|
755 |
+
bs, nres, _, dim = edge_h.size()
|
756 |
+
edge_indices = edge_idx.unsqueeze(-1).repeat(1, 1, 1, dim)
|
757 |
+
result = torch.zeros(
|
758 |
+
size=(bs, nres, nres, dim), dtype=edge_h.dtype, device=edge_h.device,
|
759 |
+
)
|
760 |
+
return result.scatter(dim=2, index=edge_indices, src=edge_h)
|
761 |
+
|
762 |
+
|
763 |
+
def pack_edges(
|
764 |
+
node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor
|
765 |
+
) -> torch.Tensor:
|
766 |
+
"""Pack nodes and edge features into edge features.
|
767 |
+
|
768 |
+
Expands each edge_ij by packing node i, node j, and edge ij into
|
769 |
+
{node,node,edge}_ij.
|
770 |
+
|
771 |
+
Args:
|
772 |
+
node_h (torch.Tensor): Node features with shape
|
773 |
+
`(num_batch, num_nodes, num_features_nodes)`.
|
774 |
+
edge_h (torch.Tensor): Edge features with shape
|
775 |
+
`(num_batch, num_nodes, num_neighbors, num_features_edges)`.
|
776 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
777 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
778 |
+
|
779 |
+
Returns:
|
780 |
+
edge_packed (torch.Tensor): Concatenated node and edge features with shape
|
781 |
+
(num_batch, num_nodes, num_neighbors, num_features_nodes
|
782 |
+
+ 2*num_features_edges)`.
|
783 |
+
"""
|
784 |
+
num_neighbors = edge_h.shape[2]
|
785 |
+
node_i = node_h.unsqueeze(2).expand(-1, -1, num_neighbors, -1)
|
786 |
+
node_j = collect_neighbors(node_h, edge_idx)
|
787 |
+
edge_packed = torch.cat([node_i, node_j, edge_h], -1)
|
788 |
+
return edge_packed
|
789 |
+
|
790 |
+
|
791 |
+
def pack_edges_step(
|
792 |
+
t: int, node_h: torch.Tensor, edge_h_t: torch.Tensor, edge_idx_t: torch.LongTensor
|
793 |
+
) -> torch.Tensor:
|
794 |
+
"""Pack node and edge features into edge features for a single node index t.
|
795 |
+
|
796 |
+
Expands each edge_ij by packing node i, node j, and edge ij into
|
797 |
+
{node,node,edge}_ij.
|
798 |
+
|
799 |
+
Args:
|
800 |
+
t (int): Node index to decode.
|
801 |
+
node_h (torch.Tensor): Node features at all positions with shape
|
802 |
+
`(num_batch, num_nodes, num_features_nodes)`.
|
803 |
+
edge_h_t (torch.Tensor): Edge features at index `t` with shape
|
804 |
+
`(num_batch, 1, num_neighbors, num_features_edges)`.
|
805 |
+
edge_idx_t (torch.LongTensor): Edge indices at index `t` for neighbors with shape
|
806 |
+
`(num_batch, 1, num_neighbors)`.
|
807 |
+
|
808 |
+
Returns:
|
809 |
+
edge_packed (torch.Tensor): Concatenated node and edge features
|
810 |
+
for index `t` with shape
|
811 |
+
(num_batch, 1, num_neighbors, num_features_nodes
|
812 |
+
+ 2*num_features_edges)`.
|
813 |
+
"""
|
814 |
+
num_nodes_i = node_h.shape[1]
|
815 |
+
num_neighbors = edge_h_t.shape[2]
|
816 |
+
node_h_t = node_h[:, t, :].unsqueeze(1)
|
817 |
+
node_i = node_h_t.unsqueeze(2).expand(-1, -1, num_neighbors, -1)
|
818 |
+
node_j = collect_neighbors(node_h, edge_idx_t)
|
819 |
+
edge_packed = torch.cat([node_i, node_j, edge_h_t], -1)
|
820 |
+
return edge_packed
|
821 |
+
|
822 |
+
|
823 |
+
def transpose_edge_idx(
|
824 |
+
edge_idx: torch.LongTensor, mask_ij: torch.Tensor
|
825 |
+
) -> Tuple[torch.LongTensor, torch.Tensor]:
|
826 |
+
"""Collect edge indices of reverse edges in-place at each edge.
|
827 |
+
|
828 |
+
The tensor `edge_idx` stores a directed graph topology as a tensor of
|
829 |
+
neighbor indices, where an element `edge_idx[b,i,k]` corresponds to the
|
830 |
+
node index of neighbor `k` of node `i` in batch member `b`.
|
831 |
+
|
832 |
+
This function takes a directed graph topology and returns an index tensor
|
833 |
+
that maps, in-place, to the reversed edges (if they exist). The indices
|
834 |
+
correspond to the contracted dimension of `edge_index` when it is viewed as
|
835 |
+
`(num_batch, num_nodes * num_neighbors)`. These indices can be used in
|
836 |
+
conjunction with `torch.gather` to collect edge embeddings of `j->i` at
|
837 |
+
`i->j`. See `collect_edges_transpose` for an example.
|
838 |
+
|
839 |
+
For reverse `j->i` edges that do not exist in the directed graph, the
|
840 |
+
function also returns a binary mask `mask_ji` indicating which edges
|
841 |
+
have both `i->j` and `j->i` present in the graph.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
845 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
846 |
+
mask_ij (torch.Tensor): Edge mask with shape
|
847 |
+
`(num_batch, num_nodes, num_neighbors)`
|
848 |
+
|
849 |
+
Returns:
|
850 |
+
ij_to_ji (torch.LongTensor): Flat indices for indexing ji in-place at ij with
|
851 |
+
shape `(num_batch, num_nodes * num_neighbors)`.
|
852 |
+
mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape
|
853 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
854 |
+
"""
|
855 |
+
num_batch, num_residues, num_k = list(edge_idx.size())
|
856 |
+
|
857 |
+
# 1. Collect neighbors of neighbors
|
858 |
+
edge_idx_flat = edge_idx.reshape([num_batch, num_residues * num_k, 1]).expand(
|
859 |
+
-1, -1, num_k
|
860 |
+
)
|
861 |
+
edge_idx_neighbors = torch.gather(edge_idx, 1, edge_idx_flat)
|
862 |
+
# (b,i,j,k) gives the kth neighbor of the jth neighbor of i
|
863 |
+
edge_idx_neighbors = edge_idx_neighbors.reshape(
|
864 |
+
[num_batch, num_residues, num_k, num_k]
|
865 |
+
)
|
866 |
+
|
867 |
+
# 2. Determine which k at j maps back to i (if it exists)
|
868 |
+
residue_i = torch.arange(num_residues, device=edge_idx.device).reshape(
|
869 |
+
(1, -1, 1, 1)
|
870 |
+
)
|
871 |
+
edge_idx_match = (edge_idx_neighbors == residue_i).type(torch.float32)
|
872 |
+
return_mask, return_idx = torch.max(edge_idx_match, -1)
|
873 |
+
|
874 |
+
# 3. Build flat indices
|
875 |
+
ij_to_ji = edge_idx * num_k + return_idx
|
876 |
+
ij_to_ji = ij_to_ji.reshape(num_batch, -1)
|
877 |
+
|
878 |
+
# 4. Transpose mask
|
879 |
+
mask_ji = torch.gather(mask_ij.reshape(num_batch, -1), -1, ij_to_ji)
|
880 |
+
mask_ji = mask_ji.reshape(num_batch, num_residues, num_k)
|
881 |
+
mask_ji = mask_ij * return_mask * mask_ji
|
882 |
+
return ij_to_ji, mask_ji
|
883 |
+
|
884 |
+
|
885 |
+
def permute_tensor(
|
886 |
+
tensor: torch.Tensor, dim: int, permute_idx: torch.LongTensor
|
887 |
+
) -> torch.Tensor:
|
888 |
+
"""Permute a tensor along a dimension given a permutation vector.
|
889 |
+
|
890 |
+
Args:
|
891 |
+
tensor (torch.Tensor): Input tensor with shape
|
892 |
+
`([batch_dims], permutation_length, [content_dims])`.
|
893 |
+
dim (int): Dimension to permute along.
|
894 |
+
permute_idx (torch.LongTensor): Permutation index tensor with shape
|
895 |
+
`([batch_dims], permutation_length)`.
|
896 |
+
|
897 |
+
Returns:
|
898 |
+
tensor_permute (torch.Tensor): Permuted node features with shape
|
899 |
+
`([batch_dims], permutation_length, [content_dims])`.
|
900 |
+
"""
|
901 |
+
# Resolve absolute dimension
|
902 |
+
dim = range(len(list(tensor.shape)))[dim]
|
903 |
+
|
904 |
+
# Flatten content dimensions
|
905 |
+
shape = list(tensor.shape)
|
906 |
+
batch_dims, permute_length = shape[:dim], shape[dim]
|
907 |
+
tensor_flat = tensor.reshape(batch_dims + [permute_length] + [-1])
|
908 |
+
|
909 |
+
# Exap content dimensions
|
910 |
+
permute_idx_expand = permute_idx.unsqueeze(-1).expand(tensor_flat.shape)
|
911 |
+
|
912 |
+
tensor_permute_flat = torch.gather(tensor_flat, dim, permute_idx_expand)
|
913 |
+
tensor_permute = tensor_permute_flat.reshape(tensor.shape)
|
914 |
+
return tensor_permute
|
915 |
+
|
916 |
+
|
917 |
+
def permute_graph_embeddings(
|
918 |
+
node_h: torch.Tensor,
|
919 |
+
edge_h: torch.Tensor,
|
920 |
+
edge_idx: torch.LongTensor,
|
921 |
+
mask_i: torch.Tensor,
|
922 |
+
mask_ij: torch.Tensor,
|
923 |
+
permute_idx: torch.LongTensor,
|
924 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor]:
|
925 |
+
"""Permute graph embeddings given a permutation vector.
|
926 |
+
|
927 |
+
Args:
|
928 |
+
node_h (torch.Tensor): Node features with shape
|
929 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
930 |
+
edge_h (torch.Tensor): Edge features with shape
|
931 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
932 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
933 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
934 |
+
mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
|
935 |
+
mask_ij (tensor, optional): Edge mask with shape
|
936 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
937 |
+
permute_idx (torch.LongTensor): Permutation vector with shape
|
938 |
+
`(num_batch, num_nodes)`.
|
939 |
+
|
940 |
+
Returns:
|
941 |
+
node_h_permute (torch.Tensor): Permuted node features with shape
|
942 |
+
`(num_batch, num_nodes, dim_nodes)`.
|
943 |
+
edge_h_permute (torch.Tensor): Permuted edge features with shape
|
944 |
+
`(num_batch, num_nodes, num_neighbors, dim_edges)`.
|
945 |
+
edge_idx_permute (torch.LongTensor): Permuted edge indices for neighbors with shape
|
946 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
947 |
+
mask_i_permute (tensor, optional): Permuted node mask with shape `(num_batch, num_nodes)`
|
948 |
+
mask_ij_permute (tensor, optional): Permuted edge mask with shape
|
949 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
950 |
+
"""
|
951 |
+
|
952 |
+
# Permuting one-dimensional objects is straightforward gathering
|
953 |
+
node_h_permute = permute_tensor(node_h, 1, permute_idx)
|
954 |
+
edge_h_permute = permute_tensor(edge_h, 1, permute_idx)
|
955 |
+
mask_i_permute = permute_tensor(mask_i, 1, permute_idx)
|
956 |
+
mask_ij_permute = permute_tensor(mask_ij, 1, permute_idx)
|
957 |
+
|
958 |
+
"""
|
959 |
+
For edge_idx, there are two-dimensions set each edge idx that
|
960 |
+
previously pointed to j to now point to the new location
|
961 |
+
of j which is p^(-1)[j]
|
962 |
+
edge^(p)[i,k] = p^(-1)[edge[p(i),k]]
|
963 |
+
"""
|
964 |
+
# First, permute on the i dimension
|
965 |
+
edge_idx_permute_1 = permute_tensor(edge_idx, 1, permute_idx)
|
966 |
+
# Second, permute on the j dimension by using the inverse
|
967 |
+
permute_idx_inverse = torch.argsort(permute_idx, dim=-1)
|
968 |
+
edge_idx_1_flat = edge_idx_permute_1.reshape([edge_idx.shape[0], -1])
|
969 |
+
edge_idx_permute_flat = torch.gather(permute_idx_inverse, 1, edge_idx_1_flat)
|
970 |
+
edge_idx_permute = edge_idx_permute_flat.reshape(edge_idx.shape)
|
971 |
+
|
972 |
+
return (
|
973 |
+
node_h_permute,
|
974 |
+
edge_h_permute,
|
975 |
+
edge_idx_permute,
|
976 |
+
mask_i_permute,
|
977 |
+
mask_ij_permute,
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
def edge_mask_causal(edge_idx: torch.LongTensor, mask_ij: torch.Tensor) -> torch.Tensor:
|
982 |
+
"""Make an edge mask causal with mask_ij = 0 for j >= i.
|
983 |
+
|
984 |
+
Args:
|
985 |
+
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
|
986 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
987 |
+
mask_ij (torch.Tensor): Edge mask with shape
|
988 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
989 |
+
|
990 |
+
Returns:
|
991 |
+
mask_ij_causal (torch.Tensor): Causal edge mask with shape
|
992 |
+
`(num_batch, num_nodes, num_neighbors)`.
|
993 |
+
"""
|
994 |
+
idx = torch.arange(edge_idx.size(1), device=edge_idx.device)
|
995 |
+
idx_expand = idx.reshape([1, -1, 1])
|
996 |
+
mask_ij_causal = (edge_idx < idx_expand).float() * mask_ij
|
997 |
+
return mask_ij_causal
|
998 |
+
|
999 |
+
|
1000 |
+
class MaskedNorm(nn.Module):
|
1001 |
+
"""Masked normalization layer.
|
1002 |
+
|
1003 |
+
Args:
|
1004 |
+
dim (int): Dimensionality of the normalization. Can be 1 for 1D
|
1005 |
+
normalization along dimension 1 or 2 for 2D normalization along
|
1006 |
+
dimensions 1 and 2.
|
1007 |
+
num_features (int): Channel dimension; only needed if `affine` is True.
|
1008 |
+
affine (bool): If True, inclde a learnable affine transformation
|
1009 |
+
post-normalization. Default is False.
|
1010 |
+
norm (str): Type of normalization, can be `instance`, `layer`, or
|
1011 |
+
`transformer`.
|
1012 |
+
eps (float): Small number for numerical stability.
|
1013 |
+
|
1014 |
+
Inputs:
|
1015 |
+
data (torch.Tensor): Input tensor with shape
|
1016 |
+
`(num_batch, num_nodes, num_channels)` (1D) or
|
1017 |
+
`(num_batch, num_nodes, num_nodes, num_channels)` (2D).
|
1018 |
+
mask (torch.Tensor): Mask tensor with shape
|
1019 |
+
`(num_batch, num_nodes)` (1D) or
|
1020 |
+
`(num_batch, num_nodes, num_nodes)` (2D).
|
1021 |
+
|
1022 |
+
Outputs:
|
1023 |
+
norm_data (torch.Tensor): Mask-normalized tensor with shape
|
1024 |
+
`(num_batch, num_nodes, num_channels)` (1D) or
|
1025 |
+
`(num_batch, num_nodes, num_nodes, num_channels)` (2D).
|
1026 |
+
"""
|
1027 |
+
|
1028 |
+
def __init__(
|
1029 |
+
self,
|
1030 |
+
dim: int,
|
1031 |
+
num_features: int = -1,
|
1032 |
+
affine: bool = False,
|
1033 |
+
norm: str = "instance",
|
1034 |
+
eps: float = 1e-5,
|
1035 |
+
):
|
1036 |
+
super(MaskedNorm, self).__init__()
|
1037 |
+
|
1038 |
+
self.norm_type = norm
|
1039 |
+
self.dim = dim
|
1040 |
+
self.norm = norm + str(dim)
|
1041 |
+
self.affine = affine
|
1042 |
+
self.eps = eps
|
1043 |
+
|
1044 |
+
# Dimension to sum
|
1045 |
+
if self.norm == "instance1":
|
1046 |
+
self.sum_dims = [1]
|
1047 |
+
elif self.norm == "layer1":
|
1048 |
+
self.sum_dims = [1, 2]
|
1049 |
+
elif self.norm == "transformer1":
|
1050 |
+
self.sum_dims = [-1]
|
1051 |
+
elif self.norm == "instance2":
|
1052 |
+
self.sum_dims = [1, 2]
|
1053 |
+
elif self.norm == "layer2":
|
1054 |
+
self.sum_dims = [1, 2, 3]
|
1055 |
+
elif self.norm == "transformer2":
|
1056 |
+
self.sum_dims = [-1]
|
1057 |
+
else:
|
1058 |
+
raise NotImplementedError
|
1059 |
+
|
1060 |
+
# Number of features, only required if affine
|
1061 |
+
self.num_features = num_features
|
1062 |
+
|
1063 |
+
# Affine transformation is a linear layer on the C channel
|
1064 |
+
if self.affine:
|
1065 |
+
self.weights = nn.Parameter(torch.rand(self.num_features))
|
1066 |
+
self.bias = nn.Parameter(torch.zeros(self.num_features))
|
1067 |
+
|
1068 |
+
def forward(
|
1069 |
+
self, data: torch.Tensor, mask: Optional[torch.Tensor] = None
|
1070 |
+
) -> torch.Tensor:
|
1071 |
+
# Add optional trailing singleton dimension and expand if necessary
|
1072 |
+
if mask is not None:
|
1073 |
+
if len(mask.shape) == len(data.shape) - 1:
|
1074 |
+
mask = mask.unsqueeze(-1)
|
1075 |
+
if data.shape != mask.shape:
|
1076 |
+
mask = mask.expand(data.shape)
|
1077 |
+
|
1078 |
+
# Input shape is Batch, Channel, Dim1, (dim2 if 2d)
|
1079 |
+
dims = self.sum_dims
|
1080 |
+
if (mask is None) or (self.norm_type == "transformer"):
|
1081 |
+
mask_mean = data.mean(dim=dims, keepdim=True)
|
1082 |
+
mask_std = torch.sqrt(
|
1083 |
+
(((data - mask_mean)).pow(2)).mean(dim=dims, keepdim=True) + self.eps
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
# Norm
|
1087 |
+
norm_data = (data - mask_mean) / mask_std
|
1088 |
+
|
1089 |
+
else:
|
1090 |
+
# Zeroes vector to sum all mask data
|
1091 |
+
norm_data = torch.zeros_like(data).to(data.device).type(data.dtype)
|
1092 |
+
for mask_id in mask.unique():
|
1093 |
+
# Skip zero, since real mask
|
1094 |
+
if mask_id == 0:
|
1095 |
+
continue
|
1096 |
+
|
1097 |
+
# Transform mask to temp mask that match mask id
|
1098 |
+
tmask = (mask == mask_id).type(torch.float32)
|
1099 |
+
|
1100 |
+
# Sum mask for mean
|
1101 |
+
mask_sum = tmask.sum(dim=dims, keepdim=True)
|
1102 |
+
|
1103 |
+
# Data is tmask, so that mean is only for unmasked pos
|
1104 |
+
mask_mean = (data * tmask).sum(dim=dims, keepdim=True) / mask_sum
|
1105 |
+
mask_std = torch.sqrt(
|
1106 |
+
(((data - mask_mean) * tmask).pow(2)).sum(dim=dims, keepdim=True)
|
1107 |
+
/ mask_sum
|
1108 |
+
+ self.eps
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# Calculate temp norm, apply mask
|
1112 |
+
tnorm = ((data - mask_mean) / mask_std) * tmask
|
1113 |
+
# Sometime mask is empty, so generate nan that are conversted to 0
|
1114 |
+
tnorm[tnorm != tnorm] = 0
|
1115 |
+
|
1116 |
+
# Add to init zero norm data
|
1117 |
+
norm_data += tnorm
|
1118 |
+
|
1119 |
+
# Apply affine
|
1120 |
+
if self.affine:
|
1121 |
+
norm_data = norm_data * self.weights + self.bias
|
1122 |
+
|
1123 |
+
# If mask, apply mask
|
1124 |
+
if mask is not None:
|
1125 |
+
norm_data = norm_data * (mask != 0).type(data.dtype)
|
1126 |
+
return norm_data
|
chroma/chroma/layers/linalg.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Generate Biomedicines, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Layers for linear algebra.
|
16 |
+
进行线性代数计算
|
17 |
+
|
18 |
+
This module contains additional pytorch layers for linear algebra operations,
|
19 |
+
such as a more parallelization-friendly implementation of eigvenalue estimation.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
def eig_power_iteration(A, num_iterations=50, eps=1e-5):
|
26 |
+
"""Estimate largest magnitude eigenvalue and associated eigenvector.
|
27 |
+
|
28 |
+
This uses a simple power iteration algorithm to estimate leading
|
29 |
+
eigenvalues, which can often be considerably faster than torch's built-in
|
30 |
+
eigenvalue routines. All steps are differentiable and small constants are
|
31 |
+
added to any division to preserve the stability of the gradients. For more
|
32 |
+
information on power iteration, see
|
33 |
+
https://en.wikipedia.org/wiki/Power_iteration.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
A (tensor): Batch of square matrices with shape
|
37 |
+
`(..., num_dims, num_dims)`.
|
38 |
+
num_iterations (int, optional): Number of iterations for power
|
39 |
+
iteration. Default: 50.
|
40 |
+
eps (float, optional): Small number to prevent division by zero.
|
41 |
+
Default: 1E-5.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
lam (tensor): Batch of estimated highest-magnitude eigenvalues with
|
45 |
+
shape `(...)`.
|
46 |
+
v (tensor): Associated eigvector with shape `(..., num_dims)`.
|
47 |
+
"""
|
48 |
+
_safe = lambda x: x + eps
|
49 |
+
|
50 |
+
dims = list(A.size())[:-1]
|
51 |
+
v = torch.randn(dims, device=A.device).unsqueeze(-1)
|
52 |
+
for i in range(num_iterations):
|
53 |
+
v_prev = v
|
54 |
+
Av = torch.matmul(A, v)
|
55 |
+
v = Av / _safe(Av.norm(p=2, dim=-2, keepdim=True))
|
56 |
+
|
57 |
+
# Compute eigenvalue
|
58 |
+
v_prev = v_prev.transpose(-1, -2)
|
59 |
+
lam = torch.matmul(v_prev, Av) / _safe(torch.abs(torch.matmul(v_prev, v)))
|
60 |
+
|
61 |
+
# Reshape
|
62 |
+
v = v.squeeze(-1)
|
63 |
+
lam = lam.view(list(lam.size())[:-2])
|
64 |
+
return lam, v
|
65 |
+
|
66 |
+
|
67 |
+
def eig_leading(A, num_iterations=50):
|
68 |
+
"""Estimate largest positive eigenvalue and associated eigenvector.
|
69 |
+
|
70 |
+
This estimates the *most positive* eigenvalue of each matrix in a batch of
|
71 |
+
matrices by using two consecutive power iterations with spectral shifting.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
A (tensor): Batch of square matrices with shape
|
75 |
+
`(..., num_dims, num_dims)`.
|
76 |
+
num_iterations (int, optional): Number of iterations for power
|
77 |
+
iteration. Default: 50.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
lam (tensor): Estimated most positive eigenvalue with shape `(...)`.
|
81 |
+
v (tensor): Associated eigenvectors with shape `(..., num_dims)`.
|
82 |
+
"""
|
83 |
+
batch_dims = list(A.size())[:-2]
|
84 |
+
|
85 |
+
# First pass gets largest magnitude
|
86 |
+
lam_1, vec_1 = eig_power_iteration(A, num_iterations)
|
87 |
+
|
88 |
+
# Second pass guaranteed to grab most positive eigenvalue
|
89 |
+
lam_1_abs = torch.abs(lam_1)
|
90 |
+
lam_I = lam_1_abs.reshape(batch_dims + [1, 1]) * torch.eye(4, device=A.device).view(
|
91 |
+
[1 for _ in batch_dims] + [4, 4]
|
92 |
+
)
|
93 |
+
A_shift = A + lam_I
|
94 |
+
lam_2, vec = eig_power_iteration(A_shift, num_iterations)
|
95 |
+
|
96 |
+
# Shift back to original specta
|
97 |
+
lam = lam_2 - lam_1_abs
|
98 |
+
return lam, vec
|