Hukuna commited on
Commit
e9e75df
·
verified ·
1 Parent(s): 589196f

Upload 275 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. __pycache__/demo.cpython-38.pyc +0 -0
  3. app.py +33 -0
  4. chroma/CONTRIBUTING.md +11 -0
  5. chroma/Dockerfile +24 -0
  6. chroma/LICENSE.txt +202 -0
  7. chroma/README.md +255 -0
  8. chroma/Untitled.ipynb +6 -0
  9. chroma/assets/LiberationSans-Regular.ttf +0 -0
  10. chroma/assets/chroma_logo.svg +85 -0
  11. chroma/assets/chroma_logo_outline.svg +109 -0
  12. chroma/assets/conditioners.png +0 -0
  13. chroma/assets/lattice.png +3 -0
  14. chroma/assets/logo.png +0 -0
  15. chroma/assets/proteins.png +3 -0
  16. chroma/assets/refolding.png +3 -0
  17. chroma/chroma/__init__.py +19 -0
  18. chroma/chroma/__pycache__/__init__.cpython-38.pyc +0 -0
  19. chroma/chroma/constants/__init__.py +16 -0
  20. chroma/chroma/constants/__pycache__/__init__.cpython-38.pyc +0 -0
  21. chroma/chroma/constants/__pycache__/geometry.cpython-38.pyc +0 -0
  22. chroma/chroma/constants/__pycache__/named_models.cpython-38.pyc +0 -0
  23. chroma/chroma/constants/__pycache__/sequence.cpython-38.pyc +0 -0
  24. chroma/chroma/constants/geometry.py +558 -0
  25. chroma/chroma/constants/named_models.py +54 -0
  26. chroma/chroma/constants/sequence.py +112 -0
  27. chroma/chroma/data/__init__.py +19 -0
  28. chroma/chroma/data/__pycache__/__init__.cpython-38.pyc +0 -0
  29. chroma/chroma/data/__pycache__/protein.cpython-38.pyc +0 -0
  30. chroma/chroma/data/__pycache__/system.cpython-38.pyc +0 -0
  31. chroma/chroma/data/__pycache__/xcs.cpython-38.pyc +0 -0
  32. chroma/chroma/data/protein.py +513 -0
  33. chroma/chroma/data/system.py +0 -0
  34. chroma/chroma/data/xcs.py +121 -0
  35. chroma/chroma/layers/__init__.py +18 -0
  36. chroma/chroma/layers/__pycache__/__init__.cpython-38.pyc +0 -0
  37. chroma/chroma/layers/__pycache__/attention.cpython-38.pyc +0 -0
  38. chroma/chroma/layers/__pycache__/basic.cpython-38.pyc +0 -0
  39. chroma/chroma/layers/__pycache__/complexity.cpython-38.pyc +0 -0
  40. chroma/chroma/layers/__pycache__/conv.cpython-38.pyc +0 -0
  41. chroma/chroma/layers/__pycache__/graph.cpython-38.pyc +0 -0
  42. chroma/chroma/layers/__pycache__/linalg.cpython-38.pyc +0 -0
  43. chroma/chroma/layers/__pycache__/norm.cpython-38.pyc +0 -0
  44. chroma/chroma/layers/__pycache__/sde.cpython-38.pyc +0 -0
  45. chroma/chroma/layers/attention.py +347 -0
  46. chroma/chroma/layers/basic.py +467 -0
  47. chroma/chroma/layers/complexity.py +201 -0
  48. chroma/chroma/layers/conv.py +58 -0
  49. chroma/chroma/layers/graph.py +1126 -0
  50. chroma/chroma/layers/linalg.py +98 -0
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chroma/assets/lattice.png filter=lfs diff=lfs merge=lfs -text
37
+ chroma/assets/proteins.png filter=lfs diff=lfs merge=lfs -text
38
+ chroma/assets/refolding.png filter=lfs diff=lfs merge=lfs -text
39
+ chroma/notebooks/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
40
+ chroma/notebooks/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
41
+ chroma/notebooks/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
42
+ chroma/tests/_streamlit/demoapp/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
43
+ chroma/tests/_streamlit/demoapp/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
44
+ chroma/tests/_streamlit/demoapp/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
45
+ output/complex_trajectory.cif filter=lfs diff=lfs merge=lfs -text
46
+ output/shaped_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
47
+ output/symmetric_protein_trajectory.cif filter=lfs diff=lfs merge=lfs -text
__pycache__/demo.cpython-38.pyc ADDED
Binary file (10.4 kB). View file
 
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import demo
3
+
4
+ st.set_page_config(
5
+ page_title="Chroma Demos",
6
+ page_icon="🧬",
7
+ layout="wide",
8
+ initial_sidebar_state="expanded",
9
+ )
10
+
11
+ st.title("Demos for Chroma")
12
+
13
+ # sidebar
14
+ st.sidebar.header("Demo Config")
15
+
16
+ # 创建字典映射demo
17
+ demoDict={
18
+ "getProtein":demo.getProteinDemo,
19
+ "complexSample":demo.complexSampleDemo,
20
+ "symmetricSample":demo.symmetricSampleDemo,
21
+ "shapeSample":demo.shapeSampleDemo,
22
+ "foldSample":demo.foldSampleDemo,
23
+ "ssSample":demo.ssSampleDemo,
24
+ "substructureSample":demo.substructureSampleDemo,
25
+
26
+ }
27
+ # 在侧边栏中添加一个选择框,用于选择demo
28
+ selected_branch = st.sidebar.selectbox("Select demo", list(demoDict.keys()))
29
+ style=st.sidebar.selectbox("Select style:Can be 'stick', 'sphere', 'cross','cartoon'",('stick', 'sphere', 'cross','cartoon'),key='style')
30
+ resn=st.sidebar.selectbox("Select display resn:PDB resn labels:['ALA','ARG','LYS','THR','TRP','TYR','VAL']",('','ALA','ARG','LYS','THR','TRP','TYR','VAL'),key='resn')
31
+
32
+ # 执行选定分支对应的函数
33
+ demoDict[selected_branch](style,resn)
chroma/CONTRIBUTING.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code contributions
2
+
3
+ We welcome contributions to the Chroma code base, including new conditioners, integrators, patches, bug fixes, and others.
4
+
5
+ Note that your contributions will be governed by the Apache 2.0 license, meaning that you will be giving us permission to use your contributed code under the conditions specified in the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) (also available in [LICENSE.txt](LICENSE.txt)).
6
+
7
+ ## How to Contribute
8
+
9
+ Please use GitHub pull requests to contribute code. See
10
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
11
+ information on using pull requests. We will try to monitor incoming requests with some regularity, but cannot promise a specific timeframe within which we will review your request.
chroma/Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.3.1-devel-ubuntu20.04
2
+ ARG DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ build-essential \
5
+ cmake \
6
+ git \
7
+ curl \
8
+ ca-certificates \
9
+ libjpeg-dev \
10
+ libpng-dev && \
11
+ rm -rf /var/lib/apt/lists/*
12
+
13
+ WORKDIR /tmp
14
+
15
+ RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
16
+ chmod +x ~/miniconda.sh && \
17
+ ~/miniconda.sh -b -p /opt/conda && \
18
+ rm ~/miniconda.sh
19
+ RUN /opt/conda/bin/conda create --name chroma python=3.9.7
20
+ RUN /opt/conda/envs/chroma/bin/pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
21
+ WORKDIR /workspace
22
+ COPY . .
23
+ RUN /opt/conda/envs/chroma/bin/pip install .
24
+ ENV PATH /opt/conda/envs/chroma/bin:$PATH
chroma/LICENSE.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
chroma/README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <img src="assets/chroma_logo_outline.svg" width="280">
2
+
3
+ [**Get Started**](#get-started)
4
+ | [**Sampling**](#sampling)
5
+ | [**Design**](#design)
6
+ | [**Conditioners**](#conditioners)
7
+ | [**License**](#license)
8
+
9
+ Chroma is a generative model for designing proteins **programmatically**.
10
+
11
+ Protein space is complex and hard to navigate. With Chroma, protein design problems are represented in terms of [composable building blocks](#conditioners) from which diverse, [all-atom protein structures can be automatically generated](#sampling). As a joint model of structure and sequence, Chroma can also be used for common protein modeling tasks such as [generating sequences given backbones](#design), packing side-chains, and scoring designs.
12
+
13
+ We provide protein conditioners for a variety of constraints, including substructure, symmetry, shape, and neural-network predictions of some protein classes and annotations. We also provide an API for [creating your own conditioners](#conditioners-api) in a few lines of code.
14
+
15
+ Internally, Chroma uses diffusion modeling, equivariant graph neural networks, and conditional random fields to efficiently sample all-atom structures with a complexity that is sub-quadratic in the number of residues. It can generate large complexes in a few minutes on a commodity GPU. You can read more about Chroma, including biophysical and crystallographic validation of some early designs, in our paper, [*Illuminating protein space with a programmable generative model*. Nature 2023](https://doi.org/10.1038/s41586-023-06728-8).
16
+
17
+ <div align="center">
18
+ <img src="assets/proteins.png" alt="Generated protein examples" width="700px" align="middle"/>
19
+ </div>
20
+
21
+ ## Get Started
22
+ > **Note:** An API key is required to download and use the pretrained model weights. It can be obtained [here](https://chroma-weights.generatebiomedicines.com/).
23
+
24
+
25
+ **Colab Notebooks**. The quickest way to get started with Chroma is our Colab notebooks, which provide starting points for a variety of use cases in a preconfigured, in-browser environment
26
+
27
+ * [Chroma Quickstart](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaDemo.ipynb): GUI notebook demonstrating unconditional and conditional generation of proteins with Chroma.
28
+ * [Chroma API Tutorial](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaAPI.ipynb): Code notebook demonstrating protein I/O, sampling, and design configuration directly in `python`.
29
+ * [Chroma Conditioner API Tutorial](https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaConditioners.ipynb): A deeper dive under the hood for implementing new Chroma [Conditioners](#conditioner-api).
30
+
31
+ **PyPi package**.You can install the latest release of Chroma with:
32
+ ```
33
+ pip install generate-chroma
34
+ ```
35
+
36
+ **Install latest Chroma from github**
37
+ ```
38
+ git clone https://github.com/generatebio/chroma.git
39
+ pip install -e chroma # use `-e` for it to be editable locally.
40
+ ```
41
+
42
+ ## Sampling
43
+ **Unconditional monomer**. We provide a unified entry point to both unconditional and conditional protein design with the `Chroma.sample()` method. When no conditioners are specified, we can sample a simple 200-amino acid monomeric protein with
44
+ ```python
45
+ from chroma import Chroma
46
+
47
+ chroma = Chroma()
48
+ protein = chroma.sample(chain_lengths=[200])
49
+
50
+ protein.to("sample.cif")
51
+ display(protein)
52
+ ```
53
+
54
+ Generally, `Chroma.sample()` takes as input design hyperparameters and [Conditioners](#conditioners) and outputs `Protein` objects representing the all-atom structures of protein systems which can be loaded to and from disk in PDB or mmCIF formats.
55
+
56
+ **Unconditional complex**. To sample a complex instead of a monomer, we can simply do
57
+ ```python
58
+ from chroma import Chroma
59
+
60
+ chroma = Chroma()
61
+ protein = chroma.sample(chain_lengths=[100, 200])
62
+
63
+ protein.to("sample-complex.cif")
64
+ ```
65
+
66
+ **Conditional complex**. We can further customize sampling towards design objectives via [Conditioners](#conditioners) and sampling hyperparameters. For example, to sample a C3-symmetric homo-trimer with 100 residues per monomer, we can do
67
+
68
+ ```python
69
+ from chroma import Chroma, conditioners
70
+
71
+ chroma = Chroma()
72
+ conditioner = conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=2)
73
+ protein = chroma.sample(
74
+ chain_lengths=[100],
75
+ conditioner=conditioner,
76
+ langevin_factor=8,
77
+ inverse_temperature=8,
78
+ sde_func="langevin",
79
+ potts_symmetry_order=conditioner.potts_symmetry_order)
80
+
81
+ protein.to("sample-C3.cif")
82
+ ```
83
+
84
+ Because compositions of conditioners are conditioners, even relatively complex design problems can follow this basic usage pattern. See the [demo notebooks](#get-started) and docstrings for more information on hyperparameters, conditioners, and starting points.
85
+
86
+ ## Design
87
+ **Robust design**. Chroma is a joint model of sequence and structure that uses a common graph neural network base architecture to parameterize both backbone generation and conditional sequence and sidechain generation. These sequence and sidechain decoders are *diffusion-aware* in the sense that they have been trained to predict sequence and side chain not just for natural structures at diffusion time $t=0$ but also on noisy structures at all diffusion times $t \in [0,1]$. As a result, the $t$ hyperpameter of the design network provides a kind of tunable robustness via **diffusion augmentation** in we trade off between how much the model attempts to design the backbone *exactly* as specified (e.g. $t=0.0$) versus *robust* design within a small neighborhood of nearby backbone conformations (e.g. $t=0.5$).
88
+
89
+ While all results presented in the Chroma [publication](https://doi.org/10.1038/s41586-023-06728-8) were done with **exact design** at $t=0.0$, we have found **robust design** at times near $t=0.5$ frequently improves one-shot refolding while incurring only minor, often Ångstrom-scale, relaxation adjustments to target backbones. When we compare the performance of these two design modes on our set of 50,000 unconditional backbones that were analyzed in the paper, we see very large improvements in refolding across both [AlphaFold](https://github.com/google-deepmind/alphafold) and [ESMFold](https://github.com/facebookresearch/esm) that stratifies well across protein length, percent helicity, or similarity to a known structure (See Chroma [Supplementary Figure 14](https://doi.org/10.1038/s41586-023-06728-8) for further context).
90
+
91
+
92
+ <div align="center">
93
+ <img src="./assets/refolding.png" alt="alt text" width="700px" align="middle"/>
94
+ </div></br>
95
+
96
+ The value of diffusion time conditioning $t$ can be set via the `design_t` parameter in `Chroma.sample` and `Chroma.design`. We find that for generated structures, $t = 0.5$ produces highly robust refolding results and is, therefore, the default setting. For experimentally-precise structures, $t = 0.0$ may be more appropriate, and values in between may provide a useful tradeoff between these two regimes.
97
+
98
+ **Design *a la carte***. Chroma's design network can be accessed separately to design, redesign, and pack arbitrary protein systems. Here we load a protein from the PDB and redesign as
99
+ ```python
100
+ # Redesign a Protein
101
+ from chroma import Protein, Chroma
102
+ chroma = Chroma()
103
+
104
+ protein = Protein('1GFP')
105
+ protein = chroma.design(protein)
106
+
107
+ protein.to("1GFP-redesign.cif")
108
+ ```
109
+
110
+ Clamped sub-sequence redesign is also available and compatible with a built-in selection algebra, along with position- and mutation-specific mask constraints as
111
+ ```python
112
+ # Redesign a Protein
113
+ from chroma import Protein, Chroma
114
+ chroma = Chroma()
115
+
116
+ protein = Protein('my_favorite_protein.cif') # PDB is fine too
117
+ protein = chroma.design(protein, design_selection="resid 20-50 around 5.0") # 5 angstrom bubble around indices 20-50
118
+
119
+ protein.to("my_favorite_protein_redesign.cif")
120
+ ```
121
+
122
+ We provide more examples of design in the [demo notebooks](#get-started).
123
+
124
+ ## Conditioners
125
+ Protein design with Chroma is **programmable**. Our `Conditioner` framework allows for automatic conditional sampling under arbitrary compositions of protein specifications, which can come in the forms of restraints (biasing the distribution of states) or constraints (directly restrict the domain of underlying sampling process); see Supplementary Appendix M in [our paper](https://doi.org/10.1038/s41586-023-06728-8). We have pre-defined multiple conditioners, including for controlling substructure, symmetry, shape, semantics, and natural-language prompts (see `chroma.layers.structure.conditioners`), which can be used in arbitrary combinations.
126
+
127
+ <div align="center">
128
+
129
+ | Conditioner | Class(es) in [`chroma.conditioners`](chroma/layers/structure/conditioners.py) | Example applications |
130
+ |----------|----------|----------|
131
+ | Symmetry constraint | `SymmetryConditioner`, `ScrewConditioner` | Large symmetric assemblies |
132
+ | Substructure constraint | `SubstructureConditioner` | Substructure grafting, scaffold enforcement |
133
+ | Shape restraint | `ShapeConditioner` | Molecular shape control |
134
+ | Secondary structure | `ProClassConditioner` | Secondary-structure specification |
135
+ | Domain classification | `ProClassConditioner` | Specification of class, such as Pfam, CATH, or Taxonomy |
136
+ | Text caption | `ProCapConditioner` | Natural language prompting |
137
+ | Sequence | `SubsequenceConditioner` | Subsequence constraints. |
138
+
139
+ </div>
140
+
141
+ **How it works**. The central idea of Conditioners is *composable state transformations*, where each Conditioner is a function that modifies the state and/or energy of a protein system in a differentiable way ([Supplementary Appendix M](https://doi.org/10.1038/s41586-023-06728-8)). For example, to encode symmetry as a *constraint* we can take as input the assymetric unit and tesselate it according to the desired symmetry group to output a protein system that is symmetric by construction. To encode something like a neural network restraint, we can adjust the total system energy by the negative log probability of the target condition. For both of these, we add on the diffusion energy to the output of the Conditioner(s) and then backpropagate the total energy through all intermediate transformations to compute the unconstrained forces that are compatible with generic sampling SDE such as annealed Langevin Dynamics.
142
+
143
+ We schematize this overall Conditioners framework below.
144
+ <div align="center">
145
+ <img src="./assets/conditioners.png" alt="alt text" width="600px" align="middle"/><br>
146
+ <figcaption><i>The <code>Conditioner</code> class is the composable building block of protein design with Chroma.</i></figcaption>
147
+ </div>
148
+
149
+ #### Conditioner API
150
+ It is simple to develop new conditioners. A `Conditioner` is a Pytorch `nn.Module` which takes in the system state - i.e. the structure, energy, and diffusion time - and outputs potentially updated structures and energies as
151
+
152
+ ```python
153
+
154
+ class Conditioner(torch.nn.Module):
155
+ """A composable function for parameterizing protein design problems.
156
+ """
157
+ def __init__(self, *args, **kwargs):
158
+ super().__init__()
159
+ # Setup your conditioner's hyperparameters
160
+
161
+ def forward(
162
+ self,
163
+ X: torch.Tensor, # Input coordinates
164
+ C: torch.LongTensor, # Input chain map (for complexes)
165
+ O: torch.Tensor, # Input sequence (one-hot, not used)
166
+ U: torch.Tensor, # Input energy (one-hot, not used)
167
+ t: Union[torch.Tensor, float], # Diffusion time
168
+ ):
169
+ # Update the state, e.g. map from an unconstrained to constrained manifold
170
+ X_update, C_update = update_state(X, C, t)
171
+
172
+ # Update the energy, e.g. add a restraint potential
173
+ U_update = U + update_energy(X, C, t)
174
+ return X_update, C_update, O, U_update, t
175
+ ```
176
+ Roughly speaking, `Conditioner`s are composable by construction because their input and output type signatures are matched (i.e. they are an endomorphism). So we also simply build conditioners from conditioners by "stacking" them much as we would with traditional neural network layer developemnt. With the final `Conditioner` as an input, `Chroma.sample()` will then leverage Pytorch's automatic differentiation facilities to automaticallly furnish a diffusion-annealed MCMC sampling algorithm to sample with this conditioner (We note this isn't magic and taking care to scale and parameterize appropriately is [important](#note-on-conditioners)).
177
+
178
+ ##### A minimal Conditioner: 2D lattice symmetry
179
+ The code snippet below shows how in a few lines of code we can add a conditioner that stipulates the generation of a 2D crystal-like object, where generated proteins are arrayed in an `M x N` rectangular lattice.
180
+
181
+ ```python
182
+ import torch
183
+ from chroma.models import Chroma
184
+ from chroma.layers.structure import conditioners
185
+
186
+ class Lattice2DConditioner(conditioners.Conditioner):
187
+ def __init__(self, M, N, cell):
188
+ super().__init__()
189
+ # Setup the coordinates of a 2D lattice
190
+ self.order = M*N
191
+ x = torch.arange(M) * cell[0]
192
+ y = torch.arange(N) * cell[1]
193
+ xx, yy = torch.meshgrid(x, y, indexing="ij")
194
+ dX = torch.stack([xx.flatten(), yy.flatten(), torch.zeros(M * N)], dim=1)
195
+ self.register_buffer("dX", dX)
196
+
197
+ def forward(self, X, C, O, U, t):
198
+ # Tesselate the unit cell on the lattice
199
+ X = (X[:,None,...] + self.dX[None,:,None,None]).reshape(1, -1, 4, 3)
200
+ C = torch.cat([C + C.unique().max() * i for i in range(self.dX.shape[0])], dim=1)
201
+ # Average the gradient across the group (simplifies force scaling)
202
+ X.register_hook(lambda gradX: gradX / self.order)
203
+ return X, C, O, U, t
204
+
205
+ chroma = Chroma().cuda()
206
+ conditioner = Lattice2DConditioner(M=3, N=4, cell=[20., 15.]).cuda()
207
+ protein = chroma.sample(
208
+ chain_lengths=[70], conditioner=conditioner, sde_func='langevin',
209
+ potts_symmetry_order=conditioner.order
210
+ )
211
+
212
+ protein.to_CIF("lattice_protein.cif")
213
+ ```
214
+
215
+ <div align="center">
216
+ <img src="./assets/lattice.png" alt="alt text" width="700px" align="middle"/>
217
+ </div>
218
+
219
+ #### Note on Conditioners
220
+
221
+ An attractive aspect of this conditioner framework is that it is very general, enabling both constraints (which involve operations on $x$) and restraints (which amount to changes to $U$). At the same time, generation under restraints can still be (and often is) challenging, as the resulting effective energy landscape can become arbitrarily rugged and difficult to integrate. We therefore advise caution when using and developing new conditioners or conditioner combinations. We find that inspecting diffusition trajectories (including unconstrained and denoised trajectories, $\hat{x}_t$ and $\tilde{x}_t$) can be a good tool for identifying integration challenges and defining either better conditioner forms or better sampling regimes.
222
+
223
+ ## Citing Chroma
224
+
225
+ If you use Chroma in your research, please cite:
226
+
227
+ J. B. Ingraham, M. Baranov, Z. Costello, K. W. Barber, W. Wang, A. Ismail, V. Frappier, D. M. Lord, C. Ng-Thow-Hing, E. R. Van Vlack, S. Tie, V. Xue, S. C. Cowles, A. Leung, J. V. Rodrigues, C. L. Morales-Perez, A. M. Ayoub, R. Green, K. Puentes, F. Oplinger, N. V. Panwar, F. Obermeyer, A. R. Root, A. L. Beam, F. J. Poelwijk, and G. Grigoryan, "Illuminating protein space with a programmable generative model", *Nature*, 2023 (10.1038/s41586-023-06728-8).
228
+
229
+ ```bibtex
230
+ @Article{Chroma2023,
231
+ author = {Ingraham, John B. and Baranov, Max and Costello, Zak and Barber, Karl W. and Wang, Wujie and Ismail, Ahmed and Frappier, Vincent and Lord, Dana M. and Ng-Thow-Hing, Christopher and Van Vlack, Erik R. and Tie, Shan and Xue, Vincent and Cowles, Sarah C. and Leung, Alan and Rodrigues, Jo\~{a}o V. and Morales-Perez, Claudio L. and Ayoub, Alex M. and Green, Robin and Puentes, Katherine and Oplinger, Frank and Panwar, Nishant V. and Obermeyer, Fritz and Root, Adam R. and Beam, Andrew L. and Poelwijk, Frank J. and Grigoryan, Gevorg},
232
+ journal = {Nature},
233
+ title = {Illuminating protein space with a programmable generative model},
234
+ year = {2023},
235
+ volume = {},
236
+ number = {},
237
+ pages = {},
238
+ doi = {10.1038/s41586-023-06728-8}
239
+ }
240
+ ```
241
+
242
+ ## Acknowledgements
243
+ The Chroma codebase is the work of many contributers at Generate Biomedicines. We would like to acknowledge: Ahmed Ismail, Alan Witmer, Alex Ramos, Alexander Bock, Ameya Harmalkar, Brinda Monian, Craig Mackenzie, Dan Luu, David Moore, Frank Oplinger, Fritz Obermeyer, George Kent-Scheller, Gevorg Grigoryan, Jacob Feala, James Lucas, Jenhan Tao, John Ingraham, Martin Jankowiak, Max Baranov, Meghan Franklin, Mick Ward, Rudraksh Tuwani, Ryan Nelson, Shan Tie, Vincent Frappier, Vincent Xue, William Wolfe-McGuire, Wujie Wang, Zak Costello, Zander Harteveld.
244
+
245
+ ## License
246
+
247
+ Copyright Generate Biomedicines, Inc.
248
+
249
+ ### Chroma Code License
250
+ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this code except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.
251
+
252
+ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. See the License for the specific language governing permissions and limitations under the License.
253
+
254
+ ### Model Weights License
255
+ Chroma weights are freely available to academic researchers and non-profit entities who accept and agree to be bound under the terms of the Chroma Parameters License. Please visit the [weights download page](https://chroma-weights.generatebiomedicines.com/) for more information. If you are not eligible to use the Chroma Parameters under the terms of the provided License or if you would like to share the Chroma Parameters and/or otherwise use the Chroma Parameters beyond the scope of the rights granted in the License (including for commercial purposes), you may contact the Licensor at: [email protected].
chroma/Untitled.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
chroma/assets/LiberationSans-Regular.ttf ADDED
Binary file (139 kB). View file
 
chroma/assets/chroma_logo.svg ADDED
chroma/assets/chroma_logo_outline.svg ADDED
chroma/assets/conditioners.png ADDED
chroma/assets/lattice.png ADDED

Git LFS Details

  • SHA256: 6f19bae6a7d8c38dece6bdb8eab384bf319264957a1a8ce85f0eb90e21e2b7b7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.47 MB
chroma/assets/logo.png ADDED
chroma/assets/proteins.png ADDED

Git LFS Details

  • SHA256: 9714927ed591ba22d5815ef24219b493dd40389be3c4c4cda8f830e89de48fe3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.89 MB
chroma/assets/refolding.png ADDED

Git LFS Details

  • SHA256: 2b1db27f48d31963e8ff422ea47b8bfbb2a7d2cd6ab1c344896be3d672840bb3
  • Pointer size: 132 Bytes
  • Size of remote file: 4.18 MB
chroma/chroma/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __version__ = "1.0.0"
16
+ from chroma.data.protein import Protein
17
+ from chroma.layers.structure import conditioners
18
+ from chroma.models.chroma import Chroma
19
+ from chroma.utility import api
chroma/chroma/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (365 Bytes). View file
 
chroma/chroma/constants/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from chroma.constants.geometry import AA_GEOMETRY
16
+ from chroma.constants.sequence import *
chroma/chroma/constants/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (248 Bytes). View file
 
chroma/chroma/constants/__pycache__/geometry.cpython-38.pyc ADDED
Binary file (6.97 kB). View file
 
chroma/chroma/constants/__pycache__/named_models.cpython-38.pyc ADDED
Binary file (1.24 kB). View file
 
chroma/chroma/constants/__pycache__/sequence.cpython-38.pyc ADDED
Binary file (2.06 kB). View file
 
chroma/chroma/constants/geometry.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dictionary containing ideal internal coordinates and chi angle assignments
16
+ for building amino acid 3D coordinates"""
17
+ from typing import Dict
18
+
19
+ AA_GEOMETRY: Dict[str, dict] = {
20
+ "ALA": {
21
+ "atoms": ["CB"],
22
+ "chi_indices": [],
23
+ "parents": [["N", "C", "CA"]],
24
+ "types": {"C": "C", "CA": "CT1", "CB": "CT3", "N": "NH1", "O": "O"},
25
+ "z-angles": [111.09],
26
+ "z-dihedrals": [123.23],
27
+ "z-lengths": [1.55],
28
+ },
29
+ "ARG": {
30
+ "atoms": ["CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
31
+ "chi_indices": [1, 2, 3, 4],
32
+ "parents": [
33
+ ["N", "C", "CA"],
34
+ ["N", "CA", "CB"],
35
+ ["CA", "CB", "CG"],
36
+ ["CB", "CG", "CD"],
37
+ ["CG", "CD", "NE"],
38
+ ["CD", "NE", "CZ"],
39
+ ["NH1", "NE", "CZ"],
40
+ ],
41
+ "types": {
42
+ "C": "C",
43
+ "CA": "CT1",
44
+ "CB": "CT2",
45
+ "CD": "CT2",
46
+ "CG": "CT2",
47
+ "CZ": "C",
48
+ "N": "NH1",
49
+ "NE": "NC2",
50
+ "NH1": "NC2",
51
+ "NH2": "NC2",
52
+ "O": "O",
53
+ },
54
+ "z-angles": [112.26, 115.95, 114.01, 107.09, 123.05, 118.06, 122.14],
55
+ "z-dihedrals": [123.64, 180.0, 180.0, 180.0, 180.0, 180.0, 178.64],
56
+ "z-lengths": [1.56, 1.55, 1.54, 1.5, 1.34, 1.33, 1.33],
57
+ },
58
+ "ASN": {
59
+ "atoms": ["CB", "CG", "OD1", "ND2"],
60
+ "chi_indices": [1, 2],
61
+ "parents": [
62
+ ["N", "C", "CA"],
63
+ ["N", "CA", "CB"],
64
+ ["CA", "CB", "CG"],
65
+ ["OD1", "CB", "CG"],
66
+ ],
67
+ "types": {
68
+ "C": "C",
69
+ "CA": "CT1",
70
+ "CB": "CT2",
71
+ "CG": "CC",
72
+ "N": "NH1",
73
+ "ND2": "NH2",
74
+ "O": "O",
75
+ "OD1": "O",
76
+ },
77
+ "z-angles": [113.04, 114.3, 122.56, 116.15],
78
+ "z-dihedrals": [121.18, 180.0, 180.0, -179.19],
79
+ "z-lengths": [1.56, 1.53, 1.23, 1.35],
80
+ },
81
+ "ASP": {
82
+ "atoms": ["CB", "CG", "OD1", "OD2"],
83
+ "chi_indices": [1, 2],
84
+ "parents": [
85
+ ["N", "C", "CA"],
86
+ ["N", "CA", "CB"],
87
+ ["CA", "CB", "CG"],
88
+ ["OD1", "CB", "CG"],
89
+ ],
90
+ "types": {
91
+ "C": "C",
92
+ "CA": "CT1",
93
+ "CB": "CT2A",
94
+ "CG": "CC",
95
+ "N": "NH1",
96
+ "O": "O",
97
+ "OD1": "OC",
98
+ "OD2": "OC",
99
+ },
100
+ "z-angles": [114.1, 112.6, 117.99, 117.7],
101
+ "z-dihedrals": [122.33, 180.0, 180.0, -170.23],
102
+ "z-lengths": [1.56, 1.52, 1.26, 1.25],
103
+ },
104
+ "CYS": {
105
+ "atoms": ["CB", "SG"],
106
+ "chi_indices": [1],
107
+ "parents": [["N", "C", "CA"], ["N", "CA", "CB"]],
108
+ "types": {"C": "C", "CA": "CT1", "CB": "CT2", "N": "NH1", "O": "O", "SG": "S"},
109
+ "z-angles": [111.98, 113.87],
110
+ "z-dihedrals": [121.79, 180.0],
111
+ "z-lengths": [1.56, 1.84],
112
+ },
113
+ "GLN": {
114
+ "atoms": ["CB", "CG", "CD", "OE1", "NE2"],
115
+ "chi_indices": [1, 2, 3],
116
+ "parents": [
117
+ ["N", "C", "CA"],
118
+ ["N", "CA", "CB"],
119
+ ["CA", "CB", "CG"],
120
+ ["CB", "CG", "CD"],
121
+ ["OE1", "CG", "CD"],
122
+ ],
123
+ "types": {
124
+ "C": "C",
125
+ "CA": "CT1",
126
+ "CB": "CT2",
127
+ "CD": "CC",
128
+ "CG": "CT2",
129
+ "N": "NH1",
130
+ "NE2": "NH2",
131
+ "O": "O",
132
+ "OE1": "O",
133
+ },
134
+ "z-angles": [111.68, 115.52, 112.5, 121.52, 116.84],
135
+ "z-dihedrals": [121.91, 180.0, 180.0, 180.0, 179.57],
136
+ "z-lengths": [1.55, 1.55, 1.53, 1.23, 1.35],
137
+ },
138
+ "GLU": {
139
+ "atoms": ["CB", "CG", "CD", "OE1", "OE2"],
140
+ "chi_indices": [1, 2, 3],
141
+ "parents": [
142
+ ["N", "C", "CA"],
143
+ ["N", "CA", "CB"],
144
+ ["CA", "CB", "CG"],
145
+ ["CB", "CG", "CD"],
146
+ ["OE1", "CG", "CD"],
147
+ ],
148
+ "types": {
149
+ "C": "C",
150
+ "CA": "CT1",
151
+ "CB": "CT2A",
152
+ "CD": "CC",
153
+ "CG": "CT2",
154
+ "N": "NH1",
155
+ "O": "O",
156
+ "OE1": "OC",
157
+ "OE2": "OC",
158
+ },
159
+ "z-angles": [111.71, 115.69, 115.73, 114.99, 120.08],
160
+ "z-dihedrals": [121.9, 180.0, 180.0, 180.0, -179.1],
161
+ "z-lengths": [1.55, 1.56, 1.53, 1.26, 1.25],
162
+ },
163
+ "GLY": {
164
+ "atoms": [],
165
+ "chi_indices": [],
166
+ "parents": [],
167
+ "types": {"C": "C", "CA": "CT2", "N": "NH1", "O": "O"},
168
+ "z-angles": [],
169
+ "z-dihedrals": [],
170
+ "z-lengths": [],
171
+ },
172
+ "HIS": {
173
+ "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
174
+ "chi_indices": [1, 2],
175
+ "parents": [
176
+ ["N", "C", "CA"],
177
+ ["N", "CA", "CB"],
178
+ ["CA", "CB", "CG"],
179
+ ["ND1", "CB", "CG"],
180
+ ["CB", "CG", "ND1"],
181
+ ["CB", "CG", "CD2"],
182
+ ],
183
+ "types": {
184
+ "C": "C",
185
+ "CA": "CT1",
186
+ "CB": "CT2",
187
+ "CD2": "CPH1",
188
+ "CE1": "CPH2",
189
+ "CG": "CPH1",
190
+ "N": "NH1",
191
+ "ND1": "NR1",
192
+ "NE2": "NR2",
193
+ "O": "O",
194
+ },
195
+ "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03],
196
+ "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99],
197
+ "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38],
198
+ },
199
+ "HSD": {
200
+ "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
201
+ "chi_indices": [1, 2],
202
+ "parents": [
203
+ ["N", "C", "CA"],
204
+ ["N", "CA", "CB"],
205
+ ["CA", "CB", "CG"],
206
+ ["ND1", "CB", "CG"],
207
+ ["CB", "CG", "ND1"],
208
+ ["CB", "CG", "CD2"],
209
+ ],
210
+ "types": {
211
+ "C": "C",
212
+ "CA": "CT1",
213
+ "CB": "CT2",
214
+ "CD2": "CPH1",
215
+ "CE1": "CPH2",
216
+ "CG": "CPH1",
217
+ "N": "NH1",
218
+ "ND1": "NR1",
219
+ "NE2": "NR2",
220
+ "O": "O",
221
+ },
222
+ "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03],
223
+ "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99],
224
+ "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38],
225
+ },
226
+ "HSE": {
227
+ "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
228
+ "chi_indices": [],
229
+ "parents": [
230
+ ["N", "C", "CA"],
231
+ ["N", "CA", "CB"],
232
+ ["CA", "CB", "CG"],
233
+ ["ND1", "CB", "CG"],
234
+ ["CB", "CG", "ND1"],
235
+ ["CB", "CG", "CD2"],
236
+ ],
237
+ "types": {
238
+ "C": "C",
239
+ "CA": "CT1",
240
+ "CB": "CT2",
241
+ "CD2": "CPH1",
242
+ "CE1": "CPH2",
243
+ "CG": "CPH1",
244
+ "N": "NH1",
245
+ "ND1": "NR2",
246
+ "NE2": "NR1",
247
+ "O": "O",
248
+ },
249
+ "z-angles": [111.67, 116.94, 120.17, 129.71, 105.2, 105.8],
250
+ "z-dihedrals": [123.52, 180.0, 90.0, -178.26, -179.2, 178.66],
251
+ "z-lengths": [1.56, 1.51, 1.39, 1.36, 1.32, 1.38],
252
+ },
253
+ "HSP": {
254
+ "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"],
255
+ "chi_indices": [],
256
+ "parents": [
257
+ ["N", "C", "CA"],
258
+ ["N", "CA", "CB"],
259
+ ["CA", "CB", "CG"],
260
+ ["ND1", "CB", "CG"],
261
+ ["CB", "CG", "ND1"],
262
+ ["CB", "CG", "CD2"],
263
+ ],
264
+ "types": {
265
+ "C": "C",
266
+ "CA": "CT1",
267
+ "CB": "CT2A",
268
+ "CD2": "CPH1",
269
+ "CE1": "CPH2",
270
+ "CG": "CPH1",
271
+ "N": "NH1",
272
+ "ND1": "NR3",
273
+ "NE2": "NR3",
274
+ "O": "O",
275
+ },
276
+ "z-angles": [109.38, 114.18, 122.94, 128.93, 108.9, 106.93],
277
+ "z-dihedrals": [125.13, 180.0, 90.0, -165.26, -167.62, 167.13],
278
+ "z-lengths": [1.55, 1.52, 1.37, 1.35, 1.33, 1.37],
279
+ },
280
+ "ILE": {
281
+ "atoms": ["CB", "CG1", "CG2", "CD1"],
282
+ "chi_indices": [1, 3],
283
+ "parents": [
284
+ ["N", "C", "CA"],
285
+ ["N", "CA", "CB"],
286
+ ["CG1", "CA", "CB"],
287
+ ["CA", "CB", "CG1"],
288
+ ],
289
+ "types": {
290
+ "C": "C",
291
+ "CA": "CT1",
292
+ "CB": "CT1",
293
+ "CD": "CT3",
294
+ "CG1": "CT2",
295
+ "CG2": "CT3",
296
+ "N": "NH1",
297
+ "O": "O",
298
+ },
299
+ "z-angles": [112.93, 113.63, 113.93, 114.09],
300
+ "z-dihedrals": [124.22, 180.0, -130.04, 180.0],
301
+ "z-lengths": [1.57, 1.55, 1.55, 1.54],
302
+ },
303
+ "LEU": {
304
+ "atoms": ["CB", "CG", "CD1", "CD2"],
305
+ "chi_indices": [1, 2],
306
+ "parents": [
307
+ ["N", "C", "CA"],
308
+ ["N", "CA", "CB"],
309
+ ["CA", "CB", "CG"],
310
+ ["CD1", "CB", "CG"],
311
+ ],
312
+ "types": {
313
+ "C": "C",
314
+ "CA": "CT1",
315
+ "CB": "CT2",
316
+ "CD1": "CT3",
317
+ "CD2": "CT3",
318
+ "CG": "CT1",
319
+ "N": "NH1",
320
+ "O": "O",
321
+ },
322
+ "z-angles": [112.12, 117.46, 110.48, 112.57],
323
+ "z-dihedrals": [121.52, 180.0, 180.0, 120.0],
324
+ "z-lengths": [1.55, 1.55, 1.54, 1.54],
325
+ },
326
+ "LYS": {
327
+ "atoms": ["CB", "CG", "CD", "CE", "NZ"],
328
+ "chi_indices": [1, 2, 3, 4],
329
+ "parents": [
330
+ ["N", "C", "CA"],
331
+ ["N", "CA", "CB"],
332
+ ["CA", "CB", "CG"],
333
+ ["CB", "CG", "CD"],
334
+ ["CG", "CD", "CE"],
335
+ ],
336
+ "types": {
337
+ "C": "C",
338
+ "CA": "CT1",
339
+ "CB": "CT2",
340
+ "CD": "CT2",
341
+ "CE": "CT2",
342
+ "CG": "CT2",
343
+ "N": "NH1",
344
+ "NZ": "NH3",
345
+ "O": "O",
346
+ },
347
+ "z-angles": [111.36, 115.76, 113.28, 112.33, 110.46],
348
+ "z-dihedrals": [122.23, 180.0, 180.0, 180.0, 180.0],
349
+ "z-lengths": [1.56, 1.54, 1.54, 1.53, 1.46],
350
+ },
351
+ "MET": {
352
+ "atoms": ["CB", "CG", "SD", "CE"],
353
+ "chi_indices": [1, 2, 3],
354
+ "parents": [
355
+ ["N", "C", "CA"],
356
+ ["N", "CA", "CB"],
357
+ ["CA", "CB", "CG"],
358
+ ["CB", "CG", "SD"],
359
+ ],
360
+ "types": {
361
+ "C": "C",
362
+ "CA": "CT1",
363
+ "CB": "CT2",
364
+ "CE": "CT3",
365
+ "CG": "CT2",
366
+ "N": "NH1",
367
+ "O": "O",
368
+ "SD": "S",
369
+ },
370
+ "z-angles": [111.88, 115.92, 110.28, 98.94],
371
+ "z-dihedrals": [121.62, 180.0, 180.0, 180.0],
372
+ "z-lengths": [1.55, 1.55, 1.82, 1.82],
373
+ },
374
+ "PHE": {
375
+ "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
376
+ "chi_indices": [1, 2],
377
+ "parents": [
378
+ ["N", "C", "CA"],
379
+ ["N", "CA", "CB"],
380
+ ["CA", "CB", "CG"],
381
+ ["CD1", "CB", "CG"],
382
+ ["CB", "CG", "CD1"],
383
+ ["CB", "CG", "CD2"],
384
+ ["CG", "CD1", "CE1"],
385
+ ],
386
+ "types": {
387
+ "C": "C",
388
+ "CA": "CT1",
389
+ "CB": "CT2",
390
+ "CD1": "CA",
391
+ "CD2": "CA",
392
+ "CE1": "CA",
393
+ "CE2": "CA",
394
+ "CG": "CA",
395
+ "CZ": "CA",
396
+ "N": "NH1",
397
+ "O": "O",
398
+ },
399
+ "z-angles": [112.45, 112.76, 120.32, 120.76, 120.63, 120.62, 119.93],
400
+ "z-dihedrals": [122.49, 180.0, 90.0, -177.96, -177.37, 177.2, -0.12],
401
+ "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4],
402
+ },
403
+ "PRO": {
404
+ "atoms": ["CB", "CG", "CD"],
405
+ "chi_indices": [1, 2],
406
+ "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CA", "CB", "CG"]],
407
+ "types": {
408
+ "C": "C",
409
+ "CA": "CP1",
410
+ "CB": "CP2",
411
+ "CD": "CP3",
412
+ "CG": "CP2",
413
+ "N": "N",
414
+ "O": "O",
415
+ },
416
+ "z-angles": [111.74, 104.39, 103.21],
417
+ "z-dihedrals": [113.74, 31.61, -34.59],
418
+ "z-lengths": [1.54, 1.53, 1.53],
419
+ },
420
+ "SER": {
421
+ "atoms": ["CB", "OG"],
422
+ "chi_indices": [1],
423
+ "parents": [["N", "C", "CA"], ["N", "CA", "CB"]],
424
+ "types": {
425
+ "C": "C",
426
+ "CA": "CT1",
427
+ "CB": "CT2",
428
+ "N": "NH1",
429
+ "O": "O",
430
+ "OG": "OH1",
431
+ },
432
+ "z-angles": [111.4, 112.45],
433
+ "z-dihedrals": [124.75, 180.0],
434
+ "z-lengths": [1.56, 1.43],
435
+ },
436
+ "THR": {
437
+ "atoms": ["CB", "OG1", "CG2"],
438
+ "chi_indices": [1],
439
+ "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["OG1", "CA", "CB"]],
440
+ "types": {
441
+ "C": "C",
442
+ "CA": "CT1",
443
+ "CB": "CT1",
444
+ "CG2": "CT3",
445
+ "N": "NH1",
446
+ "O": "O",
447
+ "OG1": "OH1",
448
+ },
449
+ "z-angles": [112.74, 112.16, 115.91],
450
+ "z-dihedrals": [126.46, 180.0, -124.13],
451
+ "z-lengths": [1.57, 1.43, 1.53],
452
+ },
453
+ "TRP": {
454
+ "atoms": ["CB", "CG", "CD2", "CD1", "CE2", "NE1", "CE3", "CZ3", "CH2", "CZ2"],
455
+ "chi_indices": [1, 2],
456
+ "parents": [
457
+ ["N", "C", "CA"],
458
+ ["N", "CA", "CB"],
459
+ ["CA", "CB", "CG"],
460
+ ["CD2", "CB", "CG"],
461
+ ["CD1", "CG", "CD2"],
462
+ ["CG", "CD2", "CE2"],
463
+ ["CE2", "CG", "CD2"],
464
+ ["CE2", "CD2", "CE3"],
465
+ ["CD2", "CE3", "CZ3"],
466
+ ["CE3", "CZ3", "CH2"],
467
+ ],
468
+ "types": {
469
+ "C": "C",
470
+ "CA": "CT1",
471
+ "CB": "CT2",
472
+ "CD1": "CA",
473
+ "CD2": "CPT",
474
+ "CE2": "CPT",
475
+ "CE3": "CAI",
476
+ "CG": "CY",
477
+ "CH2": "CA",
478
+ "CZ2": "CAI",
479
+ "CZ3": "CA",
480
+ "N": "NH1",
481
+ "NE1": "NY",
482
+ "O": "O",
483
+ },
484
+ "z-angles": [
485
+ 111.23,
486
+ 115.14,
487
+ 123.95,
488
+ 129.18,
489
+ 106.65,
490
+ 107.87,
491
+ 132.54,
492
+ 118.16,
493
+ 120.97,
494
+ 120.87,
495
+ ],
496
+ "z-dihedrals": [
497
+ 122.68,
498
+ 180.0,
499
+ 90.0,
500
+ -172.81,
501
+ -0.08,
502
+ 0.14,
503
+ 179.21,
504
+ -0.2,
505
+ 0.1,
506
+ 0.01,
507
+ ],
508
+ "z-lengths": [1.56, 1.52, 1.44, 1.37, 1.41, 1.37, 1.4, 1.4, 1.4, 1.4],
509
+ },
510
+ "TYR": {
511
+ "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
512
+ "chi_indices": [1, 2],
513
+ "parents": [
514
+ ["N", "C", "CA"],
515
+ ["N", "CA", "CB"],
516
+ ["CA", "CB", "CG"],
517
+ ["CD1", "CB", "CG"],
518
+ ["CB", "CG", "CD1"],
519
+ ["CB", "CG", "CD2"],
520
+ ["CG", "CD1", "CE1"],
521
+ ["CE1", "CE2", "CZ"],
522
+ ],
523
+ "types": {
524
+ "C": "C",
525
+ "CA": "CT1",
526
+ "CB": "CT2",
527
+ "CD1": "CA",
528
+ "CD2": "CA",
529
+ "CE1": "CA",
530
+ "CE2": "CA",
531
+ "CG": "CA",
532
+ "CZ": "CA",
533
+ "N": "NH1",
534
+ "O": "O",
535
+ "OH": "OH1",
536
+ },
537
+ "z-angles": [112.34, 112.94, 120.49, 120.46, 120.4, 120.56, 120.09, 120.25],
538
+ "z-dihedrals": [122.27, 180.0, 90.0, -176.46, -175.49, 175.32, -0.19, -178.98],
539
+ "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4, 1.41],
540
+ },
541
+ "VAL": {
542
+ "atoms": ["CB", "CG1", "CG2"],
543
+ "chi_indices": [1],
544
+ "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CG1", "CA", "CB"]],
545
+ "types": {
546
+ "C": "C",
547
+ "CA": "CT1",
548
+ "CB": "CT1",
549
+ "CG1": "CT3",
550
+ "CG2": "CT3",
551
+ "N": "NH1",
552
+ "O": "O",
553
+ },
554
+ "z-angles": [111.23, 113.97, 112.17],
555
+ "z-dihedrals": [122.95, 180.0, 123.99],
556
+ "z-lengths": [1.57, 1.54, 1.54],
557
+ },
558
+ }
chroma/chroma/constants/named_models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ Paths for named models in the zoo """
16
+
17
+ GRAPH_BACKBONE_MODELS = {
18
+ "public": {
19
+ "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt",
20
+ "data": "Generate Structure ETL: July 25 2022",
21
+ "task": "BLNL backbone model training with EMA, trained July 2023",
22
+ },
23
+ }
24
+
25
+ GRAPH_CLASSIFIER_MODELS = {
26
+ "public": {
27
+ "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_proclass_v1.0.pt",
28
+ "data": "Generate Structure ETL: June 2022",
29
+ "task": "Backbone classification model training with cross-entropy loss",
30
+ },
31
+ }
32
+
33
+ GRAPH_DESIGN_MODELS = {
34
+ "public": {
35
+ "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt",
36
+ "data": "Generate Structure ETL: July 25 2022",
37
+ "task": "Autoregressive joint prediction of sequence and chi angles, two-stage",
38
+ },
39
+ }
40
+
41
+ PROCAP_MODELS = {
42
+ "public": {
43
+ "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_procap_v1.0.pt",
44
+ "data": "Generate Structure ETL: June 2022",
45
+ "task": "Backbone caption model training with cross-entropy loss, using M5 ProClass GNN embeddings",
46
+ },
47
+ }
48
+
49
+ NAMED_MODELS = {
50
+ "GraphBackbone": GRAPH_BACKBONE_MODELS,
51
+ "GraphDesign": GRAPH_DESIGN_MODELS,
52
+ "GraphClassifier": GRAPH_CLASSIFIER_MODELS,
53
+ "ProteinCaption": PROCAP_MODELS,
54
+ }
chroma/chroma/constants/sequence.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants used across protein representations.
16
+
17
+ These constants standardize protein tokenization alphabets, ideal structure
18
+ geometries and topologies, etc.
19
+ """
20
+ from chroma.constants.geometry import AA_GEOMETRY
21
+
22
+ # Standard tokenization for Omniprot and Omniprot-interacting models
23
+ OMNIPROT_TOKENS = "ABCDEFGHIKLMNOPQRSTUVWYXZ*-#"
24
+ POTTS_EXTENDED_TOKENS = "ACDEFGHIKLMNPQRSTVWY-*#"
25
+ PAD = "-"
26
+ START = "@"
27
+ STOP = "*"
28
+ MASK = "#"
29
+ DNA_TOKENS = "ACGT"
30
+ RNA_TOKENS = "AGCU"
31
+ PROTEIN_TOKENS = "ACDEFGHIKLMNPQRSTVWY"
32
+
33
+ # Minimal 20-letter alphabet and corresponding triplet codes
34
+ AA20 = "ACDEFGHIKLMNPQRSTVWY"
35
+ AA20_3_TO_1 = {
36
+ "ALA": "A",
37
+ "ARG": "R",
38
+ "ASN": "N",
39
+ "ASP": "D",
40
+ "CYS": "C",
41
+ "GLN": "Q",
42
+ "GLU": "E",
43
+ "GLY": "G",
44
+ "HIS": "H",
45
+ "ILE": "I",
46
+ "LEU": "L",
47
+ "LYS": "K",
48
+ "MET": "M",
49
+ "PHE": "F",
50
+ "PRO": "P",
51
+ "SER": "S",
52
+ "THR": "T",
53
+ "TRP": "W",
54
+ "TYR": "Y",
55
+ "VAL": "V",
56
+ }
57
+ AA20_1_TO_3 = {
58
+ "A": "ALA",
59
+ "R": "ARG",
60
+ "N": "ASN",
61
+ "D": "ASP",
62
+ "C": "CYS",
63
+ "Q": "GLN",
64
+ "E": "GLU",
65
+ "G": "GLY",
66
+ "H": "HIS",
67
+ "I": "ILE",
68
+ "L": "LEU",
69
+ "K": "LYS",
70
+ "M": "MET",
71
+ "F": "PHE",
72
+ "P": "PRO",
73
+ "S": "SER",
74
+ "T": "THR",
75
+ "W": "TRP",
76
+ "Y": "TYR",
77
+ "V": "VAL",
78
+ }
79
+ AA20_3 = [AA20_1_TO_3[aa] for aa in AA20]
80
+
81
+ # Adding noncanonical amino acids
82
+ NONCANON_AA = [
83
+ "HSD",
84
+ "HSE",
85
+ "HSC",
86
+ "HSP",
87
+ "MSE",
88
+ "CSO",
89
+ "SEC",
90
+ "CSX",
91
+ "HIP",
92
+ "SEP",
93
+ "TPO",
94
+ ]
95
+ AA31_3 = AA20_3 + NONCANON_AA
96
+
97
+ # Chain alphabet for PDB chain naming
98
+ CHAIN_ALPHABET = "_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
99
+
100
+ # Standard atom indexing
101
+ ATOMS_BB = ["N", "CA", "C", "O"]
102
+
103
+ ATOM_SYMMETRIES = {
104
+ "ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling
105
+ "ASP": [("OD1", "OD2")],
106
+ "GLU": [("OE1", "OE2")],
107
+ "PHE": [("CD1", "CD2"), ("CE1", "CE2")],
108
+ "TYR": [("CD1", "CD2"), ("CE1", "CE2")],
109
+ }
110
+
111
+ AA20_NUM_ATOMS = [4 + len(AA_GEOMETRY[aa]["atoms"]) for aa in AA20_3]
112
+ AA20_NUM_CHI = [len(AA_GEOMETRY[aa]["chi_indices"]) for aa in AA20_3]
chroma/chroma/data/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This package includes io formats and tools for a few common datatypes,
17
+ including antibodies, proteins, sequences, and structures.
18
+ """
19
+ from chroma.data.protein import Protein
chroma/chroma/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (339 Bytes). View file
 
chroma/chroma/data/__pycache__/protein.cpython-38.pyc ADDED
Binary file (19.3 kB). View file
 
chroma/chroma/data/__pycache__/system.cpython-38.pyc ADDED
Binary file (136 kB). View file
 
chroma/chroma/data/__pycache__/xcs.cpython-38.pyc ADDED
Binary file (3.83 kB). View file
 
chroma/chroma/data/protein.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import copy
18
+ import os
19
+ import tempfile
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import nglview as nv
23
+ import torch
24
+
25
+ import chroma.utility.polyseq as polyseq
26
+ from chroma.constants import CHAIN_ALPHABET, PROTEIN_TOKENS
27
+ from chroma.data.system import System, SystemEntity
28
+
29
+
30
+ class Protein:
31
+ """
32
+ Protein: A utility class for managing proteins within the Chroma ecosystem.
33
+
34
+ The Protein class offers a suite of methods for loading, saving, transforming, and viewing protein structures
35
+ and trajectories from a variety of input sources such as PDBID, CIF files, and XCS representations.
36
+
37
+ Attributes:
38
+ sys (System): A protein system object used for various molecular operations.
39
+ device (str): Specifies the device on which tensors are managed. Defaults to `cpu`.
40
+ """
41
+
42
+ sys: System
43
+ device: str = "cpu"
44
+
45
+ def __new__(cls, *args, **kwargs):
46
+ """Handles automatic loading of the protein based on the input.
47
+ Specifically deals with XCS
48
+
49
+ Args:
50
+ protein_input (_type_): _description_
51
+ """
52
+
53
+ if len(args) == 1 and isinstance(args[0], System):
54
+ return cls.from_system(*args, **kwargs)
55
+
56
+ elif len(args) == 3: # 3 Tensor Arguments
57
+ X, C, S = args
58
+ assert isinstance(
59
+ C, torch.Tensor
60
+ ), f"arg[1] must be a chain (C) torch.Tensor, but get {type(C)}"
61
+ assert isinstance(
62
+ S, torch.Tensor
63
+ ), f"arg[2] must be a sequence (S) torch.Tensor, but get {type(S)}"
64
+ if isinstance(X, list):
65
+ assert all(
66
+ isinstance(x, torch.Tensor) for x in X
67
+ ), "arg[0] must be an X torch.Tensor or a list of X torch.Tensors"
68
+ return cls.from_XCS_trajectory(X, C, S)
69
+ elif isinstance(X, torch.Tensor):
70
+ return cls.from_XCS(X, C, S)
71
+ else:
72
+ raise TypeError(
73
+ f"X must be a list of torch.Tensor that respects XCS format, but get {type(X), type(C), type(S)}"
74
+ )
75
+
76
+ elif len(args) == 1 and isinstance(args[0], str):
77
+ if args[0].lower().startswith("s3:"):
78
+ raise NotImplementedError(
79
+ "download of cifs or pdbs from s3 not supported."
80
+ )
81
+
82
+ if args[0].endswith(".cif"):
83
+ return cls.from_CIF(*args, **kwargs)
84
+
85
+ elif args[0].endswith(".pdb"):
86
+ return cls.from_PDB(*args, **kwargs)
87
+
88
+ else: # PDB or Sequence String
89
+ # Check if it is a valid PDB
90
+ import requests
91
+
92
+ url = f"https://data.rcsb.org/rest/v1/core/entry/{args[0]}"
93
+ VALID_PDBID = requests.get(url).status_code == 200
94
+ VALID_SEQUENCE = all([s in PROTEIN_TOKENS for s in args[0]])
95
+
96
+ if VALID_PDBID:
97
+ # This only works if connected to the internet,
98
+ # so maybe better status checking will help here
99
+ if VALID_PDBID and VALID_SEQUENCE:
100
+ raise Warning(
101
+ "Ambuguous input, this is both a valid Sequence string and"
102
+ " a valid PDBID. Interpreting as a PDBID, if you wish to"
103
+ " initialize as a sequence string please explicitly"
104
+ " initialize as Protein.from_sequence(MY_SEQUENCE)."
105
+ )
106
+ return cls.from_PDBID(*args, **kwargs)
107
+ elif VALID_SEQUENCE:
108
+ return cls.from_sequence(*args, **kwargs)
109
+ else:
110
+ raise NotImplementedError(
111
+ "Could Not Identify a valid input Type. See docstring for"
112
+ " details."
113
+ )
114
+ else:
115
+ raise NotImplementedError(
116
+ "Inputs must either be a 3-tuple of XCS tensors, or a single string"
117
+ )
118
+
119
+ @classmethod
120
+ def from_system(cls, system: System, device: str = "cpu") -> Protein:
121
+ protein = super(Protein, cls).__new__(cls)
122
+ protein.sys = system
123
+ protein.device = device
124
+ return protein
125
+
126
+ @classmethod
127
+ def from_XCS(cls, X: torch.Tensor, C: torch.Tensor, S: torch.Tensor) -> Protein:
128
+ """
129
+ Create a Protein object from XCS representations.
130
+
131
+ Args:
132
+ X (torch.Tensor): A 4D tensor representing atomic coordinates of proteins.
133
+ Dimensions are `(batch, residues, atoms (4 or 14), coordinates (3))`.
134
+ C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers.
135
+ Sign of the value indicates presence (+) or absence (-) of structural
136
+ information for that residue. Magnitude indicates which chain the residue belongs to.
137
+ S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains
138
+ non-negative integers representing residue types at each position.
139
+
140
+ Returns:
141
+ Protein: Initialized Protein object from the given XCS representation.
142
+ """
143
+ protein = super(Protein, cls).__new__(cls)
144
+ protein.sys = System.from_XCS(X, C, S)
145
+ protein.device = X.device
146
+ return protein
147
+
148
+ @classmethod
149
+ def from_XCS_trajectory(
150
+ cls, X_traj: List[torch.Tensor], C: torch.Tensor, S: torch.Tensor
151
+ ) -> Protein:
152
+ """
153
+ Initialize a Protein object from a trajectory of XCS representations.
154
+
155
+ Args:
156
+ X_traj (List[torch.Tensor]): List of X tensor representations over time. Each tensor represents atomic
157
+ coordinates of proteins with dimensions `(batch, residues, atoms (4 or 14), coordinates (3))`.
158
+ C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers.
159
+ Sign of the value indicates presence (+) or absence (-) of structural
160
+ information for that residue. Magnitude indicates which chain the residue belongs to.
161
+ S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains
162
+ non-negative integers representing residue types at each position.
163
+
164
+ Returns:
165
+ Protein: Protein object initialized from the XCS trajectory.
166
+ """
167
+ protein = super(Protein, cls).__new__(cls)
168
+ protein.sys = System.from_XCS(X_traj[0], C, S)
169
+ protein.device = C.device
170
+ for X in X_traj[1:]:
171
+ protein.sys.add_model_from_X(X[C > 0])
172
+ return protein
173
+
174
+ @classmethod
175
+ def from_PDB(cls, input_file: str, device: str = "cpu") -> Protein:
176
+ """
177
+ Load a Protein object from a provided PDB file.
178
+
179
+ Args:
180
+ input_file (str): Path to the PDB file to be loaded.
181
+ device (str, optional): The device for tensor operations. Defaults to 'cpu'.
182
+
183
+ Returns:
184
+ Protein: Initialized Protein object from the provided PDB file.
185
+ """
186
+ protein = super(Protein, cls).__new__(cls)
187
+ protein.sys = System.from_PDB(input_file)
188
+ protein.device = device
189
+ return protein
190
+
191
+ @classmethod
192
+ def from_CIF(
193
+ cls, input_file: str, canonicalize: bool = True, device: str = "cpu"
194
+ ) -> Protein:
195
+ """
196
+ Load a Protein object from a provided CIF format.
197
+
198
+ Args:
199
+ input_file (str): Path to the CIF file to be loaded.
200
+ device (str, optional): The device for tensor operations. Defaults to 'cpu'.
201
+
202
+ Returns:
203
+ Protein: Initialized Protein object from the provided CIF file.
204
+ """
205
+ protein = super(Protein, cls).__new__(cls)
206
+ protein.sys = System.from_CIF(input_file)
207
+ protein.device = device
208
+ if canonicalize:
209
+ protein.canonicalize()
210
+ return protein
211
+
212
+ @classmethod
213
+ def from_PDBID(
214
+ cls, pdb_id: str, canonicalize: bool = True, device: str = "cpu"
215
+ ) -> Protein:
216
+ """
217
+ Load a Protein object using its PDBID by fetching the corresponding CIF file from the Protein Data Bank.
218
+
219
+ This method downloads the CIF file for the specified PDBID, processes it to create a Protein object,
220
+ and then deletes the temporary CIF file.
221
+
222
+ Args:
223
+ pdb_id (str): The PDBID of the protein to fetch.
224
+ canonicalize (bool, optional): If set to True, the protein will be canonicalized post-loading. Defaults to True.
225
+ device (str, optional): The device for tensor operations. Defaults to 'cpu'.
226
+
227
+ Returns:
228
+ Protein: An instance of the Protein class initialized from the fetched CIF file corresponding to the PDBID.
229
+ """
230
+ from os import unlink
231
+
232
+ from chroma.utility.fetchdb import RCSB_file_download
233
+
234
+ file_cif = os.path.join(tempfile.gettempdir(), f"{pdb_id}.cif")
235
+ RCSB_file_download(pdb_id, ".cif", file_cif)
236
+ protein = cls.from_CIF(file_cif, canonicalize=canonicalize, device=device)
237
+ unlink(file_cif)
238
+ return protein
239
+
240
+ @classmethod
241
+ def from_sequence(
242
+ cls, chains: Union[List[str], str], device: str = "cpu"
243
+ ) -> Protein:
244
+ """
245
+ Load a protein object purely from Sequence with no structural content.
246
+
247
+ Args:
248
+ chains (Union[List[str],str]): a list of sequence strings, or a sequence string to create the protein.
249
+ device (str, optional): which device for torch outputs should be used. Defaults to "cpu".
250
+
251
+ Returns:
252
+ Protein: An instance of the Protein class initialized a sequence or list of sequences.
253
+ """
254
+
255
+ if isinstance(chains, str):
256
+ chains = [chains]
257
+
258
+ system = System("system")
259
+ for c_ix, seq in enumerate(chains):
260
+ chain_id = CHAIN_ALPHABET[c_ix + 1]
261
+ chain = system.add_chain(chain_id)
262
+
263
+ # Populate the Chain
264
+ three_letter_sequence = []
265
+ for s_ix, s in enumerate(seq):
266
+ resname = polyseq.to_triple(s)
267
+ three_letter_sequence.append(resname)
268
+ chain.add_residue(resname, s_ix + 1, "")
269
+
270
+ # Add Entity
271
+ sys_entity = SystemEntity(
272
+ "polymer",
273
+ f"Sequence Chain {chain_id}",
274
+ "polypeptide(L)",
275
+ three_letter_sequence,
276
+ [False] * len(three_letter_sequence),
277
+ )
278
+ system.add_new_entity(sys_entity, [c_ix])
279
+
280
+ protein = super(Protein, cls).__new__(cls)
281
+ protein.sys = system
282
+ protein.device = device
283
+ return protein
284
+
285
+ def to_CIF(self, output_file: str, force: bool = False) -> None:
286
+ """
287
+ Save the current Protein object to a file in CIF format.
288
+
289
+ Args:
290
+ output_file (str): The path where the CIF file should be saved.
291
+
292
+ """
293
+ if output_file.lower().startswith("s3:"):
294
+ raise NotImplementedError("cif output to an s3 bucket not supported.")
295
+ else:
296
+ self.sys.to_CIF(output_file)
297
+
298
+ def to_PDB(self, output_file: str, force: bool = False) -> None:
299
+ """
300
+ Save the current Protein object to a file in PDB format.
301
+
302
+ Args:
303
+ output_file (str): The path where the PDB file should be saved.
304
+ """
305
+ if output_file.lower().startswith("s3:"):
306
+ raise NotImplementedError("pdb output to an s3 bucket not supported.")
307
+
308
+ else:
309
+ self.sys.to_PDB(output_file)
310
+
311
+ def to_XCS(
312
+ self, all_atom: bool = False, device: Optional[str] = None
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
314
+ """
315
+ Convert the current Protein object to its XCS tensor representations.
316
+
317
+ Args:
318
+ all_atom (bool, optional): Indicates if all atoms should be considered in the conversion. Defaults to False.
319
+ device (str, optional): the device to export XCS tensors to. If not specified uses the device property
320
+ set in the class. Default None.
321
+
322
+ Returns:
323
+ X (torch.Tensor): A 4D tensor representing atomic coordinates of proteins with dimensions
324
+ `(batch, residues, atoms (4 or 14), coordinates (3))`.
325
+ C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers. Sign of
326
+ the value indicates presence (+) or absence (-) of structural information for that residue.
327
+ Magnitude indicates which chain the residue belongs to.
328
+ S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains non-negative
329
+ integers representing residue types at each position.
330
+ """
331
+
332
+ if device is None:
333
+ device = self.device
334
+
335
+ X, C, S = [tensor.to(device) for tensor in self.sys.to_XCS(all_atom=all_atom)]
336
+
337
+ return X, C, S
338
+
339
+ def to_XCS_trajectory(
340
+ self,
341
+ device: Optional[str] = None,
342
+ ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
343
+ """
344
+ Convert the current Protein object to its XCS tensor representations over a trajectory.
345
+
346
+ Args:
347
+ device (str, optional): the device to export XCS tensors to. If not specified uses the device property
348
+ set in the class. Default None.
349
+
350
+ Returns:
351
+ X_traj (List[torch.Tensor]): List of X tensor representations over time. Each tensor represents atomic
352
+ coordinates of proteins with dimensions `(batch, residues, atoms (4 or 14), coordinates (3))`.
353
+ C (torch.Tensor): A chain label tensor of shape `(batch, residues)`. Values are integers. Sign of
354
+ the value indicates presence (+) or absence (-) of structural information for that residue.
355
+ Magnitude indicates which chain the residue belongs to.
356
+ S (torch.Tensor): A sequence information tensor of shape `(batch, residues)`. Contains non-negative
357
+ integers representing residue types at each position.
358
+ """
359
+ X, C, S = [], None, None
360
+ for i in range(self.sys.num_models()):
361
+ self.sys.swap_model(i)
362
+ if i == 0:
363
+ X_frame, C, S, loc_indices = self.sys.to_XCS(get_indices=True)
364
+ else:
365
+ X_frame.flatten(0, 2)[:] = torch.from_numpy(
366
+ self.sys._locations["coor"][loc_indices, 0:3]
367
+ )
368
+ X.append(X_frame.clone())
369
+ self.sys.swap_model(i)
370
+ X = torch.cat(X)
371
+
372
+ if device is None:
373
+ device = self.device
374
+
375
+ Xtraj, C, S = [tensor.to(device) for tensor in [X, C, S]]
376
+ return [each.unsqueeze(0) for each in Xtraj], C, S
377
+
378
+ def to(self, file_path: str, force: bool = False) -> None:
379
+ """
380
+ General Export for the Protein Class
381
+
382
+ This method allows for export in pdf or cif based on the file extension.
383
+ explicit saving is still available with the respective export methods.
384
+
385
+ Args:
386
+ device (str): The desired device for tensor operations, e.g., 'cpu' or 'cpu'.
387
+ """
388
+ if file_path.lower().endswith(".pdb"):
389
+ self.to_PDB(file_path, force=force)
390
+ elif file_path.lower().endswith(".cif"):
391
+ self.to_CIF(file_path, force=force)
392
+ else:
393
+ raise NotImplementedError(
394
+ "file path must end with either *.cif or *.pdb for export."
395
+ )
396
+
397
+ def length(self, structured: bool = False) -> None:
398
+ """
399
+ Retrieve the length of the protein.
400
+
401
+ Args:
402
+ structured (bool, optional): If set to True, returns the residue size of the structured part of the protein.
403
+ Otherwise, returns the length of the entire protein. Defaults to False.
404
+
405
+ Returns:
406
+ int: Length of the protein or its structured part based on the 'structured' argument.
407
+ """
408
+ if structured:
409
+ return self.sys.num_structured_residues()
410
+ return self.sys.num_residues()
411
+
412
+ __len__ = length
413
+
414
+ def canonicalize(self) -> None:
415
+ """
416
+ Canonicalize the protein's backbone geometry.
417
+
418
+ This method processes the protein to ensure it conforms to a canonical form.
419
+ """
420
+ self.sys.canonicalize_protein(
421
+ level=2,
422
+ drop_coors_unknowns=True,
423
+ drop_coors_missing_backbone=True,
424
+ )
425
+
426
+ def sequence(self, format: str = "one-letter-string") -> Union[List[str], str]:
427
+ """
428
+ Retrieve the sequence of the protein in the specified format.
429
+
430
+ Args:
431
+ format (str, optional): The desired format for the sequence. Can be 'three-letter-list' or 'one-letter-string'.
432
+ Defaults to 'one-letter-string'.
433
+
434
+ Returns:
435
+ Union[List[str], str]: The protein sequence in the desired format.
436
+
437
+ Raises:
438
+ Exception: If an unknown sequence format is provided.
439
+ """
440
+ if format == "three-letter-list":
441
+ return list(self.sys.sequence())
442
+ elif format == "one-letter-string":
443
+ return self.sys.sequence("one-letter-string")
444
+ else:
445
+ raise Exception(f"unknown sequence format {format}")
446
+
447
+ def display(self, representations: list = []) -> None:
448
+ """
449
+ Display the protein using the provided representations in NGL view.
450
+
451
+ Args:
452
+ representations (list, optional): List of visual representations to use in the display. Defaults to an empty list.
453
+
454
+ Returns:
455
+ viewer: A viewer object for interactive visualization.
456
+ """
457
+ from chroma.utility.ngl import SystemTrajectory, view_gsystem
458
+
459
+ if self.sys.num_models() == 1:
460
+ viewer = view_gsystem(self.sys)
461
+ for rep in representations:
462
+ viewer.add_representation(rep)
463
+
464
+ else:
465
+ t = SystemTrajectory(self)
466
+ viewer = nv.NGLWidget(t)
467
+ return viewer
468
+
469
+ def _ipython_display_(self):
470
+ display(self.display())
471
+
472
+ def __str__(self):
473
+ """Define Print Behavior
474
+ Return Protein Sequence Along with some useful statistics.
475
+ """
476
+ protein_string = f"Protein: {self.sys.name}\n"
477
+ for chain in self.sys.chains():
478
+ if chain.sequence is not None:
479
+ protein_string += (
480
+ f"> Chain {chain.cid} ({len(chain.sequence())} residues)\n"
481
+ )
482
+ protein_string += "".join(
483
+ [polyseq.to_single(s) for s in chain.sequence()]
484
+ )
485
+ protein_string += "\n\n"
486
+
487
+ return protein_string
488
+
489
+ def get_mask(self, selection: str) -> torch.Tensor:
490
+ """
491
+ Generate a mask tensor based on the provided residue selection.
492
+
493
+ Args:
494
+ selection (str): A selection string to specify which residues should be included in the mask.
495
+
496
+ Returns:
497
+ torch.Tensor: A mask tensor of shape `(1, protein length)`, where positions corresponding to selected residues have a value of 1.
498
+ """
499
+ residue_gtis = self.sys.select_residues(selection, gti=True)
500
+ D = torch.zeros(1, self.sys.num_residues(), device=self.device)
501
+ for gti in residue_gtis:
502
+ D[0, gti] = 1
503
+ return D
504
+
505
+ def __copy__(self):
506
+ new_system = copy.copy(self.sys)
507
+ device = self.device
508
+ return Protein(new_system, device=device)
509
+
510
+ def __deepcopy__(self, memo):
511
+ new_system = copy.deepcopy(self.sys)
512
+ device = self.device
513
+ return Protein(new_system, device=device)
chroma/chroma/data/system.py ADDED
The diff for this file is too large to render. See raw diff
 
chroma/chroma/data/xcs.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """XCS represents protein structure as a tuple of PyTorch tensors.
16
+
17
+ The tensors in an XCS representation are:
18
+
19
+ `X` (FloatTensor), the Cartesian coordinates representing the protein
20
+ structure with shape `(num_batch, num_residues, num_atoms, 3)`. The
21
+ `num_atoms` dimension can be one of two sizes: `num_atoms=4` for
22
+ backbone-only structures or `num_atoms=14` for all-atom structures
23
+ (excluding hydrogens). The first four atoms will always be
24
+ `N, CA, C, O`, and the meaning of the optional 10 additional atom
25
+ positions will vary based on the residue identity at
26
+ a given position. Atom orders for each amino acid are defined in
27
+ `constants.AA_GEOMETRY[TRIPLET_CODE]["atoms"]`.
28
+
29
+ `C` (LongTensor), the chain map encoding per-residue chain assignments with
30
+ shape `(num_batch, num_residues)`.The chain map codes positions as `0`
31
+ when masked, poitive integers for chain indices, and negative integers
32
+ to represent missing residues (of the corresponding positive integers).
33
+
34
+ `S` (LongTensor), the sequence of the protein as alphabet indices with
35
+ shape `(num_batch, num_residues)`. The standard alphabet is
36
+ `ACDEFGHIKLMNPQRSTVWY`, also defined in `constants.AA20`.
37
+ """
38
+
39
+
40
+ from functools import partial, wraps
41
+ from inspect import getfullargspec
42
+
43
+ import torch
44
+ from torch.nn import functional as F
45
+
46
+ try:
47
+ pass
48
+ except ImportError:
49
+ print("MST not installed!")
50
+
51
+
52
+ def validate_XCS(all_atom=None, sequence=True):
53
+ """Decorator factory that adds XCS validation to any function.
54
+
55
+ Args:
56
+ all_atom (bool, optional): If True, requires that input structure
57
+ tensors have 14 residues per atom. If False, reduces to 4 residues
58
+ per atom. If None, applies no transformation on input structures.
59
+ sequence (bool, optional): If True, makes sure that if S and O are both
60
+ provided, that they match, i.e. that O is a one-hot version of S.
61
+ If only one of S or O is provided, the other is generated, and both
62
+ are passed.
63
+ """
64
+
65
+ def decorator(func):
66
+ @wraps(func)
67
+ def new_func(*args, **kwargs):
68
+ args = list(args)
69
+ arg_list = getfullargspec(func)[0]
70
+ tensors = {}
71
+ for var in ["X", "C", "S", "O"]:
72
+ try:
73
+ if var in kwargs:
74
+ tensors[var] = kwargs[var]
75
+ else:
76
+ tensors[var] = args[arg_list.index(var)]
77
+ except IndexError: # empty args_list
78
+ tensors[var] = None
79
+ except ValueError: # variable not an argument of function
80
+ if not sequence and var in ["S", "O"]:
81
+ pass
82
+ else:
83
+ raise Exception(
84
+ f"Variable {var} is required by validation but not defined!"
85
+ )
86
+ if tensors["X"] is not None and tensors["C"] is not None:
87
+ if tensors["X"].shape[:2] != tensors["C"].shape[:2]:
88
+ raise ValueError(
89
+ f"X shape {tensors['X'].shape} does not match C shape"
90
+ f" {tensors['C'].shape}"
91
+ )
92
+ if all_atom is not None and tensors["X"] is not None:
93
+ if all_atom and tensors["X"].shape[2] != 14:
94
+ raise ValueError("Side chain atoms missing!")
95
+ elif not all_atom:
96
+ if "X" in kwargs:
97
+ kwargs["X"] = tensors["X"][:, :, :4]
98
+ else:
99
+ args[arg_list.index("X")] = tensors["X"][:, :, :4]
100
+ if sequence and (tensors["S"] is not None or tensors["O"] is not None):
101
+ if tensors["O"] is None:
102
+ if "O" in kwargs:
103
+ kwargs["O"] = F.one_hot(tensors["S"], 20).float()
104
+ else:
105
+ args[arg_list.index("O")] = F.one_hot(tensors["S"], 20).float()
106
+ elif tensors["S"] is None:
107
+ if "S" in kwargs:
108
+ kwargs["S"] = tensors["O"].argmax(dim=2)
109
+ else:
110
+ args[arg_list.index("S")] = tensors["O"].argmax(dim=2)
111
+ else:
112
+ if not torch.allclose(tensors["O"].argmax(dim=2), tensors["S"]):
113
+ raise ValueError("S and O are both provided but don't match!")
114
+ return func(*args, **kwargs)
115
+
116
+ return new_func
117
+
118
+ return decorator
119
+
120
+
121
+ validate_XC = partial(validate_XCS, sequence=False)
chroma/chroma/layers/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This package contains low-level PyTorch layers, including ``nn.Module`` s and ops.
17
+ These layers are often used in :mod:`chroma.models`.
18
+ """
chroma/chroma/layers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (291 Bytes). View file
 
chroma/chroma/layers/__pycache__/attention.cpython-38.pyc ADDED
Binary file (12.8 kB). View file
 
chroma/chroma/layers/__pycache__/basic.cpython-38.pyc ADDED
Binary file (18.6 kB). View file
 
chroma/chroma/layers/__pycache__/complexity.cpython-38.pyc ADDED
Binary file (5.45 kB). View file
 
chroma/chroma/layers/__pycache__/conv.cpython-38.pyc ADDED
Binary file (1.14 kB). View file
 
chroma/chroma/layers/__pycache__/graph.cpython-38.pyc ADDED
Binary file (34.6 kB). View file
 
chroma/chroma/layers/__pycache__/linalg.cpython-38.pyc ADDED
Binary file (3.2 kB). View file
 
chroma/chroma/layers/__pycache__/norm.cpython-38.pyc ADDED
Binary file (7.03 kB). View file
 
chroma/chroma/layers/__pycache__/sde.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
chroma/chroma/layers/attention.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ """
19
+ 们实现了Transformer模型中的关键组件:缩放点积注意力(Scaled Dot Product Attention)和多头注意力(Multi-Head Attention)。
20
+ """
21
+ class ScaledDotProductAttention(nn.Module):
22
+ """Scaled dot product attention as described in Eqn 1 of Vaswani et al. 2017 [https://arxiv.org/abs/1706.03762].
23
+
24
+ Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
25
+
26
+ Note that the dimension of the query has to match the dimension of the keys (here specified as ```d_k```) and the length of keys has to match
27
+ the length of the values. See for instance 'The Illustrated Transformer' [http://jalammar.github.io/illustrated-transformer/]
28
+ for pictorial depiction of attention.
29
+
30
+ Inputs:
31
+ Q (torch.tensor): of shape (batch_size, sequence_length_q, d_k)
32
+ K (torch.tensor): of shape (batch_size, sequence_length_k, d_k)
33
+ V (torch.tensor): of shape (batch_size, sequence_length_k, d_v)
34
+ mask (torch.tensor): of dtype (bool) or (byte) and shape (batch_size, 1, sequence_length_k), optional
35
+ zeroes (or False) indicate positions that cannot contribute to attention
36
+ Outputs:
37
+ output (torch.tensor) of shape (batch_size, sequence_length_q, d_v). The [i-j]-entry output[i,j,:] is formed as a convex combination of values:
38
+ \sum_k a_k V[i,k,:] and \sum_k a_k = 1.
39
+ attentions (torch.tensor) of shape (batch_size, sequence_length_q, sequence_length_k)) where the [b,i,j]-element
40
+ corresponds to the attention value (e.g relative contribution) of position j in the key-tensor to position i in the query tensor in element b of the batch.
41
+ """
42
+
43
+ def __init__(self):
44
+ super(ScaledDotProductAttention, self).__init__()
45
+ self.softmax = nn.Softmax(dim=-1)
46
+
47
+ def forward(self, Q, K, V, mask=None):
48
+ _, _, d = K.size()
49
+ attn = torch.bmm(Q, K.transpose(1, 2)) / d ** 0.5
50
+ if mask is not None:
51
+ attn = attn.float().masked_fill(mask == 0, -1e9)
52
+
53
+ attn = self.softmax(attn)
54
+ if mask is not None:
55
+ attn = attn.float().masked_fill(mask == 0, 0)
56
+
57
+ if V.dtype == torch.float16:
58
+ attn = attn.half()
59
+ output = torch.bmm(attn, V)
60
+ return output, attn
61
+
62
+
63
+ class MultiHeadAttention(nn.Module):
64
+ """Multi-head attention with scaled dot product attention. See 'The Annotated Transformer'
65
+ http://nlp.seas.harvard.edu/2018/04/03/attention.html or 'The Illustrated Transformer' http://jalammar.github.io/illustrated-transformer/
66
+ for details and intuition.
67
+
68
+ Args:
69
+ n_head (int): number of attention heads
70
+ d_k (int): dimension of the keys and queries in each attention head
71
+ d_v (int): dimension of the values in each attention head
72
+ d_model (int): input and output dimension for the layer
73
+ dropout (float): dropout rate, default is 0.1
74
+
75
+ Inputs:
76
+ Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)```
77
+ K (torch.tensor): key tensor of shape ```(batch_size, sequence_length_k, d_model)```
78
+ V (torch.tensor): value tensor of shape ```(batch_size, sequence_length_k, d_model)```
79
+ mask (torch.tensor): (optional) of dtype ```bool`` or ```byte``` and size (batch_size, 1, sequence_length_k),
80
+ zeroes (or False) indicate positions that cannot contribute to attention
81
+
82
+ Outputs:
83
+ output (torch.tensor) : of shape ```(batch_size, sequence_length_q, d_model)```
84
+ attentions (torch.tensor): of shape ```(batch_size * n_head, sequence_length_q, sequence_length_k) where
85
+ ```attentions[batch_size*(i):batch_size*(i+1),:,:]``` corresponds to the batch of attention blocks for i'th head. See
86
+ ```chroma.layers.attention.ScaledDotProductAttention``` for more details
87
+ """
88
+
89
+ def __init__(self, n_head, d_k, d_v, d_model, dropout=0.1):
90
+ super(MultiHeadAttention, self).__init__()
91
+ self.n_head = n_head
92
+ self.d_k = d_k
93
+ self.d_v = d_v
94
+ self.d_model = d_model
95
+ self.Wq = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
96
+ self.Wk = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
97
+ self.Wv = nn.Parameter(torch.Tensor(n_head, d_model, d_v))
98
+ self.Wo = nn.Parameter(torch.Tensor(n_head * d_v, d_model))
99
+ self.attention = ScaledDotProductAttention()
100
+ self.dropout = nn.Dropout(p=dropout)
101
+ self.reset_parameters()
102
+
103
+ def reset_parameters(self):
104
+ nn.init.xavier_normal_(self.Wq)
105
+ nn.init.xavier_normal_(self.Wk)
106
+ nn.init.xavier_normal_(self.Wv)
107
+ nn.init.kaiming_uniform_(self.Wo)
108
+
109
+ def forward(self, Q, K, V, bias=None, mask=None):
110
+ mb_size, len_q, d_q_in = Q.size()
111
+ mb_size, len_k, d_k_in = K.size()
112
+ mb_size, len_v, d_v_in = V.size()
113
+ d_model = self.d_model
114
+ if d_q_in != d_model:
115
+ raise ValueError("Dimension of Q does not match d_model.")
116
+
117
+ if d_k_in != d_model:
118
+ raise ValueError("Dimension of K does not match d_model.")
119
+
120
+ if d_v_in != d_model:
121
+ raise ValueError("Dimension of V does not match d_model.")
122
+
123
+ # treat as a (n_head) size batch and project to d_k and d_v
124
+ q_s = torch.cat([Q @ W for W in self.Wq]) # (n_head*mb_size) x len_q x d_k
125
+ k_s = torch.cat([K @ W for W in self.Wk]) # (n_head*mb_size) x len_k x d_k
126
+ v_s = torch.cat([V @ W for W in self.Wv]) # (n_head*mb_size) x len_v x d_v
127
+
128
+ # Attention
129
+ if mask is not None:
130
+ mask = mask.repeat(self.n_head, 1, 1)
131
+ outputs, attns = self.attention(q_s, k_s, v_s, mask=mask)
132
+
133
+ # Back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
134
+ outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)
135
+
136
+ # Project back to residual size
137
+ outputs = outputs @ self.Wo
138
+ outputs = self.dropout(outputs)
139
+ return outputs, attns
140
+
141
+
142
+ class AttentionChainPool(nn.Module):
143
+ """Pools residue-based representations to chain-based representations using a chain mask and attention.
144
+ Args:
145
+ n_head (int): number of attention heads
146
+ d_model (int): dimension of embeddings to be pooled
147
+
148
+ Inputs:
149
+ h (torch.tensor): of size (batch_size, sequence_length, d_model)
150
+ C (torch.tensor): of size (batch_size, sequence_length)
151
+
152
+ Outputs:
153
+ output (torch.tensor): of size (batch_size, n_chains, d_model)
154
+ chain_mask (torch.tensor): of size (batch_size, n_chains)
155
+ """
156
+
157
+ def __init__(self, n_head, d_model):
158
+ super().__init__()
159
+ self.attention = MultiHeadAttention(
160
+ n_head, d_model, d_model, d_model, dropout=0.0
161
+ )
162
+
163
+ def get_query(self, x):
164
+ return torch.ones(x.size(0), 1, x.size(2)).type(x.dtype).to(x.device)
165
+
166
+ def forward(self, h, C):
167
+ bs, num_res = C.size()
168
+ chains = C.abs().unique()
169
+ chains = (
170
+ chains[chains > 0].unsqueeze(-1).repeat(1, bs).reshape(-1).unsqueeze(-1)
171
+ )
172
+ num_chains = len(chains.unique())
173
+
174
+ h_repeat = h.repeat(num_chains, 1, 1)
175
+ C_repeat = C.repeat(num_chains, 1)
176
+ mask = (C_repeat == chains).unsqueeze(-2)
177
+
178
+ output, _ = self.attention(
179
+ self.get_query(h_repeat), h_repeat, h_repeat, mask=mask
180
+ )
181
+ output = torch.cat(output.split(bs), 1)
182
+ chain_mask = torch.stack(mask.squeeze(1).any(dim=-1).split(bs), -1)
183
+ return output, chain_mask
184
+
185
+
186
+ class Attention(nn.Module):
187
+ """
188
+ A multi-head attention layer with optional gating and bias as implemented in Jumper et al. (2021)
189
+ Args:
190
+ n_head (int): Number of heads of attention
191
+ d_model (int): Dimension of input and outputs
192
+ d_k (int): Dimension of keys/queries
193
+ d_v (int): Dimension of values
194
+ gate (bool): Whether to include a gate connection (as in Jumper et al. (2021))
195
+
196
+ Inputs:
197
+ Q (torch.tensor): of size (batch_size, num_queries, d_model)
198
+ K (torch.tensor): of size (batch_size, num_keys, d_model)
199
+ V (torch.tensor): of size (batch_size, num_keys, d_model)
200
+ bias (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
201
+ mask (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
202
+
203
+ Outputs:
204
+ output (torch.tensor): of size (batch_size, num_queries, d_model)
205
+ """
206
+
207
+ def __init__(self, n_head, d_model, d_k=None, d_v=None, gate=False):
208
+ super().__init__()
209
+ self.n_head = n_head
210
+ self.d_model = d_model
211
+ self.d_k = d_model // n_head if d_k is None else d_k
212
+ self.d_v = d_model // n_head if d_v is None else d_v
213
+ self.gate = gate
214
+ self.q_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
215
+ self.k_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
216
+ self.v_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
217
+ self.o_weights = nn.Parameter(torch.Tensor(n_head, self.d_v, d_model))
218
+ self.o_bias = nn.Parameter(torch.Tensor(d_model))
219
+ if self.gate:
220
+ self.g_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
221
+ self.g_bias = nn.Parameter(torch.Tensor(n_head, self.d_v))
222
+ self.softmax = nn.Softmax(dim=-1)
223
+ self.reset_parameters()
224
+
225
+ def reset_parameters(self):
226
+ nn.init.xavier_uniform_(self.q_weights)
227
+ nn.init.xavier_uniform_(self.k_weights)
228
+ nn.init.xavier_uniform_(self.v_weights)
229
+ nn.init.xavier_uniform_(self.o_weights)
230
+ nn.init.zeros_(self.o_bias)
231
+ if self.gate:
232
+ nn.init.zeros_(self.g_weights)
233
+ nn.init.ones_(self.g_bias)
234
+
235
+ def forward(self, Q, K, V, bias=None, mask=None):
236
+ self._check_inputs(Q, K, V, bias, mask)
237
+ q = torch.einsum("bqa,ahc->bqhc", Q, self.q_weights) * self.d_k ** (-0.5)
238
+ k = torch.einsum("bka,ahc->bkhc", K, self.k_weights)
239
+ v = torch.einsum("bka,ahc->bkhc", V, self.v_weights)
240
+ logits = torch.einsum("bqhc,bkhc->bhqk", q, k)
241
+
242
+ if bias is not None:
243
+ logits = logits + bias
244
+
245
+ weights = torch.nn.functional.softmax(logits, dim=-1)
246
+
247
+ if mask is not None:
248
+ weights = weights.masked_fill(~mask, 0.0)
249
+
250
+ weighted_avg = torch.einsum("bhqk,bkhc->bqhc", weights, v)
251
+
252
+ if self.gate:
253
+ gate_values = torch.einsum("bqa,ahc->bqhc", Q, self.g_weights) + self.g_bias
254
+ gate_values = torch.sigmoid(gate_values, dim=-1)
255
+ weighted_avg = weighted_avg * gate_values
256
+
257
+ output = (
258
+ torch.einsum("bqhc,hco->bqo", weighted_avg, self.o_weights) + self.o_bias
259
+ )
260
+ return output
261
+
262
+ def _check_inputs(self, Q, K, V, bias, mask):
263
+ batch_size_q, num_queries, d_q_in = Q.size()
264
+ batch_size_k, num_keys, d_k_in = K.size()
265
+ batch_size_v, num_values, d_v_in = V.size()
266
+
267
+ if d_q_in != self.d_model:
268
+ raise ValueError(
269
+ f"Dimension of Q tensor needs to be (batch_size, number_queries, d_model)"
270
+ )
271
+
272
+ if d_k_in != self.d_model:
273
+ raise ValueError(
274
+ f"Dimension of K tensor needs to be (batch_size, number_keys, d_model)"
275
+ )
276
+
277
+ if d_v_in != self.d_model:
278
+ raise ValueError(
279
+ f"Dimension of V tensor needs to be (batch_size, number_values, d_model)"
280
+ )
281
+
282
+ if num_keys != num_values:
283
+ raise ValueError(f"Number of keys needs to match number of values passed")
284
+
285
+ if (batch_size_q != batch_size_k) or (batch_size_k != batch_size_v):
286
+ raise ValueError(
287
+ f"Found batch size mismatch among inputs, all tensors must agree in size of dimension 0"
288
+ )
289
+
290
+ if bias is not None:
291
+ if (bias.dim() != 3) and (bias.dim() != 4):
292
+ raise ValueError(
293
+ f"Bias specified but dimension mismatched: passed {bias.dim()}-dimensional tensor but should be 3-dimensional"
294
+ f"of shape (n_head, num_queries, num_keys) or 4-dimensional of shape (batch_size, n_head, num_queries, num_keys)"
295
+ )
296
+ if bias.dim() == 3:
297
+ n_head_b, num_queries_b, num_keys_b = bias.size()
298
+ if n_head_b != self.n_head:
299
+ raise ValueError(
300
+ f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
301
+ )
302
+ if num_queries_b != num_queries:
303
+ raise ValueError(
304
+ f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
305
+ )
306
+ if num_keys_b != num_keys:
307
+ raise ValueError(
308
+ f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
309
+ f"(dimenson of axis=1)"
310
+ )
311
+ elif bias.dim() == 4:
312
+ if bias.dim() == 3:
313
+ n_batch_b, n_head_b, num_queries_b, num_keys_b = bias.size()
314
+ if n_head_b != self.n_head:
315
+ raise ValueError(
316
+ f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
317
+ )
318
+ if num_queries_b != num_queries:
319
+ raise ValueError(
320
+ f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
321
+ )
322
+ if num_keys_b != num_keys:
323
+ raise ValueError(
324
+ f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
325
+ f"(dimenson of axis=1)"
326
+ )
327
+
328
+ if mask is not None:
329
+ if mask.dtype != torch.bool:
330
+ raise ValueError(
331
+ f"Mask specified but not given by correct dtype, should be torch.bool but found {mask.dtype}"
332
+ )
333
+ if mask.dim() != 4:
334
+ raise ValueError(
335
+ f"Mask specified but dimension mismatched: passed {mask.dim()}-dimensional tensor but should be 4-dimensional"
336
+ f"of shape (batch_size, n_head, num_queries, num_keys)"
337
+ )
338
+ batch_size_b, _, num_queries_b, num_keys_b = mask.size()
339
+ if (num_queries_b != num_queries) and (num_queries_b != 1):
340
+ raise ValueError(
341
+ f"Bias specified but number of queries (dim of axis=2) does not match number of queries given in Q tensor"
342
+ )
343
+ if (num_keys_b != num_keys) and (num_keys_b != 1):
344
+ raise ValueError(
345
+ f"Bias specified but number of keys (dim of axis=3) does not match number of queries given in K tensor "
346
+ f"(dimenson of axis=1)"
347
+ )
chroma/chroma/layers/basic.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from chroma.layers.norm import MaskedBatchNorm1d
23
+
24
+
25
+ class NoOp(nn.Module):
26
+ """A dummy nn.Module wrapping an identity operation.
27
+ 空操作模块,用来满足代码结构
28
+ Inputs:
29
+ x (any)
30
+
31
+ Outputs:
32
+ x (any)
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def forward(self, x, **kwargs):
39
+ return x
40
+
41
+
42
+ class Transpose(nn.Module):
43
+ """An nn.Module wrapping ```torch.transpose```.
44
+
45
+ Args:
46
+ d1 (int): the first (of two) dimensions to swap
47
+ d2 (int): the second (of two) dimensions to swap
48
+
49
+ Inputs:
50
+ x (torch.tensor)
51
+
52
+ Outputs:
53
+ y (torch.tensor): ```y = x.transpose(d1,d2)```
54
+ """
55
+
56
+ def __init__(self, d1=1, d2=2):
57
+ super().__init__()
58
+ self.d1 = d1
59
+ self.d2 = d2
60
+
61
+ def forward(self, x):
62
+ return x.transpose(self.d1, self.d2)
63
+
64
+
65
+ class Unsqueeze(nn.Module):
66
+ """An nn.Module wrapping ```torch.unsqueeze```.
67
+
68
+ Args:
69
+ dim (int): the dimension to unsqueeze input tensors
70
+
71
+ Inputs:
72
+ x (torch.tensor):
73
+
74
+ Outputs:
75
+ y (torch.tensor): where ```y=x.unsqueeze(dim)```
76
+ """
77
+
78
+ def __init__(self, dim=1):
79
+ super().__init__()
80
+ self.dim = dim
81
+
82
+ def forward(self, x):
83
+ return x.unsqueeze(self.dim)
84
+
85
+
86
+ class OneHot(nn.Module):
87
+ """An nn.Module that wraps F.one_hot```.
88
+
89
+ Args:
90
+ n_tokens (int): the number of tokens comprising input sequences
91
+
92
+ Inputs:
93
+ x (torch.LongTensor): of size ```(batch_size, *)```
94
+
95
+ Outputs:
96
+ y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device
97
+ """
98
+
99
+ def __init__(self, n_tokens):
100
+ super().__init__()
101
+ self.n_tokens = n_tokens
102
+
103
+ def forward(self, x):
104
+ return F.one_hot(x, self.n_tokens)
105
+
106
+
107
+ class MeanEmbedding(nn.Module):
108
+ """A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation),
109
+ optionally applying a softmax to the last dimension if the input corresponds to a log-PMF.
110
+ Args:
111
+ embedding (nn.Embedding): Embedding to wrap
112
+ use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like.
113
+
114
+ Inputs:
115
+ x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation)
116
+
117
+ Outputs:
118
+ y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if
119
+ input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note
120
+ that if the input is a one-hot matrix the output is the same regardless of representation.
121
+ 这个模块是nn.Embedding 的包装器,它允许输是one-hot-like的表示(以及标准的tokenized表示),
122
+ 并且如果输入对应于log-PMF,还以选择性地对最后 个维度应用softmax
123
+ """
124
+
125
+
126
+ def __init__(self, embedding, use_softmax=True):
127
+ super(MeanEmbedding, self).__init__()
128
+ self.embedding = embedding
129
+ self.use_softmax = use_softmax
130
+ self.softmax = nn.Softmax(dim=-1)
131
+
132
+ def forward(self, x):
133
+ if len(x.shape) == 2:
134
+ return self.embedding(x)
135
+ elif len(x.shape) == 3:
136
+ if self.use_softmax:
137
+ return self.softmax(x) @ self.embedding.weight
138
+ else:
139
+ return x @ self.embedding.weight
140
+ else:
141
+ raise (NotImplementedError)
142
+
143
+
144
+ class PeriodicPositionalEncoding(nn.Module):
145
+ """Positional encoding, adapted from 'The Annotated Transformer'
146
+ http://nlp.seas.harvard.edu/2018/04/03/attention.html
147
+ 这个模块实现了周期性的位置编码,这是Transformer模型的一个重要组成部分。
148
+ 它使用正弦和余弦函数来生成位置编码
149
+ Args:
150
+ d_model (int): input and output dimension for the layer
151
+ max_seq_len (int): maximum allowed sequence length
152
+ dropout (float): Dropout rate
153
+
154
+ Inputs:
155
+ x (torch.tensor): of size (batch_size, sequence_length, d_model)
156
+
157
+ Outputs:
158
+ y (torch.tensor): of size (batch_size, sequence_length, d_model)
159
+ """
160
+
161
+ def __init__(self, d_model, max_seq_len=4000, dropout=0.0):
162
+ super(PeriodicPositionalEncoding, self).__init__()
163
+ self.dropout = nn.Dropout(p=dropout)
164
+
165
+ pe = torch.zeros(max_seq_len, d_model)
166
+ position = torch.arange(0.0, max_seq_len).unsqueeze(1)
167
+ div_term = torch.exp(
168
+ torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
169
+ )
170
+
171
+ pe[:, 0::2] = torch.sin(position * div_term)
172
+ pe[:, 1::2] = torch.cos(position * div_term)
173
+ pe = pe.unsqueeze(0)
174
+ self.register_buffer("pe", pe)
175
+
176
+ def forward(self, x):
177
+ x = x + self.pe[:, : x.size(1)]
178
+ return self.dropout(x)
179
+
180
+
181
+ class PositionWiseFeedForward(nn.Module):
182
+ """Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized).
183
+ 这个模块实现了位置感知的前馈网络,这也是Transformer模型的一个重要组成部分。
184
+ 它使用1x1的卷积来实现前馈网络。
185
+ Args:
186
+ d_model (int): input and output dimension for the layer
187
+ d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer
188
+
189
+ Inputs:
190
+ x (torch.tensor): of size (batch_size, sequence_length, d_model)
191
+ Outputs:
192
+ y (torch.tensor): of size (batch_size, sequence_length, d_model)
193
+ """
194
+
195
+ def __init__(self, d_model, d_hidden, dropout=0.1):
196
+ super(PositionWiseFeedForward, self).__init__()
197
+ self.activation = nn.ReLU()
198
+ self.linear1 = nn.Conv1d(d_model, d_hidden, 1)
199
+ self.linear2 = nn.Conv1d(d_hidden, d_model, 1)
200
+ self.dropout = nn.Dropout(p=dropout)
201
+
202
+ def reset_parameters(self):
203
+ self.linear1.reset_parameters()
204
+ self.linear2.reset_parameters()
205
+
206
+ def forward(self, x):
207
+ output = self.activation(self.linear1(x.transpose(1, 2)))
208
+ output = self.linear2(output).transpose(1, 2)
209
+ return self.dropout(output)
210
+
211
+
212
+ class DropNormLin(nn.Module):
213
+ """nn.Module applying a linear layer, normalization, dropout, and activation
214
+ 这个模块应用了一个线性层、归一化、dropout和激活函数。你可以选择使用层归一化 (In') 或批归一 (bn) ,或者跳过过归一化。
215
+ Args:
216
+ in_features (int): input dimension
217
+ out_features (int): output dimension
218
+ norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization
219
+ dropout (float): dropout to apply
220
+ actn (nn.Module): activation function to apply
221
+
222
+ Input:
223
+ x (torch.tensor): of size (batch_size, sequence_length, in_features)
224
+ input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional)
225
+
226
+ Output:
227
+ y (torch.tensor): of size (batch_size, sequence_length, out_features)
228
+ """
229
+
230
+ def __init__(
231
+ self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU()
232
+ ):
233
+ super(DropNormLin, self).__init__()
234
+ self.linear = nn.Linear(in_features, out_features)
235
+ if norm_type == "ln":
236
+ self.norm_layer = nn.LayerNorm(out_features)
237
+ elif norm_type == "bn":
238
+ self.norm_layer = MaskedBatchNorm1d(out_features)
239
+ else:
240
+ self.norm_layer = NoOp()
241
+ self.dropout = nn.Dropout(p=dropout)
242
+ self.actn = actn
243
+
244
+ def forward(self, x, input_mask=None):
245
+ h = self.linear(x)
246
+ if isinstance(self.norm_layer, MaskedBatchNorm1d):
247
+ h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose(
248
+ 1, 2
249
+ )
250
+ else:
251
+ h = self.norm_layer(h)
252
+ return self.dropout(self.actn(h))
253
+
254
+
255
+ class ResidualLinearLayer(nn.Module):
256
+ """A Simple Residual Layer using a linear layer a relu and an optional layer norm.
257
+ 这个模块实现了一个简单的残差层,使用了一个线性层、ReLU激活函数和一个可选的层归一化。
258
+ Args:
259
+ d_model (int): Model Dimension
260
+ use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`.
261
+ """
262
+
263
+ def __init__(self, d_model, use_norm=True):
264
+ super(ResidualLinearLayer, self).__init__()
265
+ self.linear = nn.Linear(d_model, d_model)
266
+ self.ReLU = nn.ReLU()
267
+ self.use_norm = use_norm
268
+ self.norm = nn.LayerNorm(d_model)
269
+
270
+ def forward(self, x):
271
+ z = self.linear(x)
272
+ z = self.ReLU(z)
273
+ if self.use_norm:
274
+ z = self.norm(z)
275
+ return x + z
276
+
277
+
278
+ class TriangleMultiplication(nn.Module):
279
+ def __init__(self, d_model=512, mode="outgoing"):
280
+ """
281
+ Triangle multiplication as defined in Jumper et al. (2021)
282
+ 这个模块实现了Jumper等人在2021年的论文中定义的三角乘法。它接受一个四维的张量作为输入
283
+ 并通过一系列的线性变换和非线性激活函数,以及一个特殊的乘法操作(由 torch.einsum实现) ,来计算输出。
284
+ Args:
285
+ d_model (int): dimension of the embedding at each position
286
+ mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12).
287
+
288
+ Inputs:
289
+ X (torch.tensor): Pair representations of size (batch, nres, nres, channels)
290
+ mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size)
291
+
292
+ Outputs:
293
+ Y (torch.tensor): Pair representations of size (batch, nres, nres, channels)
294
+ """
295
+ super().__init__()
296
+ self.mode = mode
297
+ assert self.mode in ["outgoing", "incoming"]
298
+ self.equation = (
299
+ "bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc"
300
+ )
301
+ self.layer_norm = nn.LayerNorm(d_model)
302
+ self.left_edge_mlp = nn.Sequential(
303
+ nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
304
+ )
305
+ self.right_edge_mlp = nn.Sequential(
306
+ nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model)
307
+ )
308
+ self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid())
309
+ self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model))
310
+
311
+ def forward(self, X, mask=None):
312
+ h = self.layer_norm(X)
313
+
314
+ A = self.left_edge_mlp(h)
315
+ B = self.right_edge_mlp(h)
316
+ G = self.skip(h)
317
+
318
+ if mask is not None:
319
+ A = A.masked_fill(~mask, 0.0)
320
+ B = B.masked_fill(~mask, 0.0)
321
+
322
+ h = torch.einsum(self.equation, A, B)
323
+ h = self.combine(h) * G
324
+ return h
325
+
326
+
327
+ class NodeProduct(nn.Module):
328
+ """Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension,
329
+ process for single-sequence inputs.
330
+ 这个模块实现了Jumper等人在2021年的论文中描述的节点乘积算法。
331
+ 它接受一个二维的张量作为输入,然后通过一系列的线性变换和层归一化操作,来计算输出。
332
+ Args:
333
+ d_in (int): dimension of node embeddings (inputs)
334
+ d_out (int): dimension of edge embeddings (outputs)
335
+
336
+ Inputs:
337
+ node_features (torch.tensor): of size (batch_size, nres, d_model)
338
+ node_mask (torch.tensor): of size (batch_size, nres)
339
+ edge_mask (torch.tensor): of size (batch_size, nres, nres)
340
+
341
+ Outputs:
342
+ edge_features (torch.tensor): of size (batch_size, nres, nres, d_model)
343
+ """
344
+
345
+ def __init__(self, d_in, d_out):
346
+ super().__init__()
347
+ self.layer_norm = nn.LayerNorm(d_in)
348
+ self.left_lin = nn.Linear(d_in, d_in)
349
+ self.right_lin = nn.Linear(d_in, d_in)
350
+ self.edge_lin = nn.Linear(2 * d_in, d_out)
351
+
352
+ def forward(self, node_features, node_mask=None, edge_mask=None):
353
+ _, nres, _ = node_features.size()
354
+
355
+ node_features = self.layer_norm(node_features)
356
+ left_embs = self.left_lin(node_features)
357
+ right_embs = self.right_lin(node_features)
358
+
359
+ if node_mask is not None:
360
+ mask = node_mask[:, :, None]
361
+ left_embs = left_embs.masked_fill(~mask, 0.0)
362
+ right_embs = right_embs.masked_fill(~mask, 0.0)
363
+
364
+ left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1)
365
+ right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1)
366
+ edge_features = torch.cat([left_embs, right_embs], dim=-1)
367
+ edge_features = self.edge_lin(edge_features)
368
+
369
+ if edge_mask is not None:
370
+ mask = edge_mask[:, :, :, None]
371
+ edge_features = edge_features.masked_fill(~mask, 0.0)
372
+
373
+ return edge_features
374
+
375
+
376
+ class FourierFeaturization(nn.Module):
377
+ """Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] ,
378
+ optionally trainable as described in [https://arxiv.org/abs/2106.02795].
379
+ 这个模块实现了对低维输入数据的傅里叶特征化,这是一种将输入数据转换为频域表示的方法。
380
+ 这个模块可以选择是否学习傅里叶特征的频率
381
+ Args:
382
+ d_input (int): dimension of inputs
383
+ d_model (int): dimension of outputs
384
+ trainable (bool): whether to learn the frequency of fourier features
385
+ scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!)
386
+
387
+ Inputs:
388
+ input (torch.tensor): of size (batch_size, *, d_input)
389
+
390
+ Outputs:
391
+ output (torch.tensor): of size (batch_size, *, d_output)
392
+ """
393
+
394
+ def __init__(self, d_input, d_model, trainable=False, scale=1.0):
395
+ super().__init__()
396
+ self.scale = scale
397
+
398
+ if d_model % 2 != 0:
399
+ raise ValueError(
400
+ "d_model needs to be even for this featurization, try again!"
401
+ )
402
+
403
+ B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2)
404
+ self.trainable = trainable
405
+ if not trainable:
406
+ self.register_buffer("B", B)
407
+ else:
408
+ self.register_parameter("B", torch.nn.Parameter(B))
409
+
410
+ def forward(self, inputs):
411
+ h = inputs @ self.B
412
+ return torch.cat([h.cos(), h.sin()], -1)
413
+
414
+
415
+ class PositionalEncoding(nn.Module):
416
+ """Axis-aligned positional encodings with log-linear spacing.
417
+ 这个模块实现了对输入数据的位置编码,这是一种将输入数据的位置信息编码为连续的向量的方法。
418
+ 这个模块使用了对数线性间隔的频率组件。
419
+ Args:
420
+ d_input (int): dimension of inputs
421
+ d_model (int): dimension of outputs
422
+ period_range (tuple of floats): Min and maximum periods for the
423
+ frequency components. Fourier features will be log-linearly spaced
424
+ between these values (inclusive).
425
+
426
+ Inputs:
427
+ input (torch.tensor): of size (..., d_input)
428
+
429
+ Outputs:
430
+ output (torch.tensor): of size (..., d_model)
431
+ """
432
+
433
+ def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)):
434
+ super().__init__()
435
+
436
+ if d_model % (2 * d_input) != 0:
437
+ raise ValueError(
438
+ "d_model needs to be divisible by 2*d_input for this featurization, "
439
+ f"but got {d_model} versus {d_input}"
440
+ )
441
+
442
+ num_frequencies = d_model // (2 * d_input)
443
+ log_bounds = np.log10(period_range)
444
+ p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0)
445
+ w = 2 * math.pi / p
446
+ self.register_buffer("w", w)
447
+
448
+ def forward(self, inputs):
449
+ batch_dims = list(inputs.shape)[:-1]
450
+ # (..., 1, num_out) * (..., num_in, 1)
451
+ w = self.w.reshape(len(batch_dims) * [1] + [1, -1])
452
+ h = w * inputs[..., None]
453
+ h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1])
454
+ return h
455
+
456
+
457
+ class MaybeOnehotEmbedding(nn.Embedding):
458
+ """Wrapper around :class:`torch.nn.Embedding` to support either int-encoded
459
+ LongTensors or one-hot encoded FloatTensors.
460
+ 这个模块是torch.nn.Embedding 的包装器,它支持整数编码的LongTensor输入或者独热编码的FloatTensor输入。
461
+ 如果输入是浮点类型,那么它会通过矩阵乘法来计算嵌入,否则,它会调用父类的 forward 方法来计算嵌入。
462
+ """
463
+
464
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
465
+ if x.dtype.is_floating_point: # onehot
466
+ return x @ self.weight
467
+ return super().forward(x)
chroma/chroma/layers/complexity.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Layers for computing sequence complexities.
16
+ """
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ from chroma.constants import AA20
23
+ from chroma.layers.graph import collect_neighbors
24
+
25
+
26
+ def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30):
27
+ """Compute local compositions per residue.
28
+
29
+ Args:
30
+ S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
31
+ (long) or `(num_batch, num_residues, num_alphabet)` (float).
32
+ C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
33
+ w (int, optional): Window size.
34
+
35
+ Returns:
36
+ P (torch.Tensor): Local compositions with shape
37
+ `(num_batch, num_residues - w + 1, num_alphabet)`.
38
+ N (torch.Tensor): Local counts with shape
39
+ `(num_batch, num_residues - w + 1, num_alphabet)`.
40
+ mask_P (torch.Tensor): Mask with shape
41
+ `(num_batch, num_residues - w + 1)`.
42
+ """
43
+ device = S.device
44
+ Q = len(AA20)
45
+ mask_i = (C > 0).float()
46
+ if len(S.shape) == 2:
47
+ S = F.one_hot(S, Q)
48
+
49
+ # Build neighborhoods and masks
50
+ S_onehot = mask_i[..., None] * S
51
+ kx = torch.arange(w, device=S.device) - w // 2
52
+ edge_idx = (
53
+ torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :]
54
+ )
55
+ mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1])
56
+ edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1)
57
+ C_i = C[..., None]
58
+ C_j = collect_neighbors(C_i, edge_idx)[..., 0]
59
+ mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float()
60
+
61
+ # Sum neighborhood composition
62
+ S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx)
63
+ N = S_j.sum(2)
64
+
65
+ num_N = N.sum(-1, keepdims=True)
66
+ P = N / (num_N + 1e-5)
67
+ mask_i = ((num_N[..., 0] > 0) & (C > 0)).float()
68
+ mask_ij = mask_i[..., None] * mask_ij
69
+ return P, N, edge_idx, mask_i, mask_ij
70
+
71
+
72
+ def complexity_lcp(
73
+ S: torch.LongTensor,
74
+ C: torch.LongTensor,
75
+ w: int = 30,
76
+ entropy_min: float = 2.32,
77
+ method: str = "naive",
78
+ differentiable=True,
79
+ eps: float = 1e-5,
80
+ min_coverage=0.9,
81
+ # entropy_min: float = 2.52,
82
+ # method = "chao-shen"
83
+ ) -> torch.Tensor:
84
+ """Compute the Local Composition Perplexity metric.
85
+
86
+ Args:
87
+ S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)`
88
+ (index tensor) or `(num_batch, num_residues, num_alphabet)`.
89
+ C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
90
+ w (int): Window size.
91
+ grad_pseudocount (float): Pseudocount for stabilizing entropy gradients
92
+ on backwards pass.
93
+ eps (float): Small number for numerical stability in division and logarithms.
94
+
95
+ Returns:
96
+ U (torch.Tensor): Complexities with shape `(num_batch)`.
97
+ """
98
+
99
+ # adjust window size based on sequence length
100
+ if S.shape[1] < w:
101
+ w = S.shape[1]
102
+
103
+ P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w)
104
+
105
+ # Only count windows with `min_coverage`
106
+ min_N = int(min_coverage * w)
107
+ mask_coverage = N.sum(-1) > int(min_coverage * w)
108
+
109
+ H = estimate_entropy(N, method=method)
110
+ U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
111
+
112
+ # Compute entropy as a function of perturbed counts
113
+ if differentiable and len(S.shape) == 3:
114
+ # Compute how a mutation changes entropy for each neighbor
115
+ N_neighbors = collect_neighbors(N, edge_idx)
116
+ mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx)
117
+ N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye(
118
+ N.shape[-1], device=N.device
119
+ )[None, None, None, ...]
120
+ N_ij = N_ij.clamp(min=0)
121
+ H_ij = estimate_entropy(N_ij, method=method)
122
+ U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square()
123
+ U_ij = mask_ij[..., None] * mask_coverage_j * U_ij
124
+ U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2])
125
+ U = U.detach() + U_differentiable - U_differentiable.detach()
126
+
127
+ U = (mask_i * U).sum(1)
128
+ return U
129
+
130
+
131
+ def complexity_scores_lcp_t(
132
+ t,
133
+ S: torch.LongTensor,
134
+ C: torch.LongTensor,
135
+ idx: torch.LongTensor,
136
+ edge_idx_t: torch.LongTensor,
137
+ mask_ij_t: torch.Tensor,
138
+ w: int = 30,
139
+ entropy_min: float = 2.515,
140
+ eps: float = 1e-5,
141
+ method: str = "chao-shen",
142
+ ) -> torch.Tensor:
143
+ """Compute local LCP scores for autoregressive decoding."""
144
+ Q = len(AA20)
145
+ O = F.one_hot(S, Q)
146
+ O_j = collect_neighbors(O, edge_idx_t)
147
+ idx_i = idx[:, t, None]
148
+ C_i = C[:, t, None]
149
+ idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0]
150
+ C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0]
151
+
152
+ # Sum valid neighbor counts
153
+ is_near = (idx_i - idx_j).abs() <= w / 2
154
+ same_chain = C_i == C_j
155
+ valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None]
156
+ N_k = (valid_ij_t * O_j).sum(-2)
157
+
158
+ # Compute counts under all possible extensions
159
+ N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...]
160
+
161
+ H = estimate_entropy(N_k, method=method)
162
+ U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square()
163
+ return U
164
+
165
+
166
+ def estimate_entropy(
167
+ N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11
168
+ ) -> torch.Tensor:
169
+ """Estimate entropy from counts.
170
+
171
+ See Chao, A., & Shen, T. J. (2003) for more details.
172
+
173
+ Args:
174
+ N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`.
175
+
176
+ Returns:
177
+ H (torch.Tensor): Estimated entropy with shape `(...)`.
178
+ """
179
+ N = N.float()
180
+ N_total = N.sum(-1, keepdims=True)
181
+ P = N / (N_total + eps)
182
+
183
+ if method == "chao-shen":
184
+ # Estimate coverage and adjusted frequencies
185
+ singletons = N.long().eq(1).sum(-1, keepdims=True).float()
186
+ C = 1.0 - singletons / (N_total + eps)
187
+ P_adjust = C * P
188
+ P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps)
189
+ H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1)
190
+ elif method == "miller-maddow":
191
+ bins = (N > 0).float().sum(-1)
192
+ bias = (bins - 1) / (2 * N_total[..., 0] + eps)
193
+ H = -(P * torch.log(P + eps)).sum(-1) + bias
194
+ elif method == "laplace":
195
+ N = N.float() + 1 / N.shape[-1]
196
+ N_total = N.sum(-1, keepdims=True)
197
+ P = N / (N_total + eps)
198
+ H = -(P * torch.log(P)).sum(-1)
199
+ else:
200
+ H = -(P * torch.log(P + eps)).sum(-1)
201
+ return H
chroma/chroma/layers/conv.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ MACHINE = platform.machine()
21
+ """
22
+ 一维线性衰减滤波器
23
+ """
24
+
25
+ def filter1D_linear_decay(Z, B):
26
+ """Apply a low-pass filter with batch-heterogeneous coefficients.
27
+
28
+ Computes `x_i = z_i + b * x_{i-1}` where `b` varies per batch member.
29
+
30
+ Args:
31
+ Z (torch.Tensor): Batch of one-dimensional signals with shape `(N, W)`.
32
+ B (torch.Tensor): Batch of coefficients with shape `(N)`.
33
+
34
+ Returns:
35
+ X (torch.Tensor): Result of applying linear recurrence with shape `(N, W)`.
36
+ """
37
+
38
+ # Build filter coefficients as powers of B
39
+ N, W = Z.shape
40
+ k = (W - 1) - torch.arange(W, device=Z.device)
41
+ kernel = B[:, None, None] ** k[None, None, :]
42
+
43
+ # Pad on left to convolve from backwards in time
44
+ Z_pad = F.pad(Z, (W - 1, 0))[None, ...]
45
+
46
+ # Group convolution can effectively do one filter per batch
47
+ while True:
48
+ X = F.conv1d(Z_pad, kernel, stride=1, padding=0, groups=N)[0, :, :]
49
+ # on arm64 (M1 Mac) this convolution erroneously sometimes produces NaNs
50
+ if (
51
+ (MACHINE == "arm64")
52
+ and torch.isnan(X).any()
53
+ and (not torch.isnan(Z_pad).any())
54
+ and (not torch.isnan(kernel).any())
55
+ ):
56
+ continue
57
+ break
58
+ return X
chroma/chroma/layers/graph.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Layers for building graph neural networks.
16
+
17
+ This module contains layers for building neural networks that can process
18
+ graph-structured data. The internal representations of these layers
19
+ are node and edge embeddings.
20
+ """
21
+
22
+ from typing import Callable, List, Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from tqdm.autonotebook import tqdm
28
+
29
+ from chroma.layers.attention import Attention
30
+
31
+
32
+ class GraphNN(nn.Module):
33
+ """Graph neural network with optional edge updates.
34
+
35
+ Args:
36
+ num_layers (int): Number of layers.
37
+ dim_nodes (int): Hidden dimension of node tensor.
38
+ dim_edges (int): Hidden dimension of edge tensor.
39
+ dropout (float): Dropout rate.
40
+ node_mlp_layers (int): Node update function, number of hidden layers.
41
+ Default is 1.
42
+ node_mlp_dim (int): Node update function, hidden dimension.
43
+ Default is to match MLP output dimension.
44
+ update_edge (Boolean): Include an edge-update step. Default: True
45
+ edge_mlp_layers (int): Edge update function, number of hidden layers.
46
+ Default is 1.
47
+ edge_mlp_dim (int): Edge update function, hidden dimension.
48
+ Default is to match MLP output dimension.
49
+ mlp_activation (str): MLP nonlinearity.
50
+ `'relu'`: Rectified linear unit.
51
+ `'softplus'`: Softplus.
52
+ norm (str): Which normalization function to apply between layers.
53
+ `'transformer'`: Default layernorm
54
+ `'layer'`: Masked Layer norm with shape (input.shape[1:])
55
+ `'instance'`: Masked Instance norm
56
+ scale (float): Scaling factor of edge input when updating node (default=1.0)
57
+ attentional (bool): If True, use attention for message aggregation function
58
+ instead of a sum. Default is False.
59
+ num_attention_heads (int): Number of attention heads (if attentional) to use.
60
+ Default is 4.
61
+
62
+ Inputs:
63
+ node_h (torch.Tensor): Node features with shape
64
+ `(num_batch, num_nodes, dim_nodes)`.
65
+ edge_h (torch.Tensor): Edge features with shape
66
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
67
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
68
+ `(num_batch, num_nodes, num_neighbors)`.
69
+ mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
70
+ mask_ij (tensor, optional): Edge mask with shape
71
+ `(num_batch, num_nodes, num_neighbors)`
72
+
73
+ Outputs:
74
+ node_h_out (torch.Tensor): Updated node features with shape
75
+ `(num_batch, num_nodes, dim_nodes)`.
76
+ edge_h_out (torch.Tensor): Updated edge features with shape
77
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ num_layers: int,
83
+ dim_nodes: int,
84
+ dim_edges: int,
85
+ node_mlp_layers: int = 1,
86
+ node_mlp_dim: Optional[int] = None,
87
+ edge_update: bool = True,
88
+ edge_mlp_layers: int = 1,
89
+ edge_mlp_dim: Optional[int] = None,
90
+ mlp_activation: str = "relu",
91
+ dropout: float = 0.0,
92
+ norm: str = "transformer",
93
+ scale: float = 1.0,
94
+ skip_connect_input: bool = False,
95
+ attentional: bool = False,
96
+ num_attention_heads: int = 4,
97
+ checkpoint_gradients: bool = False,
98
+ ):
99
+ super(GraphNN, self).__init__()
100
+ ## 残差网络
101
+ self.skip_connect_input = skip_connect_input
102
+ """
103
+ 优化内存:正常的训练过程中,为了计算梯度,需要存储前向传播中所有层的激活值。
104
+ 使用梯度检查点时,只在特定层保留这些激活值,并在需要时重新计算它们
105
+ """
106
+ self.checkpoint_gradients = checkpoint_gradients
107
+ self.layers = nn.ModuleList(
108
+ [
109
+ GraphLayer(
110
+ dim_nodes=dim_nodes,
111
+ dim_edges=dim_edges,
112
+ node_mlp_layers=node_mlp_layers,
113
+ node_mlp_dim=node_mlp_dim,
114
+ edge_update=edge_update,
115
+ edge_mlp_layers=edge_mlp_layers,
116
+ edge_mlp_dim=edge_mlp_dim,
117
+ mlp_activation=mlp_activation,
118
+ dropout=dropout,
119
+ norm=norm,
120
+ scale=scale,
121
+ attentional=attentional,
122
+ num_attention_heads=num_attention_heads,
123
+ )
124
+ for _ in range(num_layers)
125
+ ]
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ node_h: torch.Tensor,
131
+ edge_h: torch.Tensor,
132
+ edge_idx: torch.LongTensor,
133
+ mask_i: Optional[torch.Tensor] = None,
134
+ mask_ij: Optional[torch.Tensor] = None,
135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ # Run every layer sequentially
137
+ node_h_init = node_h
138
+ edge_h_init = edge_h
139
+ for i, layer in enumerate(self.layers):
140
+ if self.skip_connect_input:
141
+ node_h = node_h + node_h_init
142
+ edge_h = edge_h + edge_h_init
143
+
144
+ # Update edge and node
145
+ node_h, edge_h = self.checkpoint(
146
+ layer, node_h, edge_h, edge_idx, mask_i, mask_ij
147
+ )
148
+
149
+ if self.skip_connect_input:
150
+ node_h = node_h - node_h_init
151
+ edge_h = edge_h - edge_h_init
152
+
153
+ # If mask was provided, apply it
154
+ if mask_i is not None:
155
+ node_h = node_h * (mask_i.unsqueeze(-1) != 0).type(torch.float32)
156
+ if mask_ij is not None:
157
+ edge_h = edge_h * (mask_ij.unsqueeze(-1) != 0).type(torch.float32)
158
+ return node_h, edge_h
159
+
160
+ def checkpoint(self, layer, *args):
161
+ if self.checkpoint_gradients:
162
+ return checkpoint(layer, *args)
163
+ else:
164
+ return layer(*args)
165
+
166
+ def sequential(
167
+ self,
168
+ tensors: dict,
169
+ pre_step_function: Callable = None,
170
+ post_step_function: Callable = None,
171
+ ) -> dict:
172
+ """Decode the GNN sequentially along the node index `t`, with callbacks.
173
+
174
+ Args:
175
+ tensors (dict): Initial set of state tensors. At minimum this should
176
+ include the arguments to `forward`, namely `node_h`, `edge_h`,
177
+ `edge_idx`, `mask_i`, and `mask_ij`.
178
+ pre_step_function (function, optional): Callback function that is
179
+ optionally applied to `tensors` before each sequential GNN step as
180
+ `tensors_new = pre_step_function(t, pre_step_function)` where `t` is
181
+ the node index being updated. It should update elements of the
182
+ `tensors` dictionary, and it can access and update the intermediate
183
+ GNN state cache via the keyed lists of tensors in `node_h_cache` and
184
+ `edge_h_cache`.
185
+ post_step_function (function, optional): Same as `pre_step_function`, but
186
+ optionally applied after each sequential GNN step.
187
+
188
+ Returns:
189
+ tensors (dict): Processed set of tensors.
190
+ """
191
+
192
+ # Initialize the state cache
193
+ tensors["node_h_cache"], tensors["edge_h_cache"] = self.init_steps(
194
+ tensors["node_h"], tensors["edge_h"]
195
+ )
196
+
197
+ # Sequential iteration
198
+ num_steps = tensors["node_h"].size(1)
199
+ for t in tqdm(range(num_steps), desc="Sequential decoding"):
200
+ if pre_step_function is not None:
201
+ tensors = pre_step_function(t, tensors)
202
+
203
+ tensors["node_h_cache"], tensors["edge_h_cache"] = self.step(
204
+ t,
205
+ tensors["node_h_cache"],
206
+ tensors["edge_h_cache"],
207
+ tensors["edge_idx"],
208
+ tensors["mask_i"],
209
+ tensors["mask_ij"],
210
+ )
211
+
212
+ if post_step_function is not None:
213
+ tensors = post_step_function(t, tensors)
214
+
215
+ return tensors
216
+
217
+ def init_steps(
218
+ self, node_h: torch.Tensor, edge_h: torch.Tensor
219
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
220
+ """Initialize cached node and edge features.
221
+
222
+ Args:
223
+ node_h (torch.Tensor): Node features with shape
224
+ `(num_batch, num_nodes, dim_nodes)`.
225
+ edge_h (torch.Tensor): Edge features with shape
226
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
227
+
228
+ Returns:
229
+ node_h_cache (torch.Tensor): List of cached node features with `num_layers + 1`
230
+ tensors of shape `(num_batch, num_nodes, dim_nodes)`.
231
+ edge_h_cache (torch.Tensor): List of cached edge features with `num_layers + 1`
232
+ tensors of shape `(num_batch, num_nodes, num_neighbors, dim_edges)`.
233
+ """
234
+ num_layers = len(self.layers)
235
+ node_h_cache = [node_h.clone() for _ in range(num_layers + 1)]
236
+ edge_h_cache = [edge_h.clone() for _ in range(num_layers + 1)]
237
+ return node_h_cache, edge_h_cache
238
+
239
+ def step(
240
+ self,
241
+ t: int,
242
+ node_h_cache: List[torch.Tensor],
243
+ edge_h_cache: List[torch.Tensor],
244
+ edge_idx: torch.LongTensor,
245
+ mask_i: Optional[torch.Tensor] = None,
246
+ mask_ij: Optional[torch.Tensor] = None,
247
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
248
+ """Process GNN update for a specific node index t from cached intermediates.
249
+
250
+ Inputs:
251
+ t (int): Node index to decode.
252
+ node_h_cache (List[torch.Tensor]): List of cached node features with
253
+ `num_layers + 1` tensors of shape `(num_batch, num_nodes, dim_nodes)`.
254
+ edge_h_cache (List[torch.Tensor]): List of cached edge features with
255
+ `num_layers + 1` tensors of shape
256
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
257
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
258
+ `(num_batch, num_nodes, num_neighbors)`.
259
+ mask_i (torch.Tensor, optional): Node mask with shape
260
+ `(num_batch, num_nodes)`.
261
+ mask_ij (torch.Tensor, optional): Edge mask with shape
262
+ `(num_batch, num_nodes, num_neighbors)`.
263
+
264
+ Outputs:
265
+ node_h_cache (List[torch.Tensor]): Updated list of cached node features
266
+ with `num_layers + 1` tensors of shape
267
+ `(num_batch, num_nodes, dim_nodes)`. This method updates the tensors
268
+ in place for memory.
269
+ edge_h_cache (List[torch.Tensor]): Updated list of cached edge features
270
+ with `num_layers + 1` tensors of shape
271
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
272
+ """
273
+ if self.skip_connect_input:
274
+ raise NotImplementedError
275
+
276
+ for i, layer in enumerate(self.layers):
277
+ # Because the edge updates depend on the updated nodes,
278
+ # we need both the input node features node_h and also
279
+ # the previous output node states node_h
280
+ node_h = node_h_cache[i]
281
+ node_h_out = node_h_cache[i + 1]
282
+ edge_h = edge_h_cache[i]
283
+ # Update edge and node
284
+ node_h_t, edge_h_t = checkpoint(
285
+ layer.step, t, node_h, node_h_out, edge_h, edge_idx, mask_i, mask_ij
286
+ )
287
+
288
+ # Scatter them in place
289
+ node_h_cache[i + 1].scatter_(
290
+ 1, (t * torch.ones_like(node_h_t)).long(), node_h_t
291
+ )
292
+ edge_h_cache[i + 1].scatter_(
293
+ 1, (t * torch.ones_like(edge_h_t)).long(), edge_h_t
294
+ )
295
+
296
+ return node_h_cache, edge_h_cache
297
+
298
+ ## GNNLayer
299
+ class GraphLayer(nn.Module):
300
+ """Graph layer that updates each node i given adjacent nodes and edges.
301
+
302
+ Args:
303
+ dim_nodes (int): Hidden dimension of node tensor.
304
+ dim_edges (int): Hidden dimension of edge tensor.
305
+ node_mlp_layers (int): Node update function, number of hidden layers.
306
+ Default: 1.
307
+ node_mlp_dim (int): Node update function, hidden dimension.
308
+ Default: Matches MLP output dimension.
309
+ update_edge (Boolean): Include an edge-update step. Default: True
310
+ edge_mlp_layers (int): Edge update function, number of hidden layers.
311
+ Default: 1.
312
+ edge_mlp_dim (int): Edge update function, hidden dimension.
313
+ Default: Matches MLP output dimension.
314
+ mlp_activation (str): MLP nonlinearity.
315
+ `'relu'`: Rectified linear unit.
316
+ `'softplus'`: Softplus.
317
+ dropout (float): Dropout rate.
318
+ norm (str): Which normalization function to apply between layers.
319
+ `'transformer'`: Default layernorm
320
+ `'layer'`: Masked Layer norm with shape (input.shape[1:])
321
+ `'instance'`: Masked Instance norm
322
+ scale (float): Scaling factor of edge input when updating node (default=1.0)
323
+
324
+ Inputs:
325
+ node_h (torch.Tensor): Node features with shape
326
+ `(num_batch, num_nodes, dim_nodes)`.
327
+ edge_h (torch.Tensor): Edge features with shape
328
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
329
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
330
+ `(num_batch, num_nodes, num_neighbors)`.
331
+ mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
332
+ mask_ij (tensor, optional): Edge mask with shape
333
+ `(num_batch, num_nodes, num_neighbors)`
334
+
335
+ Outputs:
336
+ node_h_out (torch.Tensor): Updated node features with shape
337
+ `(num_batch, num_nodes, dim_nodes)`.
338
+ edge_h_out (torch.Tensor): Updated edge features with shape
339
+ `(num_batch, num_nodes, num_neighbors, dim_nodes)`.
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ dim_nodes: int,
345
+ dim_edges: int,
346
+ node_mlp_layers: int = 1,
347
+ node_mlp_dim: Optional[int] = None,
348
+ edge_update: bool = True,
349
+ edge_mlp_layers: int = 1,
350
+ edge_mlp_dim: Optional[int] = None,
351
+ mlp_activation: str = "relu",
352
+ dropout: float = 0.0,
353
+ norm: str = "transformer",
354
+ scale: float = 1.0,
355
+ attentional: bool = False,
356
+ num_attention_heads: int = 4,
357
+ ):
358
+ super(GraphLayer, self).__init__()
359
+
360
+ # Store scale
361
+ self.scale = scale
362
+ self.dim_nodes = dim_nodes
363
+ self.dim_edges = dim_edges
364
+ self.attentional = attentional
365
+
366
+ self.node_norm_layer = MaskedNorm(
367
+ dim=1, num_features=dim_nodes, affine=True, norm=norm
368
+ )
369
+
370
+ self.message_mlp = MLP(
371
+ dim_in=2 * dim_nodes + dim_edges,
372
+ dim_out=dim_nodes,
373
+ num_layers_hidden=edge_mlp_layers,
374
+ dim_hidden=edge_mlp_dim,
375
+ activation=mlp_activation,
376
+ dropout=dropout,
377
+ )
378
+ self.update_mlp = MLP(
379
+ dim_in=2 * dim_nodes,
380
+ dim_out=dim_nodes,
381
+ num_layers_hidden=node_mlp_layers,
382
+ dim_hidden=node_mlp_dim,
383
+ activation=mlp_activation,
384
+ dropout=dropout,
385
+ )
386
+ self.edge_update = edge_update
387
+ self.edge_norm_layer = MaskedNorm(
388
+ dim=2, num_features=dim_edges, affine=True, norm=norm
389
+ )
390
+ if self.edge_update:
391
+ self.edge_mlp = MLP(
392
+ dim_in=2 * dim_nodes + dim_edges,
393
+ dim_out=dim_edges,
394
+ num_layers_hidden=edge_mlp_layers,
395
+ dim_hidden=edge_mlp_dim,
396
+ activation=mlp_activation,
397
+ dropout=dropout,
398
+ )
399
+
400
+ if self.attentional:
401
+ self.attention = Attention(n_head=num_attention_heads, d_model=dim_nodes)
402
+ ## attention
403
+ def attend(
404
+ self, node_h: torch.Tensor, messages: torch.Tensor, mask_ij: torch.Tensor
405
+ ) -> torch.Tensor:
406
+ B, L, K, D = messages.size()
407
+ queries = node_h.reshape(-1, 1, D)
408
+ keys = messages.reshape(-1, K, D)
409
+ values = messages.reshape(-1, K, D)
410
+ mask = mask_ij.reshape(-1, 1, 1, K).bool() if mask_ij is not None else None
411
+ return self.attention(queries, keys, values, mask=mask).reshape(B, L, D)
412
+ ## _normalize:Edge and node
413
+ def _normalize(self, node_h, edge_h, mask_i=None, mask_ij=None):
414
+ # Normalize node and edge embeddings
415
+ node_h_norm = self.node_norm_layer(node_h, mask_i)
416
+ edge_h_norm = self.edge_norm_layer(edge_h, mask_ij)
417
+ return node_h_norm, edge_h_norm
418
+ ## ?
419
+ def _normalize_t(
420
+ self, edge_node_stack_t, mask_ij_t, include_nodes=True, include_edges=True
421
+ ):
422
+ # Apply normalization (since we have only normalized time t information)
423
+ node_i_t = edge_node_stack_t[:, :, :, : self.dim_nodes]
424
+ node_j_t = edge_node_stack_t[:, :, :, self.dim_nodes : 2 * self.dim_nodes]
425
+ edge_h_t = edge_node_stack_t[:, :, :, 2 * self.dim_nodes :]
426
+ if include_nodes:
427
+ node_i_t = self.node_norm_layer(node_i_t, mask_ij_t)
428
+ node_j_t = self.node_norm_layer(node_j_t, mask_ij_t)
429
+ if include_edges:
430
+ edge_h_t = self.edge_norm_layer(edge_h_t, mask_ij_t)
431
+ edge_node_stack_t = torch.cat([node_i_t, node_j_t, edge_h_t], -1)
432
+ return edge_node_stack_t
433
+
434
+ def _update_nodes(
435
+ self, node_h, node_h_norm, edge_h_norm, edge_idx, mask_i=None, mask_ij=None
436
+ ):
437
+ """Update nodes given adjacent nodes and edges"""
438
+ # Compute messages at each ij
439
+ edge_node_stack = pack_edges(node_h_norm, edge_h_norm, edge_idx)
440
+ messages = self.message_mlp(edge_node_stack)
441
+ if mask_ij is not None:
442
+ messages = messages * mask_ij.unsqueeze(-1)
443
+
444
+ # Aggregate messages
445
+ if self.attentional:
446
+ message = self.attend(node_h_norm, messages, mask_ij)
447
+ else:
448
+ message = messages.sum(2) / self.scale
449
+
450
+ node_stack = torch.cat([node_h_norm, message], -1)
451
+
452
+ # Update nodes given aggregated messages
453
+ node_h_out = node_h + self.update_mlp(node_stack)
454
+ if mask_i is not None:
455
+ node_h_out = node_h_out * mask_i.unsqueeze(-1)
456
+ return node_h_out
457
+
458
+ def _update_nodes_t(
459
+ self,
460
+ t,
461
+ node_h,
462
+ node_h_norm_t,
463
+ edge_h_norm_t,
464
+ edge_idx_t,
465
+ mask_i_t=None,
466
+ mask_ij_t=None,
467
+ ):
468
+ """Update nodes at index t given adjacent nodes and edges"""
469
+ # Compute messages at each ij
470
+ edge_node_stack_t = mask_ij_t.unsqueeze(-1) * pack_edges_step(
471
+ t, node_h, edge_h_norm_t, edge_idx_t
472
+ )
473
+
474
+ # Apply normalization of gathered tensors
475
+ edge_node_stack_t = self._normalize_t(
476
+ edge_node_stack_t, mask_ij_t, include_edges=False
477
+ )
478
+
479
+ messages_t = self.message_mlp(edge_node_stack_t)
480
+ if mask_ij_t is not None:
481
+ messages_t = messages_t * mask_ij_t.unsqueeze(-1)
482
+
483
+ # Aggregate messages
484
+ if self.attentional:
485
+ message_t = self.attend(node_h_norm_t, messages_t, mask_ij_t)
486
+ else:
487
+ message_t = messages_t.sum(2) / self.scale
488
+
489
+ node_stack_t = torch.cat([node_h_norm_t, message_t], -1)
490
+ # Update nodes given aggregated messages
491
+ node_h_t = node_h[:, t, :].unsqueeze(1)
492
+ node_h_out_t = node_h_t + self.update_mlp(node_stack_t)
493
+ if mask_i_t is not None:
494
+ node_h_out_t = node_h_out_t * mask_i_t.unsqueeze(-1)
495
+ return node_h_out_t
496
+
497
+ def _update_edges(self, edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij):
498
+ """Update edges given adjacent nodes and edges"""
499
+ edge_node_stack = pack_edges(node_h_out, edge_h_norm, edge_idx)
500
+
501
+ edge_h_out = edge_h + self.edge_mlp(edge_node_stack)
502
+ if mask_ij is not None:
503
+ edge_h_out = edge_h_out * mask_ij.unsqueeze(-1)
504
+ return edge_h_out
505
+
506
+ def _update_edges_t(
507
+ self, t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t
508
+ ):
509
+ """Update edges given adjacent nodes and edges"""
510
+ edge_node_stack_t = pack_edges_step(t, node_h_out, edge_h_t_norm, edge_idx_t)
511
+
512
+ edge_h_out_t = edge_h_t + self.edge_mlp(edge_node_stack_t)
513
+ if mask_ij_t is not None:
514
+ edge_h_out_t = edge_h_out_t * mask_ij_t.unsqueeze(-1)
515
+ return edge_h_out_t
516
+
517
+ def forward(
518
+ self,
519
+ node_h: torch.Tensor,
520
+ edge_h: torch.Tensor,
521
+ edge_idx: torch.LongTensor,
522
+ mask_i: Optional[torch.Tensor] = None,
523
+ mask_ij: Optional[torch.Tensor] = None,
524
+ ):
525
+ node_h_norm, edge_h_norm = self._normalize(node_h, edge_h, mask_i, mask_ij)
526
+ if mask_i is not None:
527
+ mask_i = (mask_i != 0).type(torch.float32)
528
+ if mask_ij is not None:
529
+ mask_ij = (mask_ij != 0).type(torch.float32)
530
+ node_h_out = self._update_nodes(
531
+ node_h, node_h_norm, edge_h_norm, edge_idx, mask_i, mask_ij
532
+ )
533
+ edge_h_out = None
534
+ if self.edge_update:
535
+ edge_h_out = self._update_edges(
536
+ edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij
537
+ )
538
+ return node_h_out, edge_h_out
539
+
540
+ def step(
541
+ self,
542
+ t: int,
543
+ node_h: torch.Tensor,
544
+ node_h_out: torch.Tensor,
545
+ edge_h: torch.Tensor,
546
+ edge_idx: torch.LongTensor,
547
+ mask_i: Optional[torch.Tensor] = None,
548
+ mask_ij: Optional[torch.Tensor] = None,
549
+ ):
550
+ """Compute update for a single node index `t`.
551
+
552
+ This function can be useful for sequential computation of graph
553
+ updates, for example with autoregressive architectures.
554
+
555
+ Args:
556
+ t (int): Index of node dimension to update
557
+ node_h (torch.Tensor): Node features with shape
558
+ `(num_batch, num_nodes, dim_nodes)`.
559
+ node_h_out (torch.Tensor): Cached outputs of preceding steps with shape
560
+ `(num_batch, num_nodes, dim_nodes)`.
561
+ edge_h (torch.Tensor): Edge features with shape
562
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
563
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
564
+ `(num_batch, num_nodes, num_neighbors)`.
565
+ mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
566
+ mask_ij (tensor, optional): Edge mask with shape
567
+ `(num_batch, num_nodes, num_neighbors)`
568
+
569
+ Resturns:
570
+ node_h_t (torch.Tensor): Updated node features with shape
571
+ `(num_batch, 1, dim_nodes)`.
572
+ edge_h_t (torch.Tensor): Updated edge features with shape
573
+ `(num_batch, 1, num_neighbors, dim_nodes)`.
574
+ """
575
+ node_h_t = node_h[:, t, :].unsqueeze(1)
576
+ edge_h_t = edge_h[:, t, :, :].unsqueeze(1)
577
+ edge_idx_t = edge_idx[:, t, :].unsqueeze(1)
578
+ mask_i_t = mask_i[:, t].unsqueeze(1)
579
+ mask_ij_t = mask_ij[:, t, :].unsqueeze(1)
580
+
581
+ """ For a single step we need to apply the normalization both at node t and
582
+ also for all of the neighborhood tensors that feed in at t.
583
+ """
584
+ node_h_t_norm, edge_h_t_norm = self._normalize(
585
+ node_h_t, edge_h_t, mask_i_t, mask_ij_t
586
+ )
587
+ node_h_t = self._update_nodes_t(
588
+ t, node_h, node_h_t_norm, edge_h_t_norm, edge_idx_t, mask_i_t, mask_ij_t
589
+ )
590
+
591
+ if self.edge_update:
592
+ node_h_out = node_h_out.scatter(
593
+ 1, (t * torch.ones_like(node_h_t)).long(), node_h_t
594
+ )
595
+ edge_h_t = self._update_edges_t(
596
+ t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t
597
+ )
598
+ return node_h_t, edge_h_t
599
+
600
+ ## 单纯进行线性变换:Equivariance
601
+ class MLP(nn.Module):
602
+ """Multilayer perceptron with variable input, hidden, and output dims.
603
+
604
+ Args:
605
+ dim_in (int): Feature dimension of input tensor.
606
+ dim_hidden (int or None): Feature dimension of intermediate layers.
607
+ Defaults to matching output dimension.
608
+ dim_out (int or None): Feature dimension of output tensor.
609
+ Defaults to matching input dimension.
610
+ num_layers_hidden (int): Number of hidden MLP layers.
611
+ activation (str): MLP nonlinearity.
612
+ `'relu'`: Rectified linear unit.
613
+ `'softplus'`: Softplus.
614
+ dropout (float): Dropout rate. Default is 0.
615
+
616
+ Inputs:
617
+ h (torch.Tensor): Input tensor with shape `(..., dim_in)`
618
+
619
+ Outputs:
620
+ h (torch.Tensor): Input tensor with shape `(..., dim_in)`
621
+ """
622
+
623
+ def __init__(
624
+ self,
625
+ dim_in: int,
626
+ dim_hidden: Optional[int] = None,
627
+ dim_out: Optional[int] = None,
628
+ num_layers_hidden: int = 1,
629
+ activation: str = "relu",
630
+ dropout: float = 0.0,
631
+ ):
632
+ super(MLP, self).__init__()
633
+
634
+ # Default is dimension preserving
635
+ dim_out = dim_out if dim_out is not None else dim_in
636
+ dim_hidden = dim_hidden if dim_hidden is not None else dim_out
637
+
638
+ nonlinearites = {"relu": nn.ReLU, "softplus": nn.Softplus}
639
+ activation_func = nonlinearites[activation]
640
+
641
+ if num_layers_hidden == 0:
642
+ layers = [nn.Linear(dim_in, dim_out)]
643
+ else:
644
+ layers = []
645
+ for i in range(num_layers_hidden):
646
+ d_1 = dim_in if i == 0 else dim_hidden
647
+ layers = layers + [
648
+ nn.Linear(d_1, dim_hidden),
649
+ activation_func(),
650
+ nn.Dropout(dropout),
651
+ ]
652
+ layers = layers + [nn.Linear(dim_hidden, dim_out)]
653
+ self.layers = nn.Sequential(*layers)
654
+
655
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
656
+ return self.layers(h)
657
+
658
+
659
+ def collect_neighbors(node_h: torch.Tensor, edge_idx: torch.Tensor) -> torch.Tensor:
660
+ """Collect neighbor node features as edge features.
661
+
662
+ For each node i, collect the embeddings of neighbors {j in N(i)} as edge
663
+ features neighbor_ij.
664
+
665
+ Args:
666
+ node_h (torch.Tensor): Node features with shape
667
+ `(num_batch, num_nodes, num_features)`.
668
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
669
+ `(num_batch, num_nodes, num_neighbors)`.
670
+
671
+ Returns:
672
+ neighbor_h (torch.Tensor): Edge features containing neighbor node information
673
+ with shape `(num_batch, num_nodes, num_neighbors, num_features)`.
674
+ """
675
+ num_batch, num_nodes, num_neighbors = edge_idx.shape
676
+ num_features = node_h.shape[2]
677
+
678
+ # Flatten for the gather operation then reform the full tensor
679
+ idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1])
680
+ idx_flat = idx_flat.expand(-1, -1, num_features)
681
+ neighbor_h = torch.gather(node_h, 1, idx_flat)
682
+ neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features))
683
+ return neighbor_h
684
+
685
+
686
+ def collect_edges(
687
+ edge_h_dense: torch.Tensor, edge_idx: torch.LongTensor
688
+ ) -> torch.Tensor:
689
+ """Collect sparse edge features from a dense pairwise tensor.
690
+
691
+ Args:
692
+ edge_h_dense (torch.Tensor): Dense edges features with shape
693
+ `(num_batch, num_nodes, num_nodes, num_features)`.
694
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
695
+ `(num_batch, num_nodes, num_neighbors)`.
696
+
697
+ Returns:
698
+ edge_h (torch.Tensor): Edge features with shape
699
+ (num_batch, num_nodes, num_neighbors, num_features)`.
700
+ """
701
+ gather_idx = edge_idx.unsqueeze(-1).expand(-1, -1, -1, edge_h_dense.size(-1))
702
+ edge_h = torch.gather(edge_h_dense, 2, gather_idx)
703
+ return edge_h
704
+
705
+
706
+ def collect_edges_transpose(
707
+ edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_ij: torch.Tensor
708
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
709
+ """Collect edge embeddings of reversed (transposed) edges in-place.
710
+
711
+ Args:
712
+ edge_h (torch.Tensor): Edge features with shape
713
+ `(num_batch, num_nodes, num_neighbors, num_features_edges)`.
714
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
715
+ `(num_batch, num_nodes, num_neighbors)`.
716
+ mask_ij (torch.Tensor): Edge mask with shape
717
+ `(num_batch, num_nodes, num_neighbors)`
718
+
719
+ Returns:
720
+ edge_h_transpose (torch.Tensor): Edge features of transpose with shape
721
+ `(num_batch, num_nodes, num_neighbors, num_features_edges)`.
722
+ mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape
723
+ `(num_batch, num_nodes, num_neighbors)`.
724
+ """
725
+ num_batch, num_residues, num_k, num_features = list(edge_h.size())
726
+
727
+ # Get indices of reverse edges
728
+ ij_to_ji, mask_ji = transpose_edge_idx(edge_idx, mask_ij)
729
+
730
+ # Gather features at reverse edges
731
+ edge_h_flat = edge_h.reshape(num_batch, num_residues * num_k, -1)
732
+ ij_to_ji = ij_to_ji.unsqueeze(-1).expand(-1, -1, num_features)
733
+ edge_h_transpose = torch.gather(edge_h_flat, 1, ij_to_ji)
734
+ edge_h_transpose = edge_h_transpose.reshape(
735
+ num_batch, num_residues, num_k, num_features
736
+ )
737
+ edge_h_transpose = mask_ji.unsqueeze(-1) * edge_h_transpose
738
+ return edge_h_transpose, mask_ji
739
+
740
+
741
+ def scatter_edges(edge_h: torch.Tensor, edge_idx: torch.LongTensor) -> torch.Tensor:
742
+ """Scatter sparse edge features into a dense pairwise tensor.
743
+ Args:
744
+ edge_h (torch.Tensor): Edge features with shape
745
+ `(num_batch, num_nodes, num_neighbors, num_features_edges)`.
746
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
747
+ `(num_batch, num_nodes, num_neighbors)`.
748
+
749
+ Returns:
750
+ edge_h_dense (torch.Tensor): Dense edge features with shape
751
+ `(batch_size, num_nodes, num_nodes, dimensions)`.
752
+ """
753
+ assert edge_h.dim() == 4
754
+ assert edge_idx.dim() == 3
755
+ bs, nres, _, dim = edge_h.size()
756
+ edge_indices = edge_idx.unsqueeze(-1).repeat(1, 1, 1, dim)
757
+ result = torch.zeros(
758
+ size=(bs, nres, nres, dim), dtype=edge_h.dtype, device=edge_h.device,
759
+ )
760
+ return result.scatter(dim=2, index=edge_indices, src=edge_h)
761
+
762
+
763
+ def pack_edges(
764
+ node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor
765
+ ) -> torch.Tensor:
766
+ """Pack nodes and edge features into edge features.
767
+
768
+ Expands each edge_ij by packing node i, node j, and edge ij into
769
+ {node,node,edge}_ij.
770
+
771
+ Args:
772
+ node_h (torch.Tensor): Node features with shape
773
+ `(num_batch, num_nodes, num_features_nodes)`.
774
+ edge_h (torch.Tensor): Edge features with shape
775
+ `(num_batch, num_nodes, num_neighbors, num_features_edges)`.
776
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
777
+ `(num_batch, num_nodes, num_neighbors)`.
778
+
779
+ Returns:
780
+ edge_packed (torch.Tensor): Concatenated node and edge features with shape
781
+ (num_batch, num_nodes, num_neighbors, num_features_nodes
782
+ + 2*num_features_edges)`.
783
+ """
784
+ num_neighbors = edge_h.shape[2]
785
+ node_i = node_h.unsqueeze(2).expand(-1, -1, num_neighbors, -1)
786
+ node_j = collect_neighbors(node_h, edge_idx)
787
+ edge_packed = torch.cat([node_i, node_j, edge_h], -1)
788
+ return edge_packed
789
+
790
+
791
+ def pack_edges_step(
792
+ t: int, node_h: torch.Tensor, edge_h_t: torch.Tensor, edge_idx_t: torch.LongTensor
793
+ ) -> torch.Tensor:
794
+ """Pack node and edge features into edge features for a single node index t.
795
+
796
+ Expands each edge_ij by packing node i, node j, and edge ij into
797
+ {node,node,edge}_ij.
798
+
799
+ Args:
800
+ t (int): Node index to decode.
801
+ node_h (torch.Tensor): Node features at all positions with shape
802
+ `(num_batch, num_nodes, num_features_nodes)`.
803
+ edge_h_t (torch.Tensor): Edge features at index `t` with shape
804
+ `(num_batch, 1, num_neighbors, num_features_edges)`.
805
+ edge_idx_t (torch.LongTensor): Edge indices at index `t` for neighbors with shape
806
+ `(num_batch, 1, num_neighbors)`.
807
+
808
+ Returns:
809
+ edge_packed (torch.Tensor): Concatenated node and edge features
810
+ for index `t` with shape
811
+ (num_batch, 1, num_neighbors, num_features_nodes
812
+ + 2*num_features_edges)`.
813
+ """
814
+ num_nodes_i = node_h.shape[1]
815
+ num_neighbors = edge_h_t.shape[2]
816
+ node_h_t = node_h[:, t, :].unsqueeze(1)
817
+ node_i = node_h_t.unsqueeze(2).expand(-1, -1, num_neighbors, -1)
818
+ node_j = collect_neighbors(node_h, edge_idx_t)
819
+ edge_packed = torch.cat([node_i, node_j, edge_h_t], -1)
820
+ return edge_packed
821
+
822
+
823
+ def transpose_edge_idx(
824
+ edge_idx: torch.LongTensor, mask_ij: torch.Tensor
825
+ ) -> Tuple[torch.LongTensor, torch.Tensor]:
826
+ """Collect edge indices of reverse edges in-place at each edge.
827
+
828
+ The tensor `edge_idx` stores a directed graph topology as a tensor of
829
+ neighbor indices, where an element `edge_idx[b,i,k]` corresponds to the
830
+ node index of neighbor `k` of node `i` in batch member `b`.
831
+
832
+ This function takes a directed graph topology and returns an index tensor
833
+ that maps, in-place, to the reversed edges (if they exist). The indices
834
+ correspond to the contracted dimension of `edge_index` when it is viewed as
835
+ `(num_batch, num_nodes * num_neighbors)`. These indices can be used in
836
+ conjunction with `torch.gather` to collect edge embeddings of `j->i` at
837
+ `i->j`. See `collect_edges_transpose` for an example.
838
+
839
+ For reverse `j->i` edges that do not exist in the directed graph, the
840
+ function also returns a binary mask `mask_ji` indicating which edges
841
+ have both `i->j` and `j->i` present in the graph.
842
+
843
+ Args:
844
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
845
+ `(num_batch, num_nodes, num_neighbors)`.
846
+ mask_ij (torch.Tensor): Edge mask with shape
847
+ `(num_batch, num_nodes, num_neighbors)`
848
+
849
+ Returns:
850
+ ij_to_ji (torch.LongTensor): Flat indices for indexing ji in-place at ij with
851
+ shape `(num_batch, num_nodes * num_neighbors)`.
852
+ mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape
853
+ `(num_batch, num_nodes, num_neighbors)`.
854
+ """
855
+ num_batch, num_residues, num_k = list(edge_idx.size())
856
+
857
+ # 1. Collect neighbors of neighbors
858
+ edge_idx_flat = edge_idx.reshape([num_batch, num_residues * num_k, 1]).expand(
859
+ -1, -1, num_k
860
+ )
861
+ edge_idx_neighbors = torch.gather(edge_idx, 1, edge_idx_flat)
862
+ # (b,i,j,k) gives the kth neighbor of the jth neighbor of i
863
+ edge_idx_neighbors = edge_idx_neighbors.reshape(
864
+ [num_batch, num_residues, num_k, num_k]
865
+ )
866
+
867
+ # 2. Determine which k at j maps back to i (if it exists)
868
+ residue_i = torch.arange(num_residues, device=edge_idx.device).reshape(
869
+ (1, -1, 1, 1)
870
+ )
871
+ edge_idx_match = (edge_idx_neighbors == residue_i).type(torch.float32)
872
+ return_mask, return_idx = torch.max(edge_idx_match, -1)
873
+
874
+ # 3. Build flat indices
875
+ ij_to_ji = edge_idx * num_k + return_idx
876
+ ij_to_ji = ij_to_ji.reshape(num_batch, -1)
877
+
878
+ # 4. Transpose mask
879
+ mask_ji = torch.gather(mask_ij.reshape(num_batch, -1), -1, ij_to_ji)
880
+ mask_ji = mask_ji.reshape(num_batch, num_residues, num_k)
881
+ mask_ji = mask_ij * return_mask * mask_ji
882
+ return ij_to_ji, mask_ji
883
+
884
+
885
+ def permute_tensor(
886
+ tensor: torch.Tensor, dim: int, permute_idx: torch.LongTensor
887
+ ) -> torch.Tensor:
888
+ """Permute a tensor along a dimension given a permutation vector.
889
+
890
+ Args:
891
+ tensor (torch.Tensor): Input tensor with shape
892
+ `([batch_dims], permutation_length, [content_dims])`.
893
+ dim (int): Dimension to permute along.
894
+ permute_idx (torch.LongTensor): Permutation index tensor with shape
895
+ `([batch_dims], permutation_length)`.
896
+
897
+ Returns:
898
+ tensor_permute (torch.Tensor): Permuted node features with shape
899
+ `([batch_dims], permutation_length, [content_dims])`.
900
+ """
901
+ # Resolve absolute dimension
902
+ dim = range(len(list(tensor.shape)))[dim]
903
+
904
+ # Flatten content dimensions
905
+ shape = list(tensor.shape)
906
+ batch_dims, permute_length = shape[:dim], shape[dim]
907
+ tensor_flat = tensor.reshape(batch_dims + [permute_length] + [-1])
908
+
909
+ # Exap content dimensions
910
+ permute_idx_expand = permute_idx.unsqueeze(-1).expand(tensor_flat.shape)
911
+
912
+ tensor_permute_flat = torch.gather(tensor_flat, dim, permute_idx_expand)
913
+ tensor_permute = tensor_permute_flat.reshape(tensor.shape)
914
+ return tensor_permute
915
+
916
+
917
+ def permute_graph_embeddings(
918
+ node_h: torch.Tensor,
919
+ edge_h: torch.Tensor,
920
+ edge_idx: torch.LongTensor,
921
+ mask_i: torch.Tensor,
922
+ mask_ij: torch.Tensor,
923
+ permute_idx: torch.LongTensor,
924
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor]:
925
+ """Permute graph embeddings given a permutation vector.
926
+
927
+ Args:
928
+ node_h (torch.Tensor): Node features with shape
929
+ `(num_batch, num_nodes, dim_nodes)`.
930
+ edge_h (torch.Tensor): Edge features with shape
931
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
932
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
933
+ `(num_batch, num_nodes, num_neighbors)`.
934
+ mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)`
935
+ mask_ij (tensor, optional): Edge mask with shape
936
+ `(num_batch, num_nodes, num_neighbors)`.
937
+ permute_idx (torch.LongTensor): Permutation vector with shape
938
+ `(num_batch, num_nodes)`.
939
+
940
+ Returns:
941
+ node_h_permute (torch.Tensor): Permuted node features with shape
942
+ `(num_batch, num_nodes, dim_nodes)`.
943
+ edge_h_permute (torch.Tensor): Permuted edge features with shape
944
+ `(num_batch, num_nodes, num_neighbors, dim_edges)`.
945
+ edge_idx_permute (torch.LongTensor): Permuted edge indices for neighbors with shape
946
+ `(num_batch, num_nodes, num_neighbors)`.
947
+ mask_i_permute (tensor, optional): Permuted node mask with shape `(num_batch, num_nodes)`
948
+ mask_ij_permute (tensor, optional): Permuted edge mask with shape
949
+ `(num_batch, num_nodes, num_neighbors)`.
950
+ """
951
+
952
+ # Permuting one-dimensional objects is straightforward gathering
953
+ node_h_permute = permute_tensor(node_h, 1, permute_idx)
954
+ edge_h_permute = permute_tensor(edge_h, 1, permute_idx)
955
+ mask_i_permute = permute_tensor(mask_i, 1, permute_idx)
956
+ mask_ij_permute = permute_tensor(mask_ij, 1, permute_idx)
957
+
958
+ """
959
+ For edge_idx, there are two-dimensions set each edge idx that
960
+ previously pointed to j to now point to the new location
961
+ of j which is p^(-1)[j]
962
+ edge^(p)[i,k] = p^(-1)[edge[p(i),k]]
963
+ """
964
+ # First, permute on the i dimension
965
+ edge_idx_permute_1 = permute_tensor(edge_idx, 1, permute_idx)
966
+ # Second, permute on the j dimension by using the inverse
967
+ permute_idx_inverse = torch.argsort(permute_idx, dim=-1)
968
+ edge_idx_1_flat = edge_idx_permute_1.reshape([edge_idx.shape[0], -1])
969
+ edge_idx_permute_flat = torch.gather(permute_idx_inverse, 1, edge_idx_1_flat)
970
+ edge_idx_permute = edge_idx_permute_flat.reshape(edge_idx.shape)
971
+
972
+ return (
973
+ node_h_permute,
974
+ edge_h_permute,
975
+ edge_idx_permute,
976
+ mask_i_permute,
977
+ mask_ij_permute,
978
+ )
979
+
980
+
981
+ def edge_mask_causal(edge_idx: torch.LongTensor, mask_ij: torch.Tensor) -> torch.Tensor:
982
+ """Make an edge mask causal with mask_ij = 0 for j >= i.
983
+
984
+ Args:
985
+ edge_idx (torch.LongTensor): Edge indices for neighbors with shape
986
+ `(num_batch, num_nodes, num_neighbors)`.
987
+ mask_ij (torch.Tensor): Edge mask with shape
988
+ `(num_batch, num_nodes, num_neighbors)`.
989
+
990
+ Returns:
991
+ mask_ij_causal (torch.Tensor): Causal edge mask with shape
992
+ `(num_batch, num_nodes, num_neighbors)`.
993
+ """
994
+ idx = torch.arange(edge_idx.size(1), device=edge_idx.device)
995
+ idx_expand = idx.reshape([1, -1, 1])
996
+ mask_ij_causal = (edge_idx < idx_expand).float() * mask_ij
997
+ return mask_ij_causal
998
+
999
+
1000
+ class MaskedNorm(nn.Module):
1001
+ """Masked normalization layer.
1002
+
1003
+ Args:
1004
+ dim (int): Dimensionality of the normalization. Can be 1 for 1D
1005
+ normalization along dimension 1 or 2 for 2D normalization along
1006
+ dimensions 1 and 2.
1007
+ num_features (int): Channel dimension; only needed if `affine` is True.
1008
+ affine (bool): If True, inclde a learnable affine transformation
1009
+ post-normalization. Default is False.
1010
+ norm (str): Type of normalization, can be `instance`, `layer`, or
1011
+ `transformer`.
1012
+ eps (float): Small number for numerical stability.
1013
+
1014
+ Inputs:
1015
+ data (torch.Tensor): Input tensor with shape
1016
+ `(num_batch, num_nodes, num_channels)` (1D) or
1017
+ `(num_batch, num_nodes, num_nodes, num_channels)` (2D).
1018
+ mask (torch.Tensor): Mask tensor with shape
1019
+ `(num_batch, num_nodes)` (1D) or
1020
+ `(num_batch, num_nodes, num_nodes)` (2D).
1021
+
1022
+ Outputs:
1023
+ norm_data (torch.Tensor): Mask-normalized tensor with shape
1024
+ `(num_batch, num_nodes, num_channels)` (1D) or
1025
+ `(num_batch, num_nodes, num_nodes, num_channels)` (2D).
1026
+ """
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ dim: int,
1031
+ num_features: int = -1,
1032
+ affine: bool = False,
1033
+ norm: str = "instance",
1034
+ eps: float = 1e-5,
1035
+ ):
1036
+ super(MaskedNorm, self).__init__()
1037
+
1038
+ self.norm_type = norm
1039
+ self.dim = dim
1040
+ self.norm = norm + str(dim)
1041
+ self.affine = affine
1042
+ self.eps = eps
1043
+
1044
+ # Dimension to sum
1045
+ if self.norm == "instance1":
1046
+ self.sum_dims = [1]
1047
+ elif self.norm == "layer1":
1048
+ self.sum_dims = [1, 2]
1049
+ elif self.norm == "transformer1":
1050
+ self.sum_dims = [-1]
1051
+ elif self.norm == "instance2":
1052
+ self.sum_dims = [1, 2]
1053
+ elif self.norm == "layer2":
1054
+ self.sum_dims = [1, 2, 3]
1055
+ elif self.norm == "transformer2":
1056
+ self.sum_dims = [-1]
1057
+ else:
1058
+ raise NotImplementedError
1059
+
1060
+ # Number of features, only required if affine
1061
+ self.num_features = num_features
1062
+
1063
+ # Affine transformation is a linear layer on the C channel
1064
+ if self.affine:
1065
+ self.weights = nn.Parameter(torch.rand(self.num_features))
1066
+ self.bias = nn.Parameter(torch.zeros(self.num_features))
1067
+
1068
+ def forward(
1069
+ self, data: torch.Tensor, mask: Optional[torch.Tensor] = None
1070
+ ) -> torch.Tensor:
1071
+ # Add optional trailing singleton dimension and expand if necessary
1072
+ if mask is not None:
1073
+ if len(mask.shape) == len(data.shape) - 1:
1074
+ mask = mask.unsqueeze(-1)
1075
+ if data.shape != mask.shape:
1076
+ mask = mask.expand(data.shape)
1077
+
1078
+ # Input shape is Batch, Channel, Dim1, (dim2 if 2d)
1079
+ dims = self.sum_dims
1080
+ if (mask is None) or (self.norm_type == "transformer"):
1081
+ mask_mean = data.mean(dim=dims, keepdim=True)
1082
+ mask_std = torch.sqrt(
1083
+ (((data - mask_mean)).pow(2)).mean(dim=dims, keepdim=True) + self.eps
1084
+ )
1085
+
1086
+ # Norm
1087
+ norm_data = (data - mask_mean) / mask_std
1088
+
1089
+ else:
1090
+ # Zeroes vector to sum all mask data
1091
+ norm_data = torch.zeros_like(data).to(data.device).type(data.dtype)
1092
+ for mask_id in mask.unique():
1093
+ # Skip zero, since real mask
1094
+ if mask_id == 0:
1095
+ continue
1096
+
1097
+ # Transform mask to temp mask that match mask id
1098
+ tmask = (mask == mask_id).type(torch.float32)
1099
+
1100
+ # Sum mask for mean
1101
+ mask_sum = tmask.sum(dim=dims, keepdim=True)
1102
+
1103
+ # Data is tmask, so that mean is only for unmasked pos
1104
+ mask_mean = (data * tmask).sum(dim=dims, keepdim=True) / mask_sum
1105
+ mask_std = torch.sqrt(
1106
+ (((data - mask_mean) * tmask).pow(2)).sum(dim=dims, keepdim=True)
1107
+ / mask_sum
1108
+ + self.eps
1109
+ )
1110
+
1111
+ # Calculate temp norm, apply mask
1112
+ tnorm = ((data - mask_mean) / mask_std) * tmask
1113
+ # Sometime mask is empty, so generate nan that are conversted to 0
1114
+ tnorm[tnorm != tnorm] = 0
1115
+
1116
+ # Add to init zero norm data
1117
+ norm_data += tnorm
1118
+
1119
+ # Apply affine
1120
+ if self.affine:
1121
+ norm_data = norm_data * self.weights + self.bias
1122
+
1123
+ # If mask, apply mask
1124
+ if mask is not None:
1125
+ norm_data = norm_data * (mask != 0).type(data.dtype)
1126
+ return norm_data
chroma/chroma/layers/linalg.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Generate Biomedicines, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Layers for linear algebra.
16
+ 进行线性代数计算
17
+
18
+ This module contains additional pytorch layers for linear algebra operations,
19
+ such as a more parallelization-friendly implementation of eigvenalue estimation.
20
+ """
21
+
22
+ import torch
23
+
24
+
25
+ def eig_power_iteration(A, num_iterations=50, eps=1e-5):
26
+ """Estimate largest magnitude eigenvalue and associated eigenvector.
27
+
28
+ This uses a simple power iteration algorithm to estimate leading
29
+ eigenvalues, which can often be considerably faster than torch's built-in
30
+ eigenvalue routines. All steps are differentiable and small constants are
31
+ added to any division to preserve the stability of the gradients. For more
32
+ information on power iteration, see
33
+ https://en.wikipedia.org/wiki/Power_iteration.
34
+
35
+ Args:
36
+ A (tensor): Batch of square matrices with shape
37
+ `(..., num_dims, num_dims)`.
38
+ num_iterations (int, optional): Number of iterations for power
39
+ iteration. Default: 50.
40
+ eps (float, optional): Small number to prevent division by zero.
41
+ Default: 1E-5.
42
+
43
+ Returns:
44
+ lam (tensor): Batch of estimated highest-magnitude eigenvalues with
45
+ shape `(...)`.
46
+ v (tensor): Associated eigvector with shape `(..., num_dims)`.
47
+ """
48
+ _safe = lambda x: x + eps
49
+
50
+ dims = list(A.size())[:-1]
51
+ v = torch.randn(dims, device=A.device).unsqueeze(-1)
52
+ for i in range(num_iterations):
53
+ v_prev = v
54
+ Av = torch.matmul(A, v)
55
+ v = Av / _safe(Av.norm(p=2, dim=-2, keepdim=True))
56
+
57
+ # Compute eigenvalue
58
+ v_prev = v_prev.transpose(-1, -2)
59
+ lam = torch.matmul(v_prev, Av) / _safe(torch.abs(torch.matmul(v_prev, v)))
60
+
61
+ # Reshape
62
+ v = v.squeeze(-1)
63
+ lam = lam.view(list(lam.size())[:-2])
64
+ return lam, v
65
+
66
+
67
+ def eig_leading(A, num_iterations=50):
68
+ """Estimate largest positive eigenvalue and associated eigenvector.
69
+
70
+ This estimates the *most positive* eigenvalue of each matrix in a batch of
71
+ matrices by using two consecutive power iterations with spectral shifting.
72
+
73
+ Args:
74
+ A (tensor): Batch of square matrices with shape
75
+ `(..., num_dims, num_dims)`.
76
+ num_iterations (int, optional): Number of iterations for power
77
+ iteration. Default: 50.
78
+
79
+ Returns:
80
+ lam (tensor): Estimated most positive eigenvalue with shape `(...)`.
81
+ v (tensor): Associated eigenvectors with shape `(..., num_dims)`.
82
+ """
83
+ batch_dims = list(A.size())[:-2]
84
+
85
+ # First pass gets largest magnitude
86
+ lam_1, vec_1 = eig_power_iteration(A, num_iterations)
87
+
88
+ # Second pass guaranteed to grab most positive eigenvalue
89
+ lam_1_abs = torch.abs(lam_1)
90
+ lam_I = lam_1_abs.reshape(batch_dims + [1, 1]) * torch.eye(4, device=A.device).view(
91
+ [1 for _ in batch_dims] + [4, 4]
92
+ )
93
+ A_shift = A + lam_I
94
+ lam_2, vec = eig_power_iteration(A_shift, num_iterations)
95
+
96
+ # Shift back to original specta
97
+ lam = lam_2 - lam_1_abs
98
+ return lam, vec